import torch
from torch import nn
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
import math
import torch.nn.functional as F
from enum import IntEnum
import numpy as np
from .utils import transformer_FFN, ut_mask, pos_encode, get_clones
from torch.nn import Module, Embedding, LSTM, Linear, Dropout, LayerNorm, TransformerEncoder, TransformerEncoderLayer, \
        MultiLabelMarginLoss, MultiLabelSoftMarginLoss, CrossEntropyLoss, BCELoss, MultiheadAttention
from torch.nn.functional import one_hot, cross_entropy, multilabel_margin_loss, binary_cross_entropy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

class simpleKT(nn.Module):
    def __init__(self, n_question, n_pid, 
            d_model, n_blocks, dropout, d_ff=256, 
            loss1=0.5, loss2=0.5, loss3=0.5, start=50, num_layers=2, nheads=4, seq_len=200, 
            kq_same=1, final_fc_dim=512, final_fc_dim2=256, num_attn_heads=8, separate_qa=False, l2=1e-5, emb_type="qid", emb_path="", pretrain_dim=768):
        super().__init__()
        """
        Input:
            d_model: dimension of attention block
            final_fc_dim: dimension of final fully connected net before prediction
            num_attn_heads: number of heads in multi-headed attention
            d_ff : dimension for fully conntected net inside the basic block
            kq_same: if key query same, kq_same=1, else = 0
        """
        self.model_name = "simplekt"
        print(f"model_name: {self.model_name}, emb_type: {emb_type}")
        self.n_question = n_question
        self.dropout = dropout
        self.kq_same = kq_same
        self.n_pid = n_pid
        self.l2 = l2
        self.model_type = self.model_name
        self.separate_qa = separate_qa
        self.emb_type = emb_type
        self.embed_l = d_model
        self.rnn = nn.GRU(256, 256, 1,batch_first=True)
        self.rnn2 = nn.GRU(256, 256, 1,batch_first=True)

        self.difficult_param = nn.Embedding(self.n_pid+1, 1)
        self.p_embed = nn.Embedding(self.n_pid+1, self.embed_l)
        self.p_diff_embed = nn.Embedding(self.n_pid+1, 1)
        self.q_diff_embed = nn.Embedding(self.n_question+1, self.embed_l)
        self.q_embed = nn.Embedding(self.n_question+1, self.embed_l)
        self.qa_embed = nn.Embedding(2 * self.n_question + 1, self.embed_l)
        self.a_embed = nn.Embedding(2, self.embed_l)



        self.out = nn.Sequential(
            nn.Linear(2*d_model + self.embed_l,
                      final_fc_dim), nn.ReLU(), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim, final_fc_dim2), nn.ReLU(
            ), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim2, 1)
        )

        self.reset()

    def reset(self):
        for p in self.parameters():
            if p.size(0) == self.n_pid+1 and self.n_pid > 0:
                torch.nn.init.constant_(p, 0.)

    def base_emb(self, q_data, target):
        q_embed_data = self.q_embed(q_data)  # BS, seqlen,  d_model# c_ct
        if self.separate_qa:
            qa_data = q_data + self.n_question * target
            qa_embed_data = self.qa_embed(qa_data)
        else:
            # BS, seqlen, d_model # c_ct+ g_rt =e_(ct,rt)
            qa_embed_data = self.a_embed(target)+q_embed_data
        return q_embed_data, qa_embed_data

    def get_attn_pad_mask(self, sm):
        batch_size, l = sm.size()
        pad_attn_mask = sm.data.eq(0).unsqueeze(1)
        pad_attn_mask = pad_attn_mask.expand(batch_size, l, l)
        return pad_attn_mask.repeat(self.nhead, 1, 1)

    def forward(self, dcur, qtest=False, train=False):
        q, c, r = dcur["qseqs"].long(), dcur["cseqs"].long(), dcur["rseqs"].long()
        qshft, cshft, rshft = dcur["shft_qseqs"].long(), dcur["shft_cseqs"].long(), dcur["shft_rseqs"].long()
        pid_data = torch.cat((q[:,0:1], qshft), dim=1).to(device)
        q_data = torch.cat((c[:,0:1], cshft), dim=1).to(device)
        target = torch.cat((r[:,0:1], rshft), dim=1).to(device)

        q_embed_data = self.q_embed(q_data)
        qa_embed_data = self.a_embed(target)+q_embed_data
        p_embed_data = self.p_embed(pid_data)
        p_diff_data = self.p_diff_embed(pid_data)
        q_diff_data = self.q_diff_embed(q_data)

        emb_type = self.emb_type

        q_embed_data = q_embed_data + q_diff_data * p_diff_data  # uq *d_ct + c_ct # question encoder


        # BS.seqlen,d_model
        # Pass to the decoder
        # output shape BS,seqlen,d_model or d_model//2
        y2, y3 = 0, 0
        if emb_type in ["qid", "qidaktrasch", "qid_scalar", "qid_norasch"]:
            #d_output = self.model(q_embed_data, qa_embed_data)
            #first=torch.rand((q_embed_data.size()[0],1,256)).to(device)
            first=q_embed_data[:,0:1,:]
            out, hn=self.rnn(qa_embed_data)
            out2, hn=self.rnn(p_embed_data)

            out=torch.concat([first,out],dim=1)[:,:q_embed_data.size()[1],:]
            out2=torch.concat([first,out],dim=1)[:,:p_embed_data.size()[1],:]
            d_output=out

            concat_q = torch.cat([d_output,out2, q_embed_data], dim=-1)
            output = self.out(concat_q).squeeze(-1)
            m = nn.Sigmoid()
            preds = m(output)

        if train:
            return preds, y2, y3
        else:
            if qtest:
                return preds, concat_q
            else:
                return preds
