The ModelTemplate Source Files
To add a new model to the framework, navigate to the template files (src/AnalysisG/templates/model) and copy the source files into the existing model directory. Make sure to rename the folder appropriately, and rename files with the <model-name> prefix as needed. The given template files have all the needed structures in place, and essentially just require search and replace modifications.
The C++ Source Code
#ifndef <model-name>_H
#define <model-name>_H
#include <templates/model_template.h>
class <model-name>: public model_template
{
public:
<model-name>();
~<model-name>();
model_template* clone() override;
void forward(graph_t*) override;
torch::nn::Sequential* example = nullptr;
};
#endif
#include <model.h>
model::model(){
this -> example = new torch::nn::Sequential({
{"L1", torch::nn::Linear(2, 2)},
{"RELU", torch::nn::ReLU()},
{"L2", torch::nn::Linear(2, 2)}
});
this -> register_module(this -> example);
}
void model::forward(graph_t* data){
// fetch the input data of the model.
// If the variable is not available, this will return a nullptr.
torch::Tensor graph = data -> get_data_graph("graph") -> clone();
torch::Tensor node = data -> get_data_node("node") -> clone();
torch::Tensor edge = data -> get_data_edge("edge") -> clone();
torch::Tensor edge_index = data -> edge_index -> clone();
// output the prediction weights for edges, nodes, graphs.
this -> prediction_graph_feature("..."; <some-tensor>);
this -> prediction_node_feature("...", <some-tensor>);
this -> prediction_edge_feature("...", <some-tensor>);
if (!this -> inference_mode){return;} // skips any variables not avaliable during inference time.
this -> prediction_extra("...", <some-tensor>); // Any variables that should be dumped during the inference.
}
model::~model(){}
model_template* model::clone(){
return new model();
}
Cython Interface Files
The code below would be the interface of the model via Cython which can be initialized from the Python interpreter Similar to C++, they require a header (.pxd) and source file (.pyx).
- class model.pxd
# distutils: language=c++ # cython: language_level=3
from libcpp cimport int, bool from AnalysisG.core.model_template cimport model_template, ModelTemplate
- cdef extern from “<models/model.h>”:
- cdef cppclass model(model_template):
model() except+
cdef class ExampleModel(ModelTemplate): pass
- class model.pyx
# distutils: language=c++ # cython: language_level=3
from AnalysisG.core.model_template cimport ModelTemplate from AnalysisG.models.<model-name> cimport ExampleModel
- cdef class ExampleModel(ModelTemplate):
def __cinit__(self): self.nn_ptr = new model() def __init__(self): pass def __dealloc__(self): del self.nn_ptr