import torch
from torch import nn
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
import math
import torch.nn.functional as F
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class AKT(nn.Module):
    def __init__(self, n_question, n_pid, d_model, n_layer, dropout, RNN1type="GRU", 
            kq_same=1, final_fc_zoom=4, RNN2type="GRU", separate_qa=False, ratio=1e-5, emb_type="qid", ablation="", pretrain_dim=768,use_sweep=0):
        super().__init__()
        """
        Input:
            d_model: dimension of attention block
            final_fc_dim: dimension of final fully connected net before prediction
            num_attn_heads: number of heads in multi-headed attention
            d_ff : dimension for fully conntected net inside the basic block
            kq_same: if key query same, kq_same=1, else = 0
        """
        self.model_name = "akt"
        self.n_question = n_question
        self.dropout = dropout
        self.kq_same = kq_same
        self.n_pid = n_pid
        self.ratio = ratio
        self.model_type = self.model_name
        self.separate_qa = separate_qa
        self.emb_type = emb_type
        embed_l = d_model
        self.n_layer = n_layer
        self.ablation=ablation
        final_fc_dim = final_fc_zoom *embed_l


        if RNN1type=="GRU":
            self.rnn = nn.GRU(embed_l, embed_l, self.n_layer,batch_first=True)
        if RNN1type=="RNN":
            self.rnn = nn.RNN(embed_l, embed_l, self.n_layer,batch_first=True)
        if RNN1type=="LSTM":
            self.rnn = nn.LSTM(embed_l, embed_l, self.n_layer,batch_first=True)

        if RNN2type=="GRU":
            self.rnn2 = nn.GRU(embed_l, embed_l, self.n_layer,batch_first=True)
        if RNN2type=="RNN":
            self.rnn2 = nn.RNN(embed_l, embed_l, self.n_layer,batch_first=True)
        if RNN2type=="LSTM":
            self.rnn2 = nn.LSTM(embed_l, embed_l, self.n_layer,batch_first=True)        
        
        
        
        if self.n_pid > 0:
            self.difficult_param = nn.Embedding(self.n_pid+1, 1) # 题目难度
            self.q_embed_diff = nn.Embedding(self.n_question+1, embed_l) # question emb, 总结了包含当前question（concept）的problems（questions）的变化
            self.qa_embed_diff = nn.Embedding(2 * self.n_question + 1, embed_l) # interaction emb, 同上
            self.p_embed= nn.Embedding(self.n_pid+1, embed_l) # 题目难度
        if self.n_pid == 0:
            self.difficult_param = nn.Embedding(self.n_question+1, 1) # 题目难度
            self.q_embed_diff = nn.Embedding(self.n_question+1, embed_l) # question emb, 总结了包含当前question（concept）的problems（questions）的变化
            self.qa_embed_diff = nn.Embedding(2 * self.n_question + 1, embed_l) # interaction emb, 同上
            self.p_embed= nn.Embedding(self.n_question+5, embed_l) # 题目难度
        
        if emb_type.startswith("qid"):
            print(66)
            # n_question+1 ,d_model
            self.q_embed = nn.Embedding(self.n_question+5, embed_l)
            if self.separate_qa: 
                self.qa_embed = nn.Embedding(2*self.n_question+1, embed_l) # interaction emb
            else: # false default
                self.qa_embed = nn.Embedding(2, embed_l)

        # Architecture Object. It contains stack of attention block

        self.out = nn.Sequential(
            nn.Linear(d_model + embed_l,
                      final_fc_dim), nn.ReLU(), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim, embed_l), nn.ReLU(
            ), nn.Dropout(self.dropout),
            nn.Linear(embed_l, 1)
        )
        self.reset()

    def reset(self):
        for p in self.parameters():
            if p.size(0) == self.n_pid+1 and self.n_pid > 0:
                torch.nn.init.constant_(p, 0.)

            # if p.size(0) == self.n_question+5 and self.n_pid == 0:
            #     torch.nn.init.constant_(p, 0.)

    def base_emb(self, q_data, target):

        q_embed_data = self.q_embed(q_data)  # BS, seqlen,  d_model# c_ct
        if self.separate_qa:
            qa_data = q_data + self.n_question * target
            qa_embed_data = self.qa_embed(qa_data)
        else:
            # BS, seqlen, d_model # c_ct+ g_rt =e_(ct,rt)
            qa_embed_data = self.qa_embed(target)+q_embed_data
        return q_embed_data, qa_embed_data

    def forward(self, q_data, target, pid_data=None, qtest=False):
        emb_type = self.emb_type
        # Batch First
        if emb_type.startswith("qid"):
            q_embed_data, qa_embed_data = self.base_emb(q_data, target)

        pid_embed_data = None
        if self.n_pid > 0: # have problem id
            q_embed_diff_data = self.q_embed_diff(q_data)  # d_ct 总结了包含当前question（concept）的problems（questions）的变化
            pid_embed_data = self.difficult_param(pid_data)  # uq 当前problem的难
            qa_embed_diff_data = self.qa_embed_diff(target)  # f_(ct,rt) or #h_rt (qt, rt)差异向量
            if self.separate_qa:
                #print(1)
                qa_embed_data = qa_embed_data + pid_embed_data * \
                    qa_embed_diff_data  # uq* f_(ct,rt) + e_(ct,rt)
            else:
                #print(2)
                qa_embed_data = qa_embed_data +  pid_embed_data * q_embed_diff_data

                # qa_embed_data = qa_embed_data + pid_embed_data * \
                #     (qa_embed_diff_data+q_embed_diff_data)  
            #print("c_reg_loss")


            q_emb_all=self.q_embed(torch.tensor(range(0,self.n_question),device=device))
            cosine_similarity = torch.cdist(q_emb_all, q_emb_all)
            distance_matrix = cosine_similarity
            c_reg_loss = 0.         # (pid_embed_data ** 2.).sum() *self.l2-(q_embed_diff_data ** 2.).sum()*0*self.l2
        
        
        
        else:
            pid_embed_data = self.difficult_param(q_data)
            c_reg_loss = 0.

        # BS.seqlen,d_model
        # Pass to the decoder
        # output shape BS,seqlen,d_model or d_model//2
        #d_output = self.model(q_embed_data, qa_embed_data, pid_embed_data)
        # print(q_data)
        # print(pid_data.shape)
        #print(target*q_data)
   

        first=qa_embed_data[:,0:1,:]
        qa_embed_data=torch.concat([first,qa_embed_data],dim=1)[:,:-1,:]
        out, hn=self.rnn(qa_embed_data)
        d_output=out
        # print(out.shape)
        # print(d_output.shape)
        if self.n_pid == 0:
            pid_data=q_data
        q_embed_diff_data, hn=self.rnn2(self.p_embed(pid_data))   #  !!!!!!!!!!!!!!!!!!!!!
        
        
        #only qt
        if self.ablation =="only_qt":
            q_embed_data =self.p_embed(pid_data)
        #only  rasch qt
        if self.ablation =="only_rasch_qt":
            q_embed_data =self.q_embed(q_data)+pid_embed_data*self.p_embed(pid_data)
        #only  context-aware qt
        if self.ablation =="only_context-aware_qt":
            q_embed_data =q_embed_diff_data
        # Context-aware rasch qt
        if self.ablation =="context-aware_rasch_qt":
            q_embed_data = (1-self.ratio)*self.q_embed(q_data) + self.ratio*pid_embed_data * q_embed_diff_data
        





        concat_q = torch.cat([d_output, q_embed_data], dim=-1)
        output = self.out(concat_q).squeeze(-1)
        m = nn.Sigmoid()
        preds = m(output)


        #preds = (preds * torch.nn.functional.one_hot(q_data.long(), self.n_question)).sum(-1)
        # q_1_emb=self.difficult_param(((target-1)**2)*pid_data)
        # first=q_1_emb[:,0:1,:]
        # q_1_emb=torch.concat([first,q_1_emb],dim=1)[:,:-1,:]  


        # concat_q = torch.cat([out, q_embed_data], dim=-1)
        # output = self.out2(concat_q).squeeze(-1)
        # m = nn.Sigmoid()
        # preds2 = m(output)
        # preds=(preds+preds2)/2
        #print(distance_matrix.mean())
        if not qtest:
            return preds, c_reg_loss
        else:
            return preds, c_reg_loss, concat_q



