import torch
from torch import nn
from torch.functional import F
from torch.nn import Linear, Parameter, ModuleList, BatchNorm1d


from torch_geometric.loader import DataLoader
from torch_geometric.nn import aggr

from modules.utils import get_activation_fn, MLP, get_padded_features, elimitate_padding
from modules.graph_transformer import GraphTransformer
from modules.reference_layer import ReferenceLayer


from sklearn.cluster import KMeans
import math

from torch_geometric.nn import GINConv, JumpingKnowledge
from torch.nn import Linear, Parameter, ModuleList, BatchNorm1d
from torch.utils.checkpoint import checkpoint

class GINExtractor(torch.nn.Module):
    def __init__(self,
                 input_channels,
                 out_channels,
                 num_layer_mlp,
                 num_layer_gin,
                 jump_mode = None, 
                 hidden_channels = None
                 ):
        super().__init__()
        torch.manual_seed(1111)
        if jump_mode is not None:
            self.jump_layer = JumpingKnowledge(jump_mode)
        else:
            self.jump_layer = None
        hidden_channels = hidden_channels if hidden_channels and num_layer_gin > 1 else out_channels
        self.GIN_layers = ModuleList()
        self._num_layer_gin = num_layer_gin
        for layer in range(num_layer_gin):
            if layer == 0:
                local_input_channels = input_channels
                local_out_channels = hidden_channels
            elif layer == num_layer_gin - 1:
                local_input_channels = hidden_channels
                local_out_channels = out_channels
            else: 
                local_input_channels = hidden_channels
                local_out_channels = hidden_channels
            self.GIN_layers.append(
                GINConv(nn.Sequential(
                    nn.Linear(local_input_channels, local_out_channels),
                    nn.BatchNorm1d(local_out_channels),
                    nn.ReLU(),
                    nn.Linear(local_out_channels, local_out_channels),
                    nn.BatchNorm1d(local_out_channels),
                    nn.ReLU()
                ), train_eps=False)
            )
    
    def forward(self, x, edge_index):
        # x = data.x
        x = x.to(torch.float)
        assert torch.any(torch.isnan(x)) == False, "x contains NaN"
        # edge_index = data.edge_index
        xs = []
        curr_x = x
        for i in range(self._num_layer_gin):
            curr_x = self.GIN_layers[i](x=curr_x, edge_index=edge_index)
            # assert torch.any(torch.isnan(curr_x)) == False, "curr_x in GIN contains NaN"
            xs.append(curr_x)
        if self.jump_layer is None:
            return xs[-1]
        else:
            return self.jump_layer(xs)
    
    def reset_parameters(self):
        for layer in self.GIN_layers:
            layer.reset_parameters()



class GraphModel(torch.nn.Module):
    def __init__(
            self, 
            num_atoms,
            num_atom_supp, 
            gamma, 
            use_mlp_head = False, 
            mlp_num_layers = 2,
            mlp_hidden_dim = 32,
            mlp_out_dim = None, 
            jumping_mode = None, 
            readout: str = None, 
            n_graph = 6,
            feat_dim = 32, 
            gin_hidden_dim = 128, 
            gin_num_layer = 3, 
            **kwargs
        ):
        super().__init__()
        self.num_atoms = num_atoms
        self.num_atom_supp = num_atom_supp
        self.feat_dim = feat_dim
        self.embed_dim = kwargs.get("ffn_embed_dim", 768) // n_graph
        self.readout = readout
        self.n_garph = n_graph

        self.gin_num_layer = gin_num_layer
        self.gin_mlp_layer = kwargs.get("gin_mlp_layer", 2)

        self.gins = nn.ModuleDict()
        for i in range(n_graph):
            self.gins[str(i)] = GINExtractor(
                self.feat_dim, 
                self.embed_dim, 
                self.gin_mlp_layer, 
                self.gin_num_layer, 
                jump_mode=jumping_mode, 
                hidden_channels=gin_hidden_dim
                )
            
            
        
        self.gt = GraphTransformer(
            kwargs.get("num_encoder_layers", 12),
            self.embed_dim * n_graph, 
            kwargs.get("ffn_embed_dim", 768),
            kwargs.get("num_attn_heads", 32),
            kwargs.get("dropout", 0.1),
            kwargs.get("attn_dropout", 0.1),
            kwargs.get("activation_dropout", 0.1),
            kwargs.get("layerdrop", 0.0),
            kwargs.get("encoder_normalize_before", False),
            kwargs.get("activation_fn", "gelu"),
            kwargs.get("pre_layernorm", False),
            kwargs.get("embed_scale", None),
        )        
        self.reference_layer = ReferenceLayer(2 * n_graph * self.embed_dim, num_atoms, num_atom_supp, gamma)
        if self.readout == 'sum':
            self.agg = aggr.SumAggregation()
        elif self.readout == 'mean':
            self.agg = aggr.MeanAggregation()
        readout_dim = 2 * n_graph * self.embed_dim          
        # self.mlp = nn.Sequential(
        #     nn.Linear(2 * n_graph * self.embed_dim, 128), 
        #     nn.ReLU()
        # )


        if use_mlp_head:
            if readout is None:
                mlp_in_dim = num_atoms
            else: 
                mlp_in_dim = num_atoms + readout_dim
            self.mlp_head = MLP(
                mlp_in_dim, 
                mlp_out_dim,
                mlp_num_layers,
                mlp_hidden_dim
            )
        else:
            self.mlp_head = None
            

    def forward(self, data):
        batch = data.batch
        device = data.x.device
        num_nodes = batch.bincount().tolist()
        node_slice = torch.cumsum(torch.bincount(batch), 0)
        node_slice  =torch.cat([torch.tensor([0], device=batch.device), node_slice])
        x_gin = []
        for i in range(self.n_garph):
            x_gin.append(self.gins[str(i)](data.x[:, i*self.feat_dim:(i+1)*self.feat_dim], data.edge_index))
        x_gin = torch.cat(x_gin, dim=-1)
        padding_flag, pad_x = get_padded_features(batch, device, x_gin)
        
        # use checkpoint to run GraphTransformer
        # def gt_forward(x):
        #     return self.gt(x)
        # if self.training:
        #     x_gt = checkpoint(gt_forward, pad_x, use_reentrant=False)
        # else:
        x_gt = self.gt(pad_x)
        x_gt = x_gt.transpose(0, 1)

        # print("eliminate padding !...")
        x_gt = elimitate_padding(x_gt, padding_flag)
        
        x = torch.cat([x_gin, x_gt], dim=-1)

        graph_rep = self.agg(x, ptr=data.ptr) if self.readout else None
        # graph_rep = self.mlp(graph_rep)

        x = self.reference_layer(x, node_slice)
        # x = graph_rep
        if graph_rep is not None:
            x = torch.cat([x, graph_rep], dim=-1)

        # print(f"reference_layer Output: \n{x}\n\n")
        if self.mlp_head is not None:
            x = self.mlp_head(x)

        return x
    
    def reset_parameters(self):
        # self.gins.reset_parameters()
        self.reference_layer.reset_parameters()
        if self.mlp_head is not None:
            self.mlp_head.reset_parameters()