#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2025/4/21 08:33
# @Author  : hb
# @File    : simplekt_np.py
import argparse
import datetime
import json
import math
import os
import random
from enum import IntEnum
from typing import Dict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from torch.nn.init import xavier_uniform_, constant_
from torch.utils.data import DataLoader

from data_load import split_generator


def logs2str(logs: Dict) -> str:
    logstrs = []
    for k, v in logs.items():
        logstrs.append(f'{k}:{v:.4f}')
    return ",".join(logstrs)


def to_device(value, device):
    if isinstance(value, torch.Tensor):
        return value.to(device)
    elif isinstance(value, (tuple, list)):
        vv = [to_device(_, device) for _ in value]
        return vv
    elif isinstance(value, dict):
        newValue = {}
        for k, v in value.items():
            newValue[k] = to_device(v, device)
        return newValue
    else:
        raise ValueError("nonsupport type!")

class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2
class LearnablePositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        # Compute the positional encodings once in log space.
        pe = 0.1 * torch.randn(max_len, d_model)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=True)

    def forward(self, x):
        return self.weight[:, :x.size(Dim.seq), :]  # ( 1,seq,  Feature)


class CosinePositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        # Compute the positional encodings once in log space.
        pe = 0.1 * torch.randn(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=False)

    def forward(self, x):
        return self.weight[:, :x.size(Dim.seq), :]  # ( 1,seq,  Feature)


class Architecture(nn.Module):
    def __init__(self, n_question, n_blocks, d_model, d_feature,
                 d_ff, n_heads, dropout, kq_same,  seq_len):
        super().__init__()
        """
            n_block : number of stacked blocks in the attention
            d_model : dimension of attention input/output
            d_feature : dimension of input in each of the multi-head attention part.
            n_head : number of heads. n_heads*d_feature = d_model
        """
        self.d_model = d_model
        self.blocks_2 = nn.ModuleList([
            TransformerLayer(d_model=d_model, d_feature=d_model // n_heads,
                             d_ff=d_ff, dropout=dropout, n_heads=n_heads, kq_same=kq_same)
            for _ in range(n_blocks)
        ])
        self.position_emb = CosinePositionalEmbedding(d_model=self.d_model, max_len=seq_len)

    def forward(self, q_embed_data, qa_embed_data):
        # target shape  bs, seqlen
        seqlen, batch_size = q_embed_data.size(1), q_embed_data.size(0)

        q_posemb = self.position_emb(q_embed_data)
        q_embed_data = q_embed_data + q_posemb
        qa_posemb = self.position_emb(qa_embed_data)
        qa_embed_data = qa_embed_data + qa_posemb

        qa_pos_embed = qa_embed_data
        q_pos_embed = q_embed_data

        y = qa_pos_embed
        seqlen, batch_size = y.size(1), y.size(0)
        x = q_pos_embed

        # encoder

        for block in self.blocks_2:
            x = block(mask=0, query=x, key=x, values=y,
                      apply_pos=True)  # True: +FFN+残差+laynorm 非第一层与0~t-1的的q的attention, 对应图中Knowledge Retriever
            # mask=0，不能看到当前的response, 在Knowledge Retrever的value全为0，因此，实现了第一题只有question信息，无qa信息的目的
            # print(x[0,0,:])
        return x


class TransformerLayer(nn.Module):
    def __init__(self, d_model, d_feature,
                 d_ff, n_heads, dropout, kq_same):
        super().__init__()
        """
            This is a Basic Block of Transformer paper. It containts one Multi-head attention object. Followed by layer norm and postion wise feedforward net and dropout layer.
        """
        kq_same = kq_same == 1
        # Multi-Head Attention Block
        self.masked_attn_head = MultiHeadAttention(
            d_model, d_feature, n_heads, dropout, kq_same=kq_same)

        # Two layer norm layer and two droput layer
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        self.linear1 = nn.Linear(d_model, d_ff)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, mask, query, key, values, apply_pos=True):
        """
        Input:
            block : object of type BasicBlock(nn.Module). It contains masked_attn_head objects which is of type MultiHeadAttention(nn.Module).
            mask : 0 means, it can peek only past values. 1 means, block can peek only current and pas values
            query : Query. In transformer paper it is the input for both encoder and decoder
            key : Keys. In transformer paper it is the input for both encoder and decoder
            Values. In transformer paper it is the input for encoder and  encoded output for decoder (in masked attention part)

        Output:
            query: Input gets changed over the layer and returned.

        """

        seqlen, batch_size = query.size(1), query.size(0)
        device=query.device
        nopeek_mask = np.triu(
            np.ones((1, 1, seqlen, seqlen)), k=mask).astype('uint8')
        src_mask = (torch.from_numpy(nopeek_mask) == 0).to(device)
        if mask == 0:  # If 0, zero-padding is needed.
            # Calls block.masked_attn_head.forward() method
            query2 = self.masked_attn_head(
                query, key, values, mask=src_mask,
                zero_pad=True)  # 只能看到之前的信息，当前的信息也看不到，此时会把第一行score全置0，表示第一道题看不到历史的interaction信息，第一题attn之后，对应value全0
        else:
            # Calls block.masked_attn_head.forward() method
            query2 = self.masked_attn_head(
                query, key, values, mask=src_mask, zero_pad=False)

        query = query + self.dropout1((query2))  # 残差1
        query = self.layer_norm1(query)  # layer norm
        if apply_pos:
            query2 = self.linear2(self.dropout(  # FFN
                self.activation(self.linear1(query))))
            query = query + self.dropout2((query2))  # 残差
            query = self.layer_norm2(query)  # lay norm
        return query


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_feature, n_heads, dropout, kq_same, bias=True):
        super().__init__()
        """
        It has projection layer for getting keys, queries and values. Followed by attention and a connected layer.
        """
        self.d_model = d_model
        self.d_k = d_feature
        self.h = n_heads
        self.kq_same = kq_same

        self.v_linear = nn.Linear(d_model, d_model, bias=bias)
        self.k_linear = nn.Linear(d_model, d_model, bias=bias)
        if kq_same is False:
            self.q_linear = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(dropout)
        self.proj_bias = bias
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)

        self._reset_parameters()

    def _reset_parameters(self):
        xavier_uniform_(self.k_linear.weight)
        xavier_uniform_(self.v_linear.weight)
        if self.kq_same is False:
            xavier_uniform_(self.q_linear.weight)

        if self.proj_bias:
            constant_(self.k_linear.bias, 0.)
            constant_(self.v_linear.bias, 0.)
            if self.kq_same is False:
                constant_(self.q_linear.bias, 0.)
            constant_(self.out_proj.bias, 0.)

    def forward(self, q, k, v, mask, zero_pad):

        bs = q.size(0)

        # perform linear operation and split into h heads

        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        if self.kq_same is False:
            q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        else:
            q = self.k_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # transpose to get dimensions bs * h * sl * d_model

        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # calculate attention using function we will define next
        scores = attention(q, k, v, self.d_k,
                           mask, self.dropout, zero_pad)

        # concatenate heads and put through final linear layer
        concat = scores.transpose(1, 2).contiguous() \
            .view(bs, -1, self.d_model)

        output = self.out_proj(concat)

        return output


def attention(q, k, v, d_k, mask, dropout, zero_pad):
    """
    This is called by Multi-head atention object to find the values.
    """
    # d_k: 每一个头的dim
    device=q.device
    scores = torch.matmul(q, k.transpose(-2, -1)) / \
             math.sqrt(d_k)  # BS, 8, seqlen, seqlen
    bs, head, seqlen = scores.size(0), scores.size(1), scores.size(2)

    scores.masked_fill_(mask == 0, -1e32)
    scores = F.softmax(scores, dim=-1)  # BS,8,seqlen,seqlen
    # print(f"before zero pad scores: {scores.shape}")
    # print(zero_pad)
    if zero_pad:
        pad_zero = torch.zeros(bs, head, 1, seqlen).to(device)
        scores = torch.cat([pad_zero, scores[:, :, 1:, :]], dim=2)  # 第一行score置0
    # print(f"after zero pad scores: {scores}")
    scores = dropout(scores)
    output = torch.matmul(scores, v)
    # import sys
    # sys.exit()
    return output
class SimpleKT(torch.nn.Module):
    def __init__(self, problem_num,skill_num,group_num,max_len,mode="KC",  emb_size=64,  n_blocks=4, dropout=0.1, d_ff=256,
            kq_same=1, final_fc_dim=512, final_fc_dim2=256, num_attn_heads=8, separate_qa=False, l2=1e-5,  **kwargs):
        if len(kwargs) > 0:
            print(f"unused params for model:{kwargs}")
        super().__init__()
        self.name = "SimpleKT"+ "-" + mode
        self._losses = []
        self._labels = []
        self._outputs = []
        self.optimizer = None
        # self.skill_max = problem_num
        self.max_len = max_len
        self.mode = mode
        if self.mode == "KC":
            self.n_question = skill_num
            self.n_pid = problem_num
        elif self.mode == "Q":
            self.n_question = problem_num
            self.n_pid = -1
        elif self.mode == "Ours":
            self.n_question = group_num
            self.n_pid = problem_num
        else:
            raise NotImplementedError
        self.dropout = dropout
        self.kq_same = kq_same
        self.l2 = l2
        self.separate_qa = separate_qa
        d_model = embed_l = emb_size
        seq_len = max_len
        if self.n_pid>0:
            self.difficult_param = nn.Embedding(self.n_pid, embed_l)  # 题目难度
        self.q_embed_diff = nn.Embedding(self.n_question,
                                         embed_l)  # question emb, 总结了包含当前question（concept）的problems（questions）的变化
        self.qa_embed_diff = nn.Embedding(2 * self.n_question, embed_l)
        self.q_embed = nn.Embedding(self.n_question, embed_l)
        if self.separate_qa:
            self.qa_embed = nn.Embedding(2 * self.n_question, embed_l)
        else:  # false default
            self.qa_embed = nn.Embedding(2, embed_l)
        # Architecture Object. It contains stack of attention block
        self.model = Architecture(n_question=self.n_question, n_blocks=n_blocks, n_heads=num_attn_heads,
                                  dropout=dropout,
                                  d_model=d_model, d_feature=d_model / num_attn_heads, d_ff=d_ff, kq_same=self.kq_same,
                                  seq_len=seq_len)

        self.out = nn.Sequential(
            nn.Linear(d_model + embed_l,
                      final_fc_dim), nn.ReLU(), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim, final_fc_dim2), nn.ReLU(
            ), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim2, 1)
        )

        # self.reset()

    def reset(self):
        for p in self.parameters():
            if p.size(0) == self.n_pid + 1 and self.n_pid > 0:
                torch.nn.init.constant_(p, 0.)

    def base_emb(self, q_data, target):
        q_embed_data = self.q_embed(q_data)  # BS, seqlen,  d_model# c_ct
        if self.separate_qa:
            qa_data = q_data + self.n_question * target
            qa_embed_data = self.qa_embed(qa_data)
        else:
            # BS, seqlen, d_model # c_ct+ g_rt =e_(ct,rt)
            qa_embed_data = self.qa_embed(target) + q_embed_data
        return q_embed_data, qa_embed_data

    def get_attn_pad_mask(self, sm):
        batch_size, l = sm.size()
        pad_attn_mask = sm.data.eq(0).unsqueeze(1)
        pad_attn_mask = pad_attn_mask.expand(batch_size, l, l)
        return pad_attn_mask.repeat(self.nhead, 1, 1)

    def forward(self, x, mask=None, training=None, **kwargs):
        q, p,g, r = x
        target = r
        if self.mode == "KC":
            q_data = q
            pid_data = p
        elif self.mode == "Q":
            q_data = p
            pid_data = None
        elif self.mode == "Ours":
            q_data = g
            pid_data = p
        else:
            raise NotImplementedError
        q_embed_data, qa_embed_data = self.base_emb(q_data, target)
         # d_ct 总结了包含当前question（concept）的problems（questions）的变化
        if self.n_pid>0:
            q_embed_diff_data = self.q_embed_diff(q_data)
            pid_embed_data = self.difficult_param(pid_data)  # uq 当前problem的难度
            q_embed_data = q_embed_data + pid_embed_data * \
                           q_embed_diff_data  # uq *d_ct + c_ct # question encoder
        d_output = self.model(q_embed_data, qa_embed_data)

        concat_q = torch.cat([d_output, q_embed_data], dim=-1)
        output = self.out(concat_q).squeeze(-1)
        m = nn.Sigmoid()
        preds = m(output)

        return preds[:, 1:].unsqueeze(dim=-1)



    def compile_model(self, optimizer=None, lr=0.001, weight_decay=0):
        if str.lower(optimizer) == "adam":
            self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
        elif str.lower(optimizer) == "sgd":
            self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError("unknow optimizer name")

    def reset_state(self):
        self._losses.clear()
        self._labels.clear()
        self._outputs.clear()

    def compute_loss(self, x, y_pred, y, sample_weight):
        loss = F.binary_cross_entropy(y_pred, y)
        return loss

    def compute_metrics(self):
        loss = np.mean(self._losses)
        ts = np.concatenate(self._labels, axis=0)
        ps = np.concatenate(self._outputs, axis=0)
        prelabels = [1 if p >= 0.5 else 0 for p in ps]
        auc = metrics.roc_auc_score(ts, ps)
        acc = metrics.accuracy_score(ts, prelabels)
        rmse = metrics.mean_squared_error(ts, ps,squared=False)
        return {"loss": loss, "auc": auc, "acc": acc, "rmse": rmse}

    def train_step(self, data):
        x, y, mask, sample_weight = self.data_map(data)
        # Compute prediction error
        y_pred = self(x, training=True, mask=mask)
        mask = mask[:, 1:, :]
        y_pred, y = y_pred.masked_select(mask), y.masked_select(mask)
        loss = self.compute_loss(x, y_pred, y, sample_weight)
        # Backpropagation
        self._losses.append(loss.detach().cpu().item())
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        y_pred, y = y_pred.detach().cpu().numpy(), y.detach().cpu().numpy()
        self._labels.append(y)
        self._outputs.append(y_pred)

        return {"loss": loss}

    def test_step(self, data):
        x, y, mask, sample_weight = self.data_map(data)
        # Compute prediction error
        y_pred = self(x, training=False, mask=mask)
        mask = mask[:, 1:, :]
        y_pred, y = y_pred.masked_select(mask), y.masked_select(mask)
        loss = self.compute_loss(x, y_pred, y, sample_weight)

        y_pred, y = y_pred.detach().cpu().numpy(), y.detach().cpu().numpy()
        loss = loss.detach().cpu().item()
        self._labels.append(y)
        self._outputs.append(y_pred)
        self._losses.append(loss)
        return {"loss": loss}


    @property
    def inputs_specs(self):
        return ("problem", "skill", "group"), "correct"


    def data_map(self, data):
        (problem, skill, group), y = data
        mask = torch.ge(y, 0).type(torch.bool)
        problem = (problem * mask).type(torch.long)
        skill = (skill * mask).type(torch.long)
        group = (group * mask).type(torch.long)
        y = y.unsqueeze(-1).type(torch.float)
        r = (y.squeeze(-1) * mask).type(torch.long)
        return (skill, problem, group, r), y[:, 1:], mask.unsqueeze(-1).type(torch.bool), None

