import torch 
import torch.nn  as nn 
import torch.nn.functional  as F 
from torch_geometric.nn  import GATConv

class GTN(nn.Module):
    def __init__(self, node_dim, edge_dim, num_classes):
        super().__init__()
        # GTN层（基于GAT改造）
        self.gtn_layers  = nn.ModuleList([
            GATConv(node_dim, node_dim, edge_dim=edge_dim, add_self_loops=False),
            GATConv(node_dim, num_classes, edge_dim=edge_dim, add_self_loops=False)
        ])
        
    def forward(self, x, edge_index, edge_attr):
        for layer in self.gtn_layers[:-1]: 
            x = layer(x, edge_index, edge_attr=edge_attr)
            x = F.relu(x) 
        return self.gtn_layers[-1](x,  edge_index, edge_attr=edge_attr)