import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from lib.utils import DecoratorTimer


def doa_report(df):
    knowledges = []
    knowledge_item = []
    knowledge_user = []
    knowledge_truth = []
    knowledge_theta = []
    for user, item, score, theta, knowledge in df[["user_id", "item_id", "score", "theta", "knowledge"]].values:
        if isinstance(theta, list):
            for i, (theta_i, knowledge_i) in enumerate(zip(theta, knowledge)):
                if knowledge_i == 1: 
                    knowledges.append(i) # 知识点ID
                    knowledge_item.append(item) # Item ID
                    knowledge_user.append(user) # User ID
                    knowledge_truth.append(score) # score
                    knowledge_theta.append(theta_i) # matser
        else:  # pragma: no cover
            for i, knowledge_i in enumerate(knowledge):
                if knowledge_i == 1:
                    knowledges.append(i)
                    knowledge_item.append(item)
                    knowledge_user.append(user)
                    knowledge_truth.append(score)
                    knowledge_theta.append(theta)

    knowledge_df = pd.DataFrame({
        "knowledge": knowledges,
        # "user_id": knowledge_user,
        "item_id": knowledge_item,
        "score": knowledge_truth,
        "theta": knowledge_theta
    })

    knowledge_ground_truth = []
    knowledge_prediction = []
    for _, group_df in knowledge_df.groupby("knowledge"):
        _knowledge_ground_truth = []
        _knowledge_prediction = []
        for _, item_group_df in group_df.groupby("item_id"):
            _knowledge_ground_truth.append(item_group_df["score"].values)
            _knowledge_prediction.append(item_group_df["theta"].values)
        knowledge_ground_truth.append(_knowledge_ground_truth)
        knowledge_prediction.append(_knowledge_prediction)

    return doa_eval(knowledge_ground_truth, knowledge_prediction)

def doa_eval(y_true, y_pred):
    """
    >>> import numpy as np
    >>> y_true = [
    ...     [np.array([1, 0, 1])],
    ...     [np.array([0, 1, 1])]
    ... ]
    >>> y_pred = [
    ...     [np.array([.5, .4, .6])],
    ...     [np.array([.2, .3, .5])]
    ... ]
    >>> doa_eval(y_true, y_pred)['doa']
    1.0
    >>> y_pred = [
    ...     [np.array([.4, .5, .6])],
    ...     [np.array([.3, .2, .5])]
    ... ]
    >>> doa_eval(y_true, y_pred)['doa']
    0.5
    """
    doa = []
    doa_support = 0
    z_support = 0
    for knowledge_label, knowledge_pred in zip(y_true, y_pred):
        _doa = 0
        _z = 0
        for label, pred in zip(knowledge_label, knowledge_pred): # 每个习题
            if sum(label) == len(label) or sum(label) == 0:
                continue
            pos_idx = []
            neg_idx = []
            for i, _label in enumerate(label): # 找出所有(1, 0) pair
                if _label == 1:
                    pos_idx.append(i)
                else:
                    neg_idx.append(i)
            pos_pred = pred[pos_idx]
            neg_pred = pred[neg_idx]
            invalid = 0
            for _pos_pred in pos_pred:
                _doa += len(neg_pred[neg_pred < _pos_pred])
                invalid += len(neg_pred[neg_pred == _pos_pred])
            _z += (len(pos_pred) * len(neg_pred)) - invalid
        if _z > 0:
            doa.append(_doa / _z)
            z_support += _z # 有效pair个数
            doa_support += 1 # 有效doa
    return {
        "doa": np.mean(doa),
        "doa_know_support": doa_support,
        "doa_z_support": z_support,
        "doa_list": doa,
    }

# def doa(df_interaction, stu_embedding, Q_matrix):
#     cpt_arr = []
#     # stu_arr = []
#     exer_arr = []
#     label_gt = []
#     label_pd = []

#     for stu_id, exer_id, score in df_interaction[["stu_id", "exer_id", "label"]].values:
#         emb_arr = stu_embedding[stu_id]
#         q_arr = Q_matrix[exer_id]
#         for cpt_id in np.argwhere(q_arr):
#             cpt_arr.append(int(cpt_id))
#             # stu_arr.append(stu_id)
#             exer_arr.append(exer_id)
#             label_gt.append(score)
#             label_pd.append(emb_arr[cpt_id])

#     df = pd.DataFrame({
#         "cpt": cpt_arr,
#         # "user_id": stu_arr,
#         "exer_id": exer_arr,
#         "label_gt": label_gt,
#         "label_pd": label_pd
#     })
    
#     doa = np.full((Q_matrix.shape[1],), np.nan, dtype=np.float32)
#     doa_support = 0
#     z_support = 0
#     for cpt_id, group_df in df.groupby("cpt"):
#         _doa = 0
#         _z = 0
#         for _, item_group_df in group_df.groupby("exer_id"):
#             label = item_group_df["label_gt"].values
#             pred = item_group_df["label_pd"].values
#             if label.sum() == label.shape[0] or label.sum() == 0: continue
#             pos_idx = np.argwhere(label == 1)
#             neg_idx = np.argwhere(label == 0)
#             pos_pred = pred[pos_idx]
#             neg_pred = pred[neg_idx]
#             invalid = 0
#             for _pos_pred in pos_pred:
#                 _doa += (neg_pred < _pos_pred).sum()
#                 invalid += (neg_pred == _pos_pred).sum()
#             _z += pos_pred.shape[0] * neg_pred.shape[0] - invalid
#         if _z > 0:
#             doa[cpt_id] = _doa / _z
#             z_support += _z 
#             doa_support += 1
        
#     return {
#         "doa": np.nanmean(doa),
#         "doa_know_support": doa_support,
#         "doa_z_support": z_support,
#         "doa_list": doa,
#     }


    



def doa(df):
    knowledges = []
    knowledge_item = []
    # knowledge_user = []
    knowledge_truth = []
    knowledge_theta = []
    for _, item, score, theta, knowledge in df[["user_id", "item_id", "score", "theta", "knowledge"]].values:
        if isinstance(theta, list):
            for i, (theta_i, knowledge_i) in enumerate(zip(theta, knowledge)):
                if knowledge_i == 1: 
                    knowledges.append(i) # 知识点ID
                    knowledge_item.append(item) # Item ID
                    # knowledge_user.append(user) # User ID
                    knowledge_truth.append(score) # score
                    knowledge_theta.append(theta_i) # matser
        else:  # pragma: no cover
            for i, knowledge_i in enumerate(knowledge):
                if knowledge_i == 1:
                    knowledges.append(i)
                    knowledge_item.append(item)
                    # knowledge_user.append(user)
                    knowledge_truth.append(score)
                    knowledge_theta.append(theta)

    knowledge_df = pd.DataFrame({
        "knowledge": knowledges,
        # "user_id": knowledge_user,
        "item_id": knowledge_item,
        "score": knowledge_truth,
        "theta": knowledge_theta
    })

    knowledge_ground_truth = []
    knowledge_prediction = []
    for _, group_df in knowledge_df.groupby("knowledge"):
        _knowledge_ground_truth = []
        _knowledge_prediction = []
        for _, item_group_df in group_df.groupby("item_id"):
            _knowledge_ground_truth.append(item_group_df["score"].values)
            _knowledge_prediction.append(item_group_df["theta"].values)
        knowledge_ground_truth.append(_knowledge_ground_truth)
        knowledge_prediction.append(_knowledge_prediction)

    return doa_eval2(knowledge_ground_truth, knowledge_prediction)

def doa_eval2(y_true, y_pred):
    """
    >>> import numpy as np
    >>> y_true = [
    ...     [np.array([1, 0, 1])],
    ...     [np.array([0, 1, 1])]
    ... ]
    >>> y_pred = [
    ...     [np.array([.5, .4, .6])],
    ...     [np.array([.2, .3, .5])]
    ... ]
    >>> doa_eval(y_true, y_pred)['doa']
    1.0
    >>> y_pred = [
    ...     [np.array([.4, .5, .6])],
    ...     [np.array([.3, .2, .5])]
    ... ]
    >>> doa_eval(y_true, y_pred)['doa']
    0.5
    """
    doa = []
    doa_support = 0
    z_support = 0
    for knowledge_label, knowledge_pred in zip(y_true, y_pred):
        _doa = 0
        _z = 0
        for label, pred in zip(knowledge_label, knowledge_pred): # 每个习题
            if sum(label) == len(label) or sum(label) == 0:
                continue
            pos_idx = []
            neg_idx = []
            for i, _label in enumerate(label): # 找出所有(1, 0) pair
                if _label == 1:
                    pos_idx.append(i)
                else:
                    neg_idx.append(i)
            pos_pred = pred[pos_idx]
            neg_pred = pred[neg_idx]
            invalid = 0
            for _pos_pred in pos_pred:
                _doa += len(neg_pred[neg_pred < _pos_pred])
                invalid += len(neg_pred[neg_pred == _pos_pred])
            _z += (len(pos_pred) * len(neg_pred)) - invalid
        if _z > 0:
            doa.append(_doa / _z)
            z_support += _z # 有效pair个数
            doa_support += 1 # 有效doa

    return {
        "doa": np.mean(doa),
        "doa_know_support": doa_support,
        "doa_z_support": z_support,
        "doa_list": doa,
    }
