ModelTemplate (Python)
The ModelTemplate Cython class wraps the C++ model_template.
User GNN model classes must subclass both ModelTemplate and
torch.nn.Module, then override forward(data: graph_t*).
Properties
Property |
Type |
Description |
|---|---|---|
|
|
Output graph-level feature map (feature name → output key). |
|
|
Output node-level feature map. |
|
|
Output edge-level feature map. |
|
|
Input graph-level feature names expected by the model. |
|
|
Input node-level feature names expected by the model. |
|
|
Input edge-level feature names expected by the model. |
|
|
Target compute device (e.g. |
|
|
Directory where model checkpoints are written/read. |
|
|
File stem used when saving model weights. |
|
|
ROOT tree name used when saving predictions. |
|
|
Model name string. |
C++ Interface (called from forward)
These methods are available on the graph_t* pointer passed to forward:
Method |
Description |
|---|---|
|
Retrieve a graph-level data tensor by name. |
|
Retrieve a node-level data tensor by name. |
|
Retrieve an edge-level data tensor by name. |
|
Retrieve a graph-level truth tensor by name. |
|
Retrieve a node-level truth tensor by name. |
|
Retrieve an edge-level truth tensor by name. |
|
Retrieve the |
|
Store a graph-level prediction tensor. |
|
Store a node-level prediction tensor. |
|
Store an edge-level prediction tensor. |