#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time: 2025/9/9 14:57
# @Author: hb925
# @File: dkt.py

import argparse
import datetime
import json
import os
import random

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from torch.nn import Embedding, Linear, Dropout, LSTM
from torch.utils.data import DataLoader

from data_load import split_generator


class DKT(nn.Module):
    def __init__(self, problem_num, skill_num,group_num, max_len, mode="KC", emb_size=64, hidden_units=128, dropout=0.1,
                 **kwargs):
        if len(kwargs) > 0:
            print(f"unused params for model:{kwargs}")
        super().__init__()
        self.name = "dkt" + "-" + mode
        self._losses = []
        self._labels = []
        self._outputs = []
        self.optimizer = None
        self.mode = mode
        if self.mode == "KC":
            self.num_c = skill_num
        elif self.mode == "Q":
            self.num_c = problem_num
        elif self.mode == "Ours":
            self.num_c = group_num
        else:
            raise NotImplementedError
        self.emb_size = emb_size
        self.max_len = max_len
        self.hidden_units = hidden_units
        self.dropout = dropout
        self.interaction_emb = Embedding(self.num_c * 2, self.emb_size)
        self.lstm_layer = LSTM(self.emb_size, self.hidden_units, batch_first=True)
        self.dropout_layer = Dropout(self.dropout)
        self.out_layer = Linear(self.hidden_units, self.num_c)

    def forward(self, x, mask=None, training=None, **kwargs):
        q, p,g, r = x
        if self.mode == "KC":
            qa = q + self.num_c * r
            qa_emb = self.interaction_emb(qa[:, :-1])
            q_onehot = F.one_hot(q[:, 1:], num_classes=self.num_c)
        elif self.mode == "Q":
            qa = p + self.num_c * r
            qa_emb = self.interaction_emb(qa[:, :-1])
            q_onehot = F.one_hot(p[:, 1:], num_classes=self.num_c)
        elif self.mode == "Ours":
            qa = g + self.num_c * r
            qa_emb = self.interaction_emb(qa[:, :-1])
            q_onehot = F.one_hot(g[:, 1:], num_classes=self.num_c)
        else:
            raise NotImplementedError
        h, _ = self.lstm_layer(qa_emb)
        h = self.dropout_layer(h)
        y = self.out_layer(h)
        y = self.dropout_layer(y)
        y = torch.sigmoid(y)
        y = q_onehot * y
        return y.sum(dim=-1, keepdim=True)

    def compile_model(self, optimizer=None, lr=0.001, weight_decay=0):
        if str.lower(optimizer) == "adam":
            self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
        elif str.lower(optimizer) == "sgd":
            self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError("unknow optimizer name")

    def reset_state(self):
        self._losses.clear()
        self._labels.clear()
        self._outputs.clear()

    def compute_loss(self, x, y_pred, y, sample_weight):
        loss = F.binary_cross_entropy(y_pred, y)
        return loss

    def compute_metrics(self):
        loss = np.mean(self._losses)
        ts = np.concatenate(self._labels, axis=0)
        ps = np.concatenate(self._outputs, axis=0)
        prelabels = [1 if p >= 0.5 else 0 for p in ps]
        auc = metrics.roc_auc_score(ts, ps)
        acc = metrics.accuracy_score(ts, prelabels)
        rmse = metrics.mean_squared_error(ts, ps,squared=False)
        return {"loss": loss, "auc": auc, "acc": acc, "rmse": rmse}

    def train_step(self, data):
        x, y, mask, sample_weight = self.data_map(data)
        # Compute prediction error
        y_pred = self(x, training=True, mask=mask)
        mask = mask[:, 1:, :]
        y_pred, y = y_pred.masked_select(mask), y.masked_select(mask)
        loss = self.compute_loss(x, y_pred, y, sample_weight)
        # Backpropagation
        self._losses.append(loss.detach().cpu().item())
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        y_pred, y = y_pred.detach().cpu().numpy(), y.detach().cpu().numpy()
        self._labels.append(y)
        self._outputs.append(y_pred)

        return {"loss": loss}

    def test_step(self, data):
        x, y, mask, sample_weight = self.data_map(data)
        # Compute prediction error
        y_pred = self(x, training=False, mask=mask)
        mask = mask[:, 1:, :]
        y_pred, y = y_pred.masked_select(mask), y.masked_select(mask)
        loss = self.compute_loss(x, y_pred, y, sample_weight)

        y_pred, y = y_pred.detach().cpu().numpy(), y.detach().cpu().numpy()
        loss = loss.detach().cpu().item()
        self._labels.append(y)
        self._outputs.append(y_pred)
        self._losses.append(loss)
        return {"loss": loss}


    @property
    def inputs_specs(self):
        return ("problem", "skill","group"), "correct"

    def data_map(self, data):
        (problem, skill,group), y = data
        mask = torch.ge(y, 0).type(torch.bool)
        problem = (problem * mask).type(torch.long)
        skill = (skill * mask).type(torch.long)
        group = (group * mask).type(torch.long)
        y = y.unsqueeze(-1).type(torch.float)
        r = (y.squeeze(-1) * mask).type(torch.long)
        return (skill, problem,group, r), y[:, 1:], mask.unsqueeze(-1).type(torch.bool), None
