import torch
from torch.nn import Linear, ReLU, ModuleList
import torch.nn.functional as F
from torch_geometric.nn import LEConv

class SPMotifNet1(torch.nn.Module):
    """节点分类版 CRCG 原始 GNN"""
    def __init__(self, in_channels, hid_channels=64, num_classes=3, num_unit=2):
        super().__init__()
        self.node_emb = Linear(in_channels, hid_channels)
        self.convs    = ModuleList([LEConv(hid_channels, hid_channels) for _ in range(num_unit)])
        self.relus    = ModuleList([ReLU() for _ in range(num_unit)])
        self.causal_mlp = torch.nn.Sequential(
            Linear(hid_channels, 2*hid_channels),
            ReLU(),
            Linear(2*hid_channels, num_classes)
        )

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x)
        for conv, act in zip(self.convs, self.relus):
            x = conv(x=x, edge_index=edge_index, edge_weight=edge_attr)
            x = act(x)
        return self.causal_mlp(x)

    def reset_parameters(self):
        for layer in self.modules():
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()