import torch
import torch.nn as nn
import torch.nn.functional as F

class KTModel(nn.Module):
    def __init__(self, input_num, hidden_dim, drop):
        super(KTModel, self).__init__()
        self.input_num = input_num
        self.hidden_dim = hidden_dim
        self.drop = drop
        
        self.embedding = nn.Parameter(torch.empty((input_num, hidden_dim)))
        nn.init.normal_(self.embedding)
        
        self.lstm = nn.LSTM(2 * hidden_dim, hidden_dim, batch_first = True)
        self.hlin = nn.Linear(hidden_dim, 1)
        self.qlin = nn.Linear(hidden_dim, 1)

    def forward(self, cur_q, cur_r, nxt_q):
        cur_q_emb = F.embedding(cur_q, self.embedding)
        cur_r = cur_r.unsqueeze(-1)
        
        x = torch.cat((cur_q_emb * cur_r, cur_q_emb * (1 - cur_r)), dim = -1)
        x = F.dropout(x, self.drop)

        h, _ = self.lstm(x)
        h = F.dropout(h, self.drop)
        
        nxt_q_emb = F.embedding(nxt_q, self.embedding)
        y = self.hlin(h) + self.qlin(nxt_q_emb)
        y = F.sigmoid(y).squeeze(-1)
        
        return y
    
    def get_pred(self, cur_q, cur_r):
        cur_q_emb = F.embedding(cur_q, self.embedding)
        cur_r = cur_r.unsqueeze(-1)
        x = torch.cat((cur_q_emb * cur_r, cur_q_emb * (1 - cur_r)), dim = -1)

        h, _ = self.lstm(x)

        hl = self.hlin(h)
        ql = self.qlin(self.embedding).squeeze(-1)

        pred = hl + ql.reshape(1, 1, -1)
        
        return pred