from typing import Optional
from itertools import chain
from functools import partial

import torch
import torch.nn as nn
import torch as th
from dgl.nn import DegreeEncoder, GraphormerLayer, PathEncoder, SpatialEncoder
import torch.nn.functional as F
from models import Graphormer, MLP
from loss_func import sce_loss
from utils_gmae import create_norm, drop_edge
import numpy as np

def all_to_device(lst, device):
    return (x.to(device) for x in lst)

def build_model_rtl(param):
    device,\
    node_dim, edge_dim, num_hidden, num_layers, num_attention_heads, activation, \
    max_degree, num_spatial, multi_hop_max_dist, pre_layernorm, \
    in_drop, encoder_type, decoder_type, mask_rate, loss_fn, replace_rate, alpha_l, concat_hidden = param

    model = PreModel_RTL(
        device=device,
        node_dim=node_dim,
        edge_dim=edge_dim,
        num_hidden=num_hidden,
        num_layers=num_layers,
        num_attention_heads=num_attention_heads,
        activation=activation,
        max_degree=max_degree,
        num_spatial=num_spatial,
        multi_hop_max_dist=multi_hop_max_dist,
        feat_drop=in_drop,
        pre_layernorm=pre_layernorm,
        mask_rate=mask_rate,
        encoder_type=encoder_type,
        decoder_type=decoder_type,
        loss_fn=loss_fn,
        replace_rate=replace_rate,
        alpha_l=alpha_l,
        concat_hidden=concat_hidden,
    )
    return model


def setup_module(m_type, enc_dec, embedding_dim, edge_dim, out_dim, num_encoder_layers, ffn_embedding_dim, num_attention_heads, activation, 
                dropout, pre_layernorm=True, max_degree=256, num_spatial=64, multi_hop_max_dist=5) -> nn.Module:
    if m_type == "gt":
        mod = Graphormer(
            embedding_dim=embedding_dim,
            edge_dim=edge_dim,
            out_dim=256,
            num_encoder_layers=num_encoder_layers,

            ffn_embedding_dim=ffn_embedding_dim,
            num_attention_heads=num_attention_heads,

            activation_fn=activation,

            dropout=dropout,
            pre_layernorm=pre_layernorm,
            max_degree=max_degree,
            num_spatial=num_spatial,
            multi_hop_max_dist=multi_hop_max_dist,
            
            encoding=(enc_dec == "encoding"),
        )
    elif m_type == "mlp":
        mod = MLP(
            input_dim=256,
            hidden_dim=512,
            num_layers=3,
            output_dim=27,
            activation="relu",
            norm="batchnorm"
        )

    else:
        raise NotImplementedError
    
    return mod


class PreModel_RTL(nn.Module):
    def __init__(
            self,
            device,
            node_dim: int,
            edge_dim:int,
            num_hidden: int,
            num_layers: int,
            num_attention_heads: int,
            activation: str,
            max_degree: int,
            num_spatial: int,
            multi_hop_max_dist: int,
            feat_drop: float,
            pre_layernorm: bool,
            mask_rate: float = 0.3,
            encoder_type: str = "gt",
            decoder_type: str = "gt",
            loss_fn: str = "sce",
            replace_rate: float = 0.1,
            alpha_l: float = 2,
            concat_hidden: bool = False,
            
         ):
        super(PreModel_RTL, self).__init__()
        self.device = device
        self._mask_rate = mask_rate

        self._encoder_type = encoder_type
        self._decoder_type = decoder_type
        self._output_hidden_size = num_hidden
        self._concat_hidden = concat_hidden
        self._in_dim = node_dim
        
        self._replace_rate = replace_rate
        self._mask_token_rate = 1 - self._replace_rate
        self.num_heads = num_attention_heads

        self.graph_token = nn.Embedding(1, node_dim)

        self.atom_encoder = nn.Embedding(
            512 * self._in_dim + 1, node_dim, padding_idx=0)

        self.degree_encoder = DegreeEncoder(
            max_degree=max_degree, embedding_dim=node_dim
        )

        self.path_encoder = PathEncoder(
            max_len=multi_hop_max_dist,
            feat_dim=edge_dim,
            num_heads=num_attention_heads,
        )

        self.spatial_encoder = SpatialEncoder(
            max_dist=num_spatial, num_heads=num_attention_heads
        )
        self.graph_token_virtual_distance = nn.Embedding(1, num_attention_heads)

        self.emb_layer_norm = nn.LayerNorm(node_dim)

        self.out_proj = nn.Linear(self._output_hidden_size, self._in_dim)

        enc_num_hidden = num_hidden

        dec_in_dim = num_hidden
        dec_num_hidden = num_hidden

        # build encoder
        self.encoder = setup_module(
            m_type=encoder_type,
            embedding_dim=node_dim,
            edge_dim=edge_dim,
            out_dim=256,
            num_encoder_layers=num_layers,
            ffn_embedding_dim=enc_num_hidden,
            num_attention_heads=num_attention_heads,
            activation=activation,
            dropout=feat_drop,
            pre_layernorm=pre_layernorm,
            max_degree=max_degree,
            num_spatial=num_spatial,
            multi_hop_max_dist=multi_hop_max_dist,
            enc_dec="encoding",
        )
        

        # build decoder for attribute prediction
        self.decoder = setup_module(
            m_type=decoder_type,
            embedding_dim=num_hidden,
            edge_dim=edge_dim,
            out_dim=node_dim,
            num_encoder_layers=num_layers,
            ffn_embedding_dim=enc_num_hidden,
            num_attention_heads=num_attention_heads,
            activation=activation,
            dropout=feat_drop,
            pre_layernorm=pre_layernorm,
            max_degree=max_degree,
            num_spatial=num_spatial,
            multi_hop_max_dist=multi_hop_max_dist,
            enc_dec="decoding",
        )

        self.enc_mask_token = nn.Parameter(torch.zeros(1, 1, num_hidden))

        if concat_hidden:
            self.encoder_to_decoder = nn.Linear(dec_in_dim * num_layers, dec_in_dim, bias=False)
        else:
            self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)

        # * setup loss function
        self.criterion = self.setup_loss_fn(loss_fn, alpha_l)

    @property
    def output_hidden_dim(self):
        return self._output_hidden_size

    def setup_loss_fn(self, loss_fn, alpha_l):
        if loss_fn == "mse":
            criterion = nn.MSELoss()
        elif loss_fn == "sce":
            criterion = partial(sce_loss, alpha=alpha_l)
        elif loss_fn == "nll":
            criterion = nn.NLLLoss()
        elif loss_fn == "ce":
            criterion = nn.CrossEntropyLoss()
        elif loss_fn == "cos":
            criterion = nn.CosineSimilarity()
        elif loss_fn == "mae":
            criterion = nn.L1Loss()
        else:
            raise NotImplementedError
        return criterion
    
    def forward(self, batched_data):
        num_nodes = batched_data[1].size(1)
        num_mask = int(self._mask_rate * num_nodes)
        mask = np.hstack([
            np.zeros(num_nodes - num_mask),
            np.ones(num_mask),
        ])
        np.random.shuffle(mask)
        mask = torch.Tensor(mask).bool()

        ### add Fasle at the beginning of mask tensor
        mask = torch.cat([torch.tensor([False]), mask])

        
        x_rec_masked, x_init_mased, graph_emb = self.mask_attr_prediction(batched_data, mask)
        x_init = x_init_mased.view(-1)
        x_rec = x_rec_masked.contiguous().view(-1)  # [n graphs, n nodes, n feature]->[n graphs * n nodes,  n feature]
        
        if x_init.shape[0] == 0:
            loss = torch.tensor(0.0)
            loss_item = {"loss": loss.item()}
            print("No masked nodes")
            input()
            return loss, loss_item, graph_emb
        
        loss = self.criterion(x_rec, x_init)
        loss_item = {"loss": loss.item()}
        return loss, loss_item, graph_emb
    
    def mask_attr_prediction(self, batched_data, mask):


        (attn_mask,node_feat,in_degree,out_degree,path_data,dist) = all_to_device(batched_data, self.device)
        batched_data_encoder = (attn_mask,node_feat,in_degree,out_degree,path_data,dist)

        output, graph_emb = self.encoder(batched_data_encoder, mask)
        output = self.encoder_to_decoder(output)



        mask = mask[1:] ## remove the first bit of mask tensor

        gold = node_feat[:, mask]
        num_masked = mask.nonzero().view(-1).shape[0]
        mask_token = nn.Parameter(torch.zeros(output.shape[0], num_masked, output.shape[2]))
        dec_input = torch.cat([output, mask_token.to(self.device)], dim=1)
        ### enumerate based on the output.shape[0]
        for idx in range(output.shape[0]):
            dec_output = self.decoder(dec_input[idx])
            if idx == 0:
                dec_outputs = dec_output.unsqueeze(0)
            else:
                dec_outputs = torch.cat([dec_outputs, dec_output.unsqueeze(0)], dim=0)

        # dec_output = self.decoder(dec_input)
        dec_outputs = dec_outputs[:, -num_masked:, :]

        return dec_outputs, gold, graph_emb

    def decode(self, output, in_degree, out_degree, graph_attn_bias, graph_attn_mask, mask):

        mask = mask[1:] ## remove the first bit of mask tensor
        
        output = output[:, 1:, :] ## remove the first element of dimension 1 of output tensor
        pos_embed = self.degree_encoder(th.stack((in_degree, out_degree)))
        pos_embed_vis = pos_embed[:, ~mask, :]
        pos_embed_mask = pos_embed[:, mask, :]
        node_index_mask = mask.nonzero().view(-1)
        node_index_vis = (~mask).nonzero().view(-1)
        new_node_index = torch.cat([node_index_vis, node_index_mask])
        graph_attn_bias = graph_attn_bias[:, new_node_index, :, :][:, :, new_node_index, :]
        graph_attn_mask = graph_attn_mask[:, new_node_index, :][:, :, new_node_index]
        mask_token = nn.Parameter(torch.zeros(1, 1, self._output_hidden_size))
        output = torch.cat([output + pos_embed_vis, mask_token.to(self.device) + pos_embed_mask], dim=1)
        num_masked = pos_embed_mask.shape[1]

        output, _ = self.decoder(output, graph_attn_bias, graph_attn_mask)
        output = output[:, -num_masked:]

        
        output = self.out_proj(output.to(self.device))

        return output



    @property
    def enc_params(self):
        return self.encoder.parameters()
    
    @property
    def dec_params(self):
        return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])
