#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time: 2025/9/9 14:57
# @Author: hb925
# @File: dkvmn.py

import argparse
import datetime
import json
import os
import random

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from torch.nn import Embedding, Linear, Dropout, Parameter
from torch.nn.init import kaiming_normal_
from torch.utils.data import DataLoader

from data_load import split_generator


class DKVMN(nn.Module):
    def __init__(self, problem_num,skill_num,group_num,max_len,mode="KC", memory_dim=256, memory_size=64, dropout=0.2,**kwargs):
        if len(kwargs) > 0:
            print(f"unused params for model:{kwargs}")
        super().__init__()
        self.name = "dkvmn"+"-"+mode
        self._losses = []
        self._labels = []
        self._outputs = []
        self.optimizer = None
        self.mode=mode
        if self.mode == "KC":
            self.num_c = skill_num
        elif self.mode == "Q":
            self.num_c = problem_num
        elif self.mode == "Ours":
            self.num_c = group_num
        else:
            raise NotImplementedError
        self.dim_s = memory_dim
        self.size_m = memory_size

        self.k_emb_layer = Embedding(self.num_c, self.dim_s)
        self.Mk = Parameter(torch.Tensor(self.size_m, self.dim_s))
        self.Mv0 = Parameter(torch.Tensor(self.size_m, self.dim_s))

        kaiming_normal_(self.Mk)
        kaiming_normal_(self.Mv0)

        self.v_emb_layer = Embedding(self.num_c * 2, self.dim_s)

        self.f_layer = Linear(self.dim_s * 2, self.dim_s)
        self.dropout_layer = Dropout(dropout)
        self.p_layer = Linear(self.dim_s, 1)

        self.e_layer = Linear(self.dim_s, self.dim_s)
        self.a_layer = Linear(self.dim_s, self.dim_s)

    def forward(self, x, mask=None, training=None, **kwargs):
        q, p,g, r = x
        if self.mode == "KC":
            qa = q + self.num_c * r
            k = self.k_emb_layer(q)  # question embedding
            v = self.v_emb_layer(qa)
        elif self.mode == "Q":
            qa = p + self.num_c * r
            k = self.k_emb_layer(p)  # question embedding
            v = self.v_emb_layer(qa)
        elif self.mode == "Ours":
            qa = g + self.num_c * r
            k = self.k_emb_layer(g)  # question embedding
            v = self.v_emb_layer(qa)
        else:
            raise NotImplementedError
        batch_size = q.shape[0]

        Mvt = self.Mv0.unsqueeze(0).repeat(batch_size, 1, 1)
        Mv = [Mvt]
        w = torch.softmax(torch.matmul(k, self.Mk.T), dim=-1)
        # Write Process
        e = torch.sigmoid(self.e_layer(v))
        a = torch.tanh(self.a_layer(v))
        for et, at, wt in zip(
                e.permute(1, 0, 2), a.permute(1, 0, 2), w.permute(1, 0, 2)
        ):
            Mvt = Mvt * (1 - (wt.unsqueeze(-1) * et.unsqueeze(1))) + \
                  (wt.unsqueeze(-1) * at.unsqueeze(1))
            Mv.append(Mvt)
        Mv = torch.stack(Mv, dim=1)
        # Read Process
        f = torch.tanh(
            self.f_layer(
                torch.cat(
                    [
                        (w.unsqueeze(-1) * Mv[:, :-1]).sum(-2),
                        k
                    ],
                    dim=-1
                )
            )
        )
        p = self.p_layer(self.dropout_layer(f))
        p = torch.sigmoid(p)
        return p[:, 1:, :]
    def compile_model(self, optimizer=None, lr=0.001, weight_decay=0):
        if str.lower(optimizer) == "adam":
            self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
        elif str.lower(optimizer) == "sgd":
            self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError("unknow optimizer name")

    def reset_state(self):
        self._losses.clear()
        self._labels.clear()
        self._outputs.clear()

    def compute_loss(self, x, y_pred, y, sample_weight):
        loss = F.binary_cross_entropy(y_pred, y)
        return loss

    def compute_metrics(self):
        loss = np.mean(self._losses)
        ts = np.concatenate(self._labels, axis=0)
        ps = np.concatenate(self._outputs, axis=0)
        prelabels = [1 if p >= 0.5 else 0 for p in ps]
        auc = metrics.roc_auc_score(ts, ps)
        acc = metrics.accuracy_score(ts, prelabels)
        rmse = metrics.mean_squared_error(ts, ps,squared=False)
        return {"loss": loss, "auc": auc, "acc": acc, "rmse": rmse}

    def train_step(self, data):
        x, y, mask, sample_weight = self.data_map(data)
        # Compute prediction error
        y_pred = self(x, training=True, mask=mask)
        mask = mask[:, 1:, :]
        y_pred, y = y_pred.masked_select(mask), y.masked_select(mask)
        loss = self.compute_loss(x, y_pred, y, sample_weight)
        # Backpropagation
        self._losses.append(loss.detach().cpu().item())
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        y_pred, y = y_pred.detach().cpu().numpy(), y.detach().cpu().numpy()
        self._labels.append(y)
        self._outputs.append(y_pred)

        return {"loss": loss}

    def test_step(self, data):
        x, y, mask, sample_weight = self.data_map(data)
        # Compute prediction error
        y_pred = self(x, training=False, mask=mask)
        mask = mask[:, 1:, :]
        y_pred, y = y_pred.masked_select(mask), y.masked_select(mask)
        loss = self.compute_loss(x, y_pred, y, sample_weight)

        y_pred, y = y_pred.detach().cpu().numpy(), y.detach().cpu().numpy()
        loss = loss.detach().cpu().item()
        self._labels.append(y)
        self._outputs.append(y_pred)
        self._losses.append(loss)
        return {"loss": loss}

    @property
    def inputs_specs(self):
        return ("problem", "skill", "group"), "correct"

    def data_map(self, data):
        (problem, skill, group), y = data
        mask = torch.ge(y, 0).type(torch.bool)
        problem = (problem * mask).type(torch.long)
        skill = (skill * mask).type(torch.long)
        group = (group * mask).type(torch.long)
        y = y.unsqueeze(-1).type(torch.float)
        r = (y.squeeze(-1) * mask).type(torch.long)
        return (skill, problem, group, r), y[:, 1:], mask.unsqueeze(-1).type(torch.bool), None
