from typing import Optional, Callable

import torch
import torch.nn as nn
from torch.nn import LayerNorm, Dropout
from torch_geometric.nn import JumpingKnowledge

from .multihead_attention import MultiheadAttention

from .graph_transformer_layer import GraphTransformerLayer

from .utils import get_activation_fn

from torch.utils.checkpoint import checkpoint

class GraphTransformer(nn.Module):
    def __init__(
        self, 
        # num_in_degree , 
        # num_out_degree , 
        # num_spatial ,

        # use_atom_encoder: bool = False, 
        # num_atoms: Optional[int] = None,
        # use_edge_attr: bool = False,
        # num_edge: Optional[int] = None,
        # edge_type: Optional[str] = None,
        # num_edge_dis: Optional[int] = None,

        num_encoder_layers: int = 12,
        embed_dim: int = 768,
        ffn_embed_dim: int = 768,
        num_attn_heads: int = 32,

        dropout: float = 0.1,
        attn_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        layerdrop: float = 0.0, 

        encoder_normalize_before: bool = False,
        activation_fn: str = "gelu",
        pre_layernorm: bool = False,
        embed_scale: float = None, 
        apply_gt_init: bool = False, 

        out_activation_fn: str = "gelu", 

        jumping_knowledge: str = None, 

    ):
        super().__init__()
        self.dropout_module = Dropout(dropout)
        self.layerdrop = layerdrop
        self.embed_dim = embed_dim
        self.apply_gt_init = apply_gt_init
        self.embed_scale = embed_scale
        
        self.graph_attn_bias = None

        if encoder_normalize_before:
            self.emb_layer_norm = LayerNorm(embed_dim)  
        else:
            self.emb_layer_norm = None
        
        if pre_layernorm:
            self.final_layer_norm = LayerNorm(embed_dim)

        self.layers = nn.ModuleList([])
        self.layers.extend(
            [
                GraphTransformerLayer(
                    embed_dim=embed_dim,
                    ffn_embed_dim=ffn_embed_dim,
                    num_attn_heads=num_attn_heads,
                    attn_dropout=attn_dropout,
                    dropout=dropout,
                    activation_dropout=activation_dropout,
                    activation_fn=activation_fn,
                    pre_layernorm=pre_layernorm
                )
                for _ in range(num_encoder_layers)
            ]
        )
        if self.apply_gt_init:
            raise NotImplementedError("Not implemented yet")
        
        if jumping_knowledge is not None:
            self.jump_layer = JumpingKnowledge(jumping_knowledge)
        else:
            self.jump_layer = None

        self.lm_head_transform_weight = nn.Linear(embed_dim, embed_dim)
        
        self.lm_activation_fn = get_activation_fn(out_activation_fn)
        
        self.apply(lambda module: self.reset_parameters(module))

        

    def forward(
        self, 
        batched_data, 
    ):
        data_x = batched_data
        n_graph, n_node = data_x.size(0), data_x.size(1)
        padding_mask = data_x[:, :, 0].eq(0)    # [n_graph, n_node, 1]

        # padding_mask_cls = torch.zeros(
        #     n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype
        # )
        # # [n_graph, n_node+1, 1]
        # padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1)
        
        # [n_graph, num_node+1, embed_dim]
        # x = self.graph_node_feature(batched_data)   
        # attn_bias = self.graph_attn_bias(batched_data)
        
        x = data_x[:, :, 1:].float()
        
        if self.embed_scale is not None:
            x = x*self.embed_scale
        
        if self.emb_layer_norm is not None:
            x = self.emb_layer_norm(x)
        
        x = self.dropout_module(x)

        # [n_graph, n_node, embed_dim]->[n_node, n_graph, embed_dim]
        x = x.transpose(0, 1)

        # self.inner_state = []

        # self.inner_state.append(x)

        for idx, layer in enumerate(self.layers):
            x = layer(
                x, 
                self_attn_padding_mask=padding_mask, 
                # self_attn_bias = attn_bias
                self_attn_bias = None
            )
            
            # self.inner_state.append(x)

        # self.graph_rep = x[0, :, :]

        # if self.jump_layer is not None:
        #     x = self.jump_layer(self.inner_state)

        x.transpose(0, 1)
        x = self.lm_activation_fn(self.lm_head_transform_weight(x))

        return x
        
    def reset_parameters(self, module):
        """
        Initialize the weights specific to the Graph transformer.
        """

        def normal_(data):
            # with FSDP, module params will be on CUDA, so we cast them back to CPU
            # so that the RNG is consistent with and without FSDP
            data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))

        if isinstance(module, nn.Linear):
            normal_(module.weight.data)
            if module.bias is not None:
                module.bias.data.zero_()
        if isinstance(module, nn.Embedding):
            normal_(module.weight.data)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        if isinstance(module, MultiheadAttention):
            normal_(module.q_proj.weight.data)
            normal_(module.k_proj.weight.data)
            normal_(module.v_proj.weight.data)

