Model Template
model_template is the base class for all user-defined GNN models. It
inherits from both notification (for logging) and tools (for utilities).
Subclasses must override forward(graph_t* data) and call
register_module to add PyTorch Sequential networks.
Class: model_template
Header: <templates/model_template.h>
Inheritance: notification, tools
Output Feature Maps (cproperty)
These properties map feature-name → loss-function-name and are set from the
Python OptimizerConfig via the Cython wrapper.
Property |
Value Type |
Description |
|---|---|---|
|
|
Map of graph-level output feature names to loss-function names. |
|
|
Map of node-level output feature names to loss-function names. |
|
|
Map of edge-level output feature names to loss-function names. |
Input Feature Lists (cproperty)
Property |
Value Type |
Description |
|---|---|---|
|
|
Names of graph-level data features to fetch from |
|
|
Names of node-level data features to fetch from |
|
|
Names of edge-level data features to fetch from |
Device / Identity Properties (cproperty)
Property |
Type |
Description |
|---|---|---|
|
|
Device string, e.g. |
|
|
Numeric CUDA device index ( |
|
|
Model name, used as the HDF5 weight-file subdirectory. |
Public Fields
Field |
Type |
Description |
|---|---|---|
|
|
Current k-fold iteration index (set by |
|
|
Current epoch (set by |
|
|
Whether the current batch contains Monte Carlo data. Default |
|
|
Use Python-pickle checkpointing instead of HDF5. Default |
|
|
Set to |
|
|
Enable PyTorch anomaly detection. Default |
|
|
Retain computation graph for multiple backwards passes. Default |
|
|
Directory for weight checkpoints. Default |
|
|
Name of the HDF5 event-weight field. Default |
|
|
ROOT tree name used for weight look-up. Default |
Virtual Methods (Override in Subclass)
Signature |
Description |
|---|---|
|
Primary override. Implement the GNN forward pass here.
Call |
|
Override to return a heap-allocated copy of the model. |
|
Override to switch sub-modules between train/eval modes. |
Framework Methods
Signature |
Description |
|---|---|
|
Registers a |
|
Registers a |
|
Stores the graph-level prediction tensor t under key name. |
|
Stores the node-level prediction tensor t under key name. |
|
Stores the edge-level prediction tensor t under key name. |
|
Stores an auxiliary output tensor (not matched to a truth feature). |
|
Computes and returns the loss for output feature name of type
( |
|
Switches all registered modules to eval/train mode. |
|
Saves the current model state (weights) to the checkpoint directory. |
|
Restores the model state from the checkpoint directory. Returns
|
|
Verifies that all requested input features exist in the |
|
Sets the optimizer type by name ( |
|
Builds the PyTorch optimizer using the supplied parameters. |
Example forward Implementation:
void MyModel::forward(graph_t* data) {
// Get input tensors
torch::Tensor* pt = data->get_data_node("pt", this);
torch::Tensor* eta = data->get_data_node("eta", this);
torch::Tensor* ei = data->get_edge_index(this);
// Run GNN layers
torch::Tensor x = torch::cat({*pt, *eta}, 1);
x = this->gnn_layer->forward(x);
// Store output predictions
this->prediction_node_feature("node_cls", x);
}