import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GINConv
import networkx as nx
import matplotlib.pyplot as plt



class GIN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats):
        super(GIN, self).__init__()

        # MLP for GINConv
        mlp1 = nn.Sequential(
            nn.Linear(in_feats, hidden_feats),
            nn.SELU(),
            nn.LayerNorm(hidden_feats),
            nn.Linear(hidden_feats, hidden_feats),
        )
        mlp2 = nn.Sequential(
            nn.Linear(hidden_feats, hidden_feats),
            nn.SELU(),
            nn.LayerNorm(hidden_feats),
            nn.Linear(hidden_feats, hidden_feats),
            #
        )
        mlp3 = nn.Sequential(
            nn.Linear(hidden_feats, hidden_feats),
            nn.SELU(),
            nn.LayerNorm(hidden_feats),
            nn.Linear(hidden_feats, out_feats)
        )

        self.hidden_ln1 = nn.LayerNorm(hidden_feats)
        self.hidden_ln2 = nn.LayerNorm(hidden_feats)
        self.hidden_ln = nn.LayerNorm(hidden_feats)

        # Define GIN layers
        self.conv1 = GINConv(mlp1, aggregator_type='sum')
        self.conv2 = GINConv(mlp2, aggregator_type='sum')
        self.conv3 = GINConv(mlp3, aggregator_type='sum')

        self.fc = nn.Linear(hidden_feats, out_feats)

    def forward(self, g, features, edge_weight=None):
        x = F.silu(self.hidden_ln1(self.conv1(g, features, edge_weight=edge_weight)))
        x = F.silu(self.hidden_ln2(self.conv2(g, x, edge_weight=edge_weight))) + x  # this skip connection is very useful
        x = self.conv3(g, x, edge_weight=edge_weight)
        # x = self.fc(x)
        return x


class NLayerGIN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, n_layers=3):
        super(NLayerGIN, self).__init__()

        self.n_layers = n_layers

        # MLP for GINConv
        mlp1 = nn.Sequential(
            nn.Linear(in_feats, hidden_feats),
            nn.SELU(),
            nn.LayerNorm(hidden_feats),
            nn.Linear(hidden_feats, hidden_feats),
        )
        self.conv1 = GINConv(mlp1, aggregator_type='sum')
        self.hidden_ln1 = nn.LayerNorm(hidden_feats)

        self.hidden_gins = nn.ModuleList()
        self.hidden_lns = nn.ModuleList()
        for i in range(n_layers-2):
            mlp_ = nn.Sequential(
                nn.Linear(hidden_feats, hidden_feats),
                nn.SELU(),
                nn.LayerNorm(hidden_feats),
                nn.Linear(hidden_feats, hidden_feats),
                #
            )
            conv = GINConv(mlp_, aggregator_type='sum')
            # self.hidden_gins.append(
            #     nn.Sequential(conv, nn.LayerNorm(hidden_feats))
            # )
            self.hidden_gins.append(conv)
            self.hidden_lns.append(nn.LayerNorm(hidden_feats))

        mlp_out = nn.Sequential(
            nn.Linear(hidden_feats, hidden_feats),
            nn.SELU(),
            nn.LayerNorm(hidden_feats),
            nn.Linear(hidden_feats, out_feats)
        )
        self.conv_out = GINConv(mlp_out, aggregator_type='sum')

        # self.hidden_ln_out = nn.LayerNorm(hidden_feats)

    def forward(self, g, features):
        x = F.silu(self.hidden_ln1(self.conv1(g, features, edge_weight=g.edata['weight'])))
        # x = self.hidden_ln1(self.conv1(g, features, edge_weight=g.edata['weight']))
        for i in range(self.n_layers - 2):
            x = F.silu(self.hidden_lns[i](self.hidden_gins[i](g, x, edge_weight=g.edata['weight']))) + x
            # x = self.hidden_lns[i](self.hidden_gins[i](g, x, edge_weight=g.edata['weight'])) #+ x
        x = self.conv_out(g, x, edge_weight=g.edata['weight'])

        # x = F.silu(self.hidden_ln2(self.conv3(g, x, edge_weight=g.edata['weight']))) + x
        # x = self.fc(x)
        return x

if __name__ == "__main__":
    # Example Model Initialization
    in_features = 5   # Number of input node features
    hidden_features = 10
    out_features = 2   # Output classes

    model = NLayerGIN(in_features, hidden_features, out_features)

    # Define edges (source and destination nodes)
    edges_src = torch.tensor([0, 1, 2, 3, 4])
    edges_dst = torch.tensor([1, 2, 3, 4, 0])


    # Create DGL Graph
    g = dgl.graph((edges_src, edges_dst))
    g.edata['weight'] = torch.rand(5)
    g = dgl.add_self_loop(g)  # Adding self-loops improves stability

    print("Graph:", g)

    num_nodes = g.num_nodes()
    node_features = torch.randn(num_nodes, in_features)  # Random input features
    print("Node features shape:", node_features.shape)

    output = model(g, node_features)
    print("Output shape:", output.shape)
    print("Output:", output)

