import torch
from torch import nn
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
import math
import torch.nn.functional as F
from enum import IntEnum
import numpy as np
from .utils import transformer_FFN, ut_mask, pos_encode, get_clones
from torch.nn import Module, Embedding, LSTM, Linear, Dropout, LayerNorm, TransformerEncoder, TransformerEncoderLayer, \
        MultiLabelMarginLoss, MultiLabelSoftMarginLoss, CrossEntropyLoss, BCELoss, MultiheadAttention
from torch.nn.functional import one_hot, cross_entropy, multilabel_margin_loss, binary_cross_entropy
from functools import partial



from typing import Type, Tuple, Optional
from functools import partial

import torch
import torch.nn as nn
import timm



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

class lKT(nn.Module):
    def __init__(self, n_question, n_pid, 
            d_model, dropout, d_ff=256, 
            seq_len=200, 
            kq_same=1, final_fc_dim=512, final_fc_dim2=256,mixer_ratio1=0.5,mixer_ratio2=2.0,separate_qa=False,batch_size=256, emb_type="qid", emb_path="", pretrain_dim=768,use_sweep=0):
        super().__init__()
        """
        Input:
            d_model: dimension of attention block
            final_fc_dim: dimension of final fully connected net before prediction
            num_attn_heads: number of heads in multi-headed attention
            d_ff : dimension for fully conntected net inside the basic block
            kq_same: if key query same, kq_same=1, else = 0
        """
        expansion_factor = 1        
        chan_first = partial(nn.Conv1d, kernel_size = 1)
        self.model_name = "lkt"
        print(f"model_name: {self.model_name}, emb_type: {emb_type}")
        self.n_question = n_question
        self.dropout = dropout
        self.kq_same = kq_same
        self.n_pid = n_pid

        self.model_type = self.model_name
        self.separate_qa = separate_qa
        self.emb_type = emb_type
        self.embed_l = d_model
        self.hidden_dropout_prob=0.1


        self.difficult_param = nn.Embedding(self.n_pid+1, 1)
        self.p_embed = nn.Embedding(self.n_pid+1, self.embed_l)
        self.p_diff_embed = nn.Embedding(self.n_pid+1, self.embed_l)
        self.q_diff_embed = nn.Embedding(self.n_question+1, self.embed_l)
        self.q_embed = nn.Embedding(self.n_question+1, self.embed_l)
        self.qa_embed = nn.Embedding(2 * self.n_question + 1, self.embed_l)
        self.a_embed = nn.Embedding(2, self.embed_l)


        self.hyper_mixer_block = HyperMixerBlock(dim=self.embed_l,mlp_ratio=(mixer_ratio1,mixer_ratio2))
        self.short_hyper_mixer_block=HyperMixerBlock(dim=self.embed_l,mlp_ratio=(mixer_ratio1,mixer_ratio2))


        # self.out2 = nn.Sequential(
        #     nn.Linear(d_model + self.embed_l,
        #               final_fc_dim), nn.ReLU(), nn.Dropout(self.dropout),
        #     nn.Linear(final_fc_dim, final_fc_dim2), nn.ReLU(
        #     ), nn.Dropout(self.dropout),
        #     nn.Linear(final_fc_dim2, 1)
        # )
        self.out = nn.Sequential(
            nn.Linear(2*d_model + self.embed_l,
                      final_fc_dim), nn.ReLU(), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim, final_fc_dim2), nn.ReLU(
            ), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim2, 1)
        )
        # self.sequenceMixer = PreNormResidual(self.embed_l, FeedForward(200, expansion_factor, self.hidden_dropout_prob, chan_first))
        # self.channelMixer = PreNormResidual(self.embed_l, FeedForward(self.embed_l, expansion_factor, self.hidden_dropout_prob))
        
        # self.layers = nn.ModuleList([])
        # for i in range(self.num_feature_field+1):
        #     self.layers.append(self.sequenceMixer)
        #     self.layers.append(self.channelMixer)



        self.reset()

    def reset(self):
        for p in self.parameters():
            if p.size(0) == self.n_pid+1 and self.n_pid > 0:
                torch.nn.init.constant_(p, 0.)


    def forward(self, dcur, qtest=False, train=False):


        q, c, r = dcur["qseqs"].long(), dcur["cseqs"].long(), dcur["rseqs"].long()
        qshft, cshft, rshft = dcur["shft_qseqs"].long(), dcur["shft_cseqs"].long(), dcur["shft_rseqs"].long()


        # if self.training:
        #     augment_kt_seqs(q, c, r,0.5)        
        
        #gt=torch.cat((rshft,r[:,-1:]), dim=1).to(device)

        
        pid_data = torch.cat((q[:,0:1], qshft), dim=1).to(device)
        q_data = torch.cat((c[:,0:1], cshft), dim=1).to(device)
        target = torch.cat((r[:,0:1], rshft), dim=1).to(device)


        q_embed_data = self.q_embed(q_data)





        qa_embed_data = self.a_embed(target)+q_embed_data
        p_embed_data = self.p_embed(pid_data)
        p_diff_data = self.p_diff_embed(pid_data)
        q_diff_data = self.q_diff_embed(q_data)

        emb_type = self.emb_type


        q_embed_data = q_embed_data + q_diff_data * p_diff_data

        y2, y3 = 0, 0
        if emb_type in ["qid", "qidaktrasch", "qid_scalar", "qid_norasch"]:
            first=q_embed_data[:,0:1,:]
            first=torch.concat([first,first], dim=-1)
            hidden_seq=[]
            hidden_short_seq=[]
            
            for i in  range(qa_embed_data.size()[1]):

                mask = torch.ones(qa_embed_data.size()).to(device)
                mask[:, i+1:, :] = 0
                masked_data = qa_embed_data * mask  #64,200,256

                out = self.hyper_mixer_block(masked_data)
                
                if i<11:
                   shortout = self.short_hyper_mixer_block(masked_data[:,:10,:])
                else:
                    shortout = self.short_hyper_mixer_block(masked_data[:,i-10:i,:])
                   
            
                hidden_seq.append(out[:, i, :].unsqueeze(1))
                hidden_short_seq.append(shortout[:, -1, :].unsqueeze(1))

###
            out = torch.cat(hidden_seq, dim=1)
            shortout=torch.cat(hidden_short_seq, dim=1)

            out=torch.concat([out,shortout], dim=-1)
            out=torch.concat([first,out],dim=1)[:,:q_embed_data.size()[1],:]
            d_output=out

            concat_q = torch.cat([d_output,p_embed_data], dim=-1)
            output = self.out(concat_q).squeeze(-1)
###

            # out = torch.cat(hidden_seq, dim=1)
            # shortout=torch.cat(hidden_short_seq, dim=1)


            # out=torch.concat([first,out],dim=1)[:,:q_embed_data.size()[1],:]
            # shortout=torch.concat([first,shortout],dim=1)[:,:q_embed_data.size()[1],:]            
            # d_output=out

            # concat_q = torch.cat([d_output,q_embed_data], dim=-1)
            # concat_q2 = torch.cat([shortout,q_embed_data], dim=-1)      

            # output = self.out(concat_q).squeeze(-1)
            # output2 = self.out2(concat_q2).squeeze(-1)


            #output = torch.sum(out * q_embed_data, dim=2)
            m = nn.Sigmoid()
            y1 = m(output)

            y2 = 0



        if train:
            return y1, y2, y3,torch.tensor(0)
        else:
            if qtest:
                return y1, concat_q
            else:
                return y1


import  random


from enum import IntEnum

from torch.nn import Parameter
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2


class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x.clone())) + x
    
def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
    return nn.Sequential(
        dense(dim, dim * expansion_factor),
        nn.ELU(),
        #nn.GELU(),
        nn.Dropout(dropout),
        #dense(dim * expansion_factor, dim),
        #nn.Dropout(dropout)
    )

class HyperMixerBlock(nn.Module):
    """
    This class implements a HyperMixer block.
    """

    def __init__(
            self,
            dim: int,
            mlp_ratio: Tuple[float, float] = (0.5, 2.0),
            norm_layer: Type = partial(nn.LayerNorm, eps=1e-06),
            act_layer: Type = nn.GELU,
            drop: float = 0.0,
            drop_path: float = 0.
    ) -> None:
        """
        Constructor method
        :param dim (int): Channel dimension
        :param mlp_ratio (Tuple[int, int]): Ratio of hidden dim. of the hyper mixer layer and MLP. Default = (0.5, 4.0)
        :param norm_layer (Type): Type of normalization to be used. Default = nn.LayerNorm
        :param act_layer (Type): Type of activation layer to be used. Default = nn.GELU
        :param drop (float): Dropout rate. Default = 0.
        :param drop_path (float): Dropout path rate. Default = 0.
        """
        # Call super constructor
        super(HyperMixerBlock, self).__init__()
        # Init layers
        tokens_dim, channels_dim = [int(x * dim) for x in timm.models.layers.to_2tuple(mlp_ratio)]
        self.norm1: nn.Module = norm_layer(dim)
        self.mlp_tokens: nn.Module = HyperMixer(dim=dim, hidden_dim=tokens_dim, act_layer=act_layer, drop=drop)
        self.drop_path = timm.models.layers.DropPath(drop_prob=drop_path)
        self.norm2: nn.Module = norm_layer(dim)
        self.mlp_channels: nn.Module = timm.models.layers.Mlp(in_features=dim, hidden_features=channels_dim,
                                                              act_layer=act_layer, drop=drop)

    def forward(
            self,
            x: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass
        :param x (torch.Tensor): Input tensor of the shape [batch size, tokens, channels]
        :return (torch.Tensor): Output tensor of the shape [batch size, tokens, channels]
        """
        x: torch.Tensor = self.norm1(x)
        x: torch.Tensor = x + self.drop_path(self.mlp_tokens(x))
        x: torch.Tensor = x + self.drop_path(self.mlp_channels(self.norm2(x)))
        x: torch.Tensor = x + self.drop_path(self.mlp_channels(x))
        return x

class HyperMixer(nn.Module):
    """
    This class implements the Hyper Mixer layer.
    """

    def __init__(
            self,
            dim: int,
            hidden_dim: int,
            act_layer: Type = nn.GELU,
            drop: float = 0.1,
    ) -> None:
        """
        Constructor method
        :param dim (int): Channel dimension
        :param hidden_dim (int): Size of hidden dimension
        :param act_layer (Type): Type of activation function to be used
        :param drop (float): Dropout rate
        """
        # Call super constructor
        super(HyperMixer, self).__init__()
        # Init modules
        self.mlp_1: nn.Module = timm.models.layers.Mlp(in_features=dim, out_features=hidden_dim, act_layer=nn.GELU,
                                                       drop=drop)
        self.drop = nn.Dropout(drop)
        self.act = act_layer()

    def forward(
            self,
            x: torch.Tensor,
            pos_emb: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass
        :param x (torch.Tensor): Input tensor of the shape [batch size, tokens, channels]
        :param pos_emb (torch.Tensor): Optional positional embeddings for y
        :return (torch.Tensor): Output tensor of the shape
        """
        # Compute weights
        weights: torch.Tensor = self.mlp_1(x + pos_emb if pos_emb else x)
        # Map input with weights and activate
        x: torch.Tensor = self.drop(self.act(weights.transpose(1, 2) @ x))
        x: torch.Tensor = self.drop(weights @ x)
        return x