"""
    # reference: https://github.com/bigdata-ustc/EduCDM/blob/main/EduCDM/MIRT/MIRT.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence
from collections import defaultdict
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader
from lib.utils import tensor2npy
from .BaseModel import BaseModel
from lib.callback import CallbackList, Callback


class MIRT(BaseModel):
    """
        第一种: fix_a = True, fix_c = True
        第二种: fix_a = False, fix_c = True
        第三种: fix_a = False, fix_c = False
    """
    expose_default_cfg = {
        "a_range": -1.0, # disc range
        "emb_dim": 32
    }
    def __init__(self, cfg, xavier_init=True):
        super().__init__(cfg, xavier_init)

    def build_cfg(self):
        if self.model_cfg['a_range'] and self.model_cfg['a_range']  < 0: self.model_cfg['a_range'] = None

        self.n_user = self.data_cfg['dt_info']['user_count']
        self.n_item = self.data_cfg['dt_info']['item_count']
        self.emb_dim = self.model_cfg['emb_dim']

    def build_model(self):
        self.theta = nn.Embedding(self.n_user, self.emb_dim) # student ability
        self.a = nn.Embedding(self.n_item, self.emb_dim) # exer discrimination
        self.b = nn.Embedding(self.n_item, 1) # exer intercept term

    def forward(self, user_idx, item_idx):
        theta = self.theta(user_idx)
        a = self.a(item_idx)
        b = self.b(item_idx).flatten()

        if self.model_cfg['a_range'] is not None:
            a = self.model_cfg['a_range'] * torch.sigmoid(a)
        else:
            a = F.softplus(a) # 让区分度大于0，保持单调性假设
        if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b):  # pragma: no cover
            raise ValueError('ValueError:theta,a,b may contains nan!  The diff_range or a_range is too large.')
        return self.irf(theta, a, b)

    @staticmethod
    def irf(theta, a, b):
        return 1 / (1 + torch.exp(- torch.sum(torch.multiply(a, theta), axis=-1) + b)) # 为何sum前要取负号


    def fit(self, train_dataset, val_dataset=None,callbacks: Sequence[Callback]=()):
        lr = self.train_cfg['lr']
        epoch_num = self.train_cfg['epoch_num']
        batch_size = self.train_cfg['batch_size']
        num_workers = self.train_cfg['num_workers']
        eval_batch_size = self.train_cfg['eval_batch_size']
        weight_decay = self.train_cfg['weight_decay']
        eps = self.train_cfg['eps']

        model = self.train()
        optimizer = self._get_optim(optimizer=self.train_cfg['optim'], lr=lr, weight_decay=weight_decay, eps=eps)
        self.optimizer = optimizer

        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
        if val_dataset is not None:
            val_loader = DataLoader(val_dataset, shuffle=False, batch_size=eval_batch_size, num_workers=num_workers)

        callback_list = CallbackList(callbacks=callbacks, model=model, logger=self.logger)
        callback_list.on_train_begin()
        for epoch in range(epoch_num):
            callback_list.on_epoch_begin(epoch + 1)
            logs = defaultdict(lambda: np.full((len(train_loader),), np.nan, dtype=np.float32))
            for batch_id, batch in enumerate(
                    tqdm(train_loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[EPOCH={:03d}]".format(epoch + 1))
            ):
                batch = batch.to(self.device)
                users = batch[:, 0]
                items = batch[:, 1]
                labels = batch[:, 2].float()
                pd = model(users, items).flatten()
                loss = F.binary_cross_entropy(input=pd, target=labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                logs['loss'][batch_id] = loss.item()

            for name in logs: logs[name] = float(np.nanmean(logs[name]))

            if val_dataset is not None:
                val_metrics = self.evaluate(val_loader)
                logs.update({f"val_{metric}": val_metrics[metric] for metric in val_metrics})
            
            callback_list.on_epoch_end(epoch + 1, logs=logs)
            if self.share_obj_dict.get('stop_training', False):
                break

        callback_list.on_train_end()

    @torch.no_grad()
    def evaluate(self, loader):
        self.eval()
        pd_list = list(range(len(loader)))
        gt_list = list(range(len(loader)))
        for idx, batch in enumerate(tqdm(loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[PREDICT]")):
            batch = batch.to(self.device)
            u = batch[:, 0]
            i = batch[:, 1]
            r = batch[:, 2]
            pd_list[idx] = self.forward(u, i).flatten()
            gt_list[idx] = r.flatten()
        y_pd = tensor2npy(torch.hstack(pd_list))
        y_gt = tensor2npy(torch.hstack(gt_list))
        eval_result = {
            metric: self._get_metrics(metric)(y_gt, y_pd) for metric in self.eval_cfg['metrics']
        }
        return eval_result
