Grift Model (models.grift)
The Grift Graph Neural Network is a message-passing GNN for
top-quark combinatorics. Import with:
from AnalysisG.models.grift import Grift
Grift
Grift is a ModelTemplate
subclass wrapping the grift C++ class.
Architecture (from grift.cxx)
Grift uses 7 torch.nn.Sequential sub-networks arranged in a recurrent message-passing loop:
Sub-network |
Role |
|---|---|
|
Node encoder: |
|
Edge message network: |
|
Hidden-state aggregator: |
|
Top-edge predictor: |
|
Resonance predictor: |
|
Number-of-tops readout: |
|
Signal classifier: |
Architecture dimensions (C++ defaults):
Field |
Default |
Meaning |
|---|---|---|
|
1024 |
Width of the |
|
128 |
Recurrent hidden-state dimension used by all sub-networks |
|
6 |
Input node-feature dimension ( |
|
2 |
Output edge/graph prediction dimension (binary classification) |
|
5 |
Number-of-tops feature dimension |
Hyper-parameters (Python layer)
Property |
Type |
Description |
|---|---|---|
|
|
Dropout probability applied to hidden layers. Default |
|
|
When |
|
|
When |
Expected Input Features
Grift reads the following tensors from graph_t inside forward():
Feature |
Slot |
Registered by |
|---|---|---|
|
node data |
graph class |
|
node data |
graph class |
|
node data |
graph class |
|
node data |
graph class |
|
node data |
graph class |
|
node data |
graph class |
edge index |
COO |
graph class |
Usage
from AnalysisG import Analysis
from AnalysisG.models.grift import Grift
ana = Analysis()
ana.AddModel(Grift)
ana.Epochs = 20
ana.Start()
C++ Reference
-
class grift : public model_template
Public Functions
-
grift()
-
~grift()
-
virtual model_template *clone() override
Create a default-constructed clone. Override in subclasses.
- Returns:
New model instance.
-
virtual void forward(graph_t*) override
Execute one forward pass. Override in subclasses to implement the model computation.
- Parameters:
data – Pointer to the input graph.
-
torch::Tensor message(torch::Tensor trk_i, torch::Tensor trk_j, torch::Tensor pmc, torch::Tensor hx_i, torch::Tensor hx_j)
-
torch::Tensor node_encode(torch::Tensor pmc, torch::Tensor num_node, torch::Tensor *node_rnn)
-
torch::Tensor recurse(torch::Tensor *node_i, torch::Tensor *idx_mat, torch::Tensor *edge_index_, torch::Tensor *edge_index, torch::Tensor *edge_rnn, torch::Tensor *node_dnn, torch::Tensor *top_edge, torch::Tensor *pmc, torch::Tensor *node_s)
Public Members
-
int _xrec = 128
-
int _xin = 6
-
int _xout = 2
-
int _xtop = 5
-
bool is_mc = true
-
bool init = false
-
bool pagerank = false
-
double drop_out = 0
-
torch::nn::Sequential *rnn_x = nullptr
-
torch::nn::Sequential *rnn_dx = nullptr
-
torch::nn::Sequential *rnn_txx = nullptr
-
torch::nn::Sequential *rnn_rxx = nullptr
-
torch::nn::Sequential *rnn_hxx = nullptr
-
torch::nn::Sequential *mlp_ntop = nullptr
-
torch::nn::Sequential *mlp_sig = nullptr
-
torch::Tensor x_nulls
-
torch::Tensor dx_nulls
-
torch::Tensor te_nulls
-
grift()