Recursive Graph Neural Network (models.RecursiveGraphNeuralNetwork)
Import with:
from AnalysisG.models.RecursiveGraphNeuralNetwork import RecursiveGraphNeuralNetwork
RecursiveGraphNeuralNetwork
RecursiveGraphNeuralNetwork is a
ModelTemplate subclass wrapping
the recursivegraphneuralnetwork C++ class.
The model performs recursive message-passing over node and edge features using 8 sub-networks to simultaneously predict edge-level top assignments, node-level aggregated features, and event-level exotic resonance and number-of-tops outputs.
Architecture (from RecursiveGraphNeuralNetwork.cxx)
Constructor signature: recursivegraphneuralnetwork(int rep=1024, double drop_out=0.1)
Sub-network |
Role |
|---|---|
|
Edge message network: |
|
Node encoder: |
|
Hidden-state merge: |
|
Edge-prediction updater: |
|
Exotic resonance head: |
|
Node aggregation: |
|
Number-of-tops head: |
|
Second exotic head (post-aggregation): same pattern as |
Architecture dimensions:
Field |
Default |
Meaning |
|---|---|---|
|
26 |
Input edge-feature dimension (concatenated node features for both endpoints) |
|
5 |
Input node-feature dimension |
|
2 |
Output edge-prediction dimension (binary: same-top or not) |
|
256 |
Internal hidden-state dimension (overridden by constructor argument) |
Hyper-parameters (Python layer)
Property |
Type |
Description |
|---|---|---|
|
|
Dropout probability. Default |
|
|
Target resonance mass [MeV] used as an auxiliary loss constraint.
Default |
|
|
Include MC-truth features in the input. Default |
Usage
from AnalysisG import Analysis
from AnalysisG.models.RecursiveGraphNeuralNetwork import RecursiveGraphNeuralNetwork
ana = Analysis()
ana.AddModel(RecursiveGraphNeuralNetwork)
ana.Epochs = 20
ana.Start()
C++ Reference
-
class recursivegraphneuralnetwork : public model_template
Public Functions
-
recursivegraphneuralnetwork(int rep = 1024, double dpt = 0.1)
-
~recursivegraphneuralnetwork()
-
virtual model_template *clone() override
Create a default-constructed clone. Override in subclasses.
- Returns:
New model instance.
-
virtual void forward(graph_t*) override
Execute one forward pass. Override in subclasses to implement the model computation.
- Parameters:
data – Pointer to the input graph.
-
torch::Tensor message(torch::Tensor trk_i, torch::Tensor trk_j, torch::Tensor pmc, torch::Tensor pmc_i, torch::Tensor pmc_j, torch::Tensor hx_i, torch::Tensor hx_j)
Public Members
-
int _dx = 26
-
int _x = 5
-
int _output = 2
-
int _rep = 256
-
double res_mass = 0
-
double drop_out = 0.1
-
bool is_mc = true
-
torch::nn::Sequential *rnn_x = nullptr
-
torch::nn::Sequential *rnn_dx = nullptr
-
torch::nn::Sequential *rnn_merge = nullptr
-
torch::nn::Sequential *rnn_update = nullptr
-
torch::nn::Sequential *exotic_mlp = nullptr
-
torch::nn::Sequential *node_aggr_mlp = nullptr
-
torch::nn::Sequential *ntops_mlp = nullptr
-
torch::nn::Sequential *exo_mlp = nullptr
-
recursivegraphneuralnetwork(int rep = 1024, double dpt = 0.1)