import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch.glob import SumPooling

from models.egt import EGTLayer


class GNN_Layer(torch.nn.Module):
    def __init__(self, hidden_channels, layers, gnn_layer, dropout = 0.5):
        super().__init__()
        self.layers = layers
        self.dropout = dropout
        self.hidden_channels = hidden_channels

        ### set the initial virtual node embedding to 0.
        self.virtualnode_embedding = nn.Linear(1, hidden_channels)
        nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.mlp_virtualnode_list = nn.ModuleList()

        for _ in range(layers):
            self.convs.append(gnn_layer)
            self.batch_norms.append(nn.BatchNorm1d(hidden_channels))

        for _ in range(layers):
            self.mlp_virtualnode_list.append(nn.Sequential(
                nn.Linear(hidden_channels, 2 * hidden_channels), 
                nn.BatchNorm1d(2 * hidden_channels), 
                nn.GELU(),
                nn.Linear(2 * hidden_channels, hidden_channels), 
                nn.BatchNorm1d(hidden_channels), 
                nn.GELU()
            ))

        self.pool = SumPooling()

    def forward(self, g, x):
        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(torch.zeros(g.batch_size, 1).to(x.dtype).to(x.device))

        h_list = [x]
        for layer in range(self.layers):
            ### add message from virtual nodes to graph nodes
            # pdb.set_trace()
            h_list[layer] = h_list[layer] + torch.repeat_interleave(virtualnode_embedding, g.batch_num_nodes(), dim=0)

            ### Message passing among graph nodes
            h = self.convs[layer](g, h_list[layer]).reshape(-1, self.hidden_channels)

            h = self.batch_norms[layer](h)
            if layer == self.layers - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.dropout, training = self.training)
            else:
                h = F.dropout(F.gelu(h), self.dropout, training = self.training)

            h_list.append(h)

            ### update the virtual nodes
            # if layer < self.layers:
            ### add message from graph nodes to virtual nodes
            virtualnode_embedding_temp = self.pool(g, h_list[layer]) + virtualnode_embedding
            ### transform virtual nodes using MLP
            virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.dropout, training = self.training)

        return virtualnode_embedding


class GNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_class, gnn_layer: str='EGT', num_heads: int=16, layers: int=2, dropout: float=0.1):
        super(GNN, self).__init__()
        self.node_embed = nn.Linear(in_channels, hidden_channels)
        gnn_layer = EGTLayer(hidden_channels, num_heads, dropout=dropout)
        self.gnn = GNN_Layer(hidden_channels, layers, gnn_layer, 0)
        self.pool = SumPooling()
        self.predict = nn.Linear(hidden_channels, num_class)
         
    def forward(self, graph, x, return_h=False):
        h = self.node_embed(x)
        h = self.gnn(graph, h)
        out = self.predict(h)
        if return_h:
            return out, h
        return out
    