import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GINConv
import networkx as nx
import matplotlib.pyplot as plt



class GIN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, n_layers=3):
        super(GIN, self).__init__()

        self.n_layers = n_layers

        # MLP for GINConv
        mlp1 = nn.Sequential(
            nn.Linear(in_feats, hidden_feats),
            nn.SELU(),
            nn.LayerNorm(hidden_feats),
            nn.Linear(hidden_feats, hidden_feats),
        )
        self.conv1 = GINConv(mlp1, aggregator_type='sum')
        self.hidden_ln1 = nn.LayerNorm(hidden_feats)

        self.hidden_gins = nn.ModuleList()
        self.hidden_lns = nn.ModuleList()
        for i in range(n_layers-2):
            mlp_ = nn.Sequential(
                nn.Linear(hidden_feats, hidden_feats),
                nn.SELU(),
                nn.LayerNorm(hidden_feats),
                nn.Linear(hidden_feats, hidden_feats),
                #
            )
            conv = GINConv(mlp_, aggregator_type='sum')
            # self.hidden_gins.append(
            #     nn.Sequential(conv, nn.LayerNorm(hidden_feats))
            # )
            self.hidden_gins.append(conv)
            self.hidden_lns.append(nn.LayerNorm(hidden_feats))

        mlp_out = nn.Sequential(
            nn.Linear(hidden_feats, hidden_feats),
            nn.SELU(),
            nn.LayerNorm(hidden_feats),
            nn.Linear(hidden_feats, out_feats)
        )
        self.conv_out = GINConv(mlp_out, aggregator_type='sum')

        # self.hidden_ln_out = nn.LayerNorm(hidden_feats)

    def forward(self, g, features):
        x = F.silu(self.hidden_ln1(self.conv1(g, features, edge_weight=g.edata['weight'])))
        # x = self.hidden_ln1(self.conv1(g, features, edge_weight=g.edata['weight']))
        for i in range(self.n_layers - 2):
            x = F.silu(self.hidden_lns[i](self.hidden_gins[i](g, x, edge_weight=g.edata['weight']))) + x
            # x = self.hidden_lns[i](self.hidden_gins[i](g, x, edge_weight=g.edata['weight'])) #+ x
        x = self.conv_out(g, x, edge_weight=g.edata['weight'])

        # x = F.silu(self.hidden_ln2(self.conv3(g, x, edge_weight=g.edata['weight']))) + x
        # x = self.fc(x)
        return x


class GINDeepSigns(nn.Module):
    """ Sign invariant neural network with MLP aggregation.
        f(v1, ..., vk) = rho(enc(v1) + enc(-v1), ..., enc(vk) + enc(-vk))
    """

    def __init__(self, in_channels=1, hidden_channels=64, out_channels=4, num_layers=8,
                 k=128, dim_pe=128, rho_hidden_channels=128):
        super().__init__()
        self.enc = GIN(in_channels, hidden_channels, out_channels, num_layers,)
        rho_dim = out_channels * k
        # self.rho = MLP(rho_dim, hidden_channels, dim_pe, rho_num_layers,
        #                use_bn=use_bn, dropout=dropout, activation=activation)
        self.rho = nn.Sequential(
            nn.Linear(rho_dim, rho_hidden_channels),
            nn.SELU(),
            nn.LayerNorm(rho_hidden_channels),
            nn.Linear(rho_hidden_channels, rho_hidden_channels),
            nn.SELU(),
            nn.LayerNorm(rho_hidden_channels),
            nn.Linear(rho_hidden_channels, dim_pe),
        )

    def forward(self, g, x):
        assert len(x.shape) == 2
        N = x.shape[0]  # Total number of nodes in the batch.
        x = x.unsqueeze(2)
        # x = x.transpose(0, 1) # N x K x In -> K x N x In
        x = self.enc(g, x) + self.enc(g, -1 * x)
        x = x.transpose(0, 1).reshape(N, -1)  # K x N x Out -> N x (K * Out)
        x = self.rho(x)  # N x dim_pe (Note: in the original codebase dim_pe is always K)
        return x
