'''
Descripttion: 
version: 0.0.1
Author: Zijie Li
Date: 2023-08-03 00:10:48
LastEditors: Zijie Li
LastEditTime: 2023-10-30 01:06:28
E-mail:icaq@qq.com
The project is developed by Yunnan Key Laboratory of Smart Education and Key Laboratory of Educational Informatization for Nationalities.
'''
import os
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .que_base_model import QueBaseModel
from pykt.utils import debug_print
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class QDKTNet(nn.Module):
    def __init__(self, num_q,num_c,emb_size, dropout=0.1, emb_type='qaid', emb_path="", pretrain_dim=768,device='cpu',mlp_layer_num=1,other_config={}):
        super().__init__()
        self.model_name = "qdkt"
        self.num_q = num_q
        self.num_c = num_c
        self.emb_size = emb_size
        self.hidden_size = emb_size
        self.device = device
        self.emb_type = emb_type
        
        
        self.interaction_emb = nn.Embedding(self.num_q * 2, self.emb_size)
        self.lstm_layer = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True)
        self.dropout_layer = nn.Dropout(dropout)
        self.out_layer = nn.Linear(self.hidden_size, self.num_q)


    def forward(self, q, c ,r,data=None):
        x = (q + self.num_q * r)[:,:-1]
        
        xemb = self.interaction_emb(x)
        h, _ = self.lstm_layer(xemb)
        h = self.dropout_layer(h)
        y = self.out_layer(h)
        y = torch.sigmoid(y)
        y = (y * F.one_hot(data['qshft'].long().to(device), self.num_q)).sum(-1)
        outputs = {"y":y}
        return outputs

class QDKT(QueBaseModel):
    def __init__(self, num_q,num_c, emb_size, dropout=0.1, emb_type='qaid', emb_path="", pretrain_dim=768,device='cpu',seed=0,mlp_layer_num=1,other_config={},**kwargs):
        model_name = "qdkt"
       
        debug_print(f"emb_type is {emb_type}",fuc_name="QDKT")

        super().__init__(model_name=model_name,emb_type=emb_type,emb_path=emb_path,pretrain_dim=pretrain_dim,device=device,seed=seed)
        self.model = QDKTNet(num_q=num_q,num_c=num_c,emb_size=emb_size,dropout=dropout,emb_type=emb_type,
                               emb_path=emb_path,pretrain_dim=pretrain_dim,device=device,mlp_layer_num=mlp_layer_num,other_config=other_config)
       
        self.model = self.model.to(device)
        self.emb_type = self.model.emb_type
        self.loss_func = self._get_loss_func("binary_crossentropy")
       
    def train_one_step(self,data,process=True,return_all=False):
        outputs,data_new = self.predict_one_step(data,return_details=True,process=process)
        loss = self.get_loss(outputs['y'],data_new['rshft'],data_new['sm'])
        return outputs['y'],loss#y_question没用

    def predict_one_step(self,data,return_details=False,process=True,return_raw=False):
        data_new = self.batch_to_device(data,process=process)
        outputs = self.model(data_new['cq'].long().to(device),data_new['cc'].to(device),data_new['cr'].to(device).long(),data=data_new)
        if return_details:
            return outputs,data_new
        else:
            return outputs['y']