from .BaseDataFmt import BaseDataFmt
import numpy as np
import pandas as pd
from itertools import chain
from dataset import VAEDataset
import torch
from lib.utils import set_same_seeds, tensor2npy
import os
from tqdm import tqdm


class EduFmtFold(BaseDataFmt):
    expose_default_cfg = {
        'adj_mat_type': None,
        'Q_delete_ratio': 0.0,
        'Q_delete_type': 'by_exer',
        "fold_num": 5,
        'Q_fill_type': "None", # "None"表示不填充，sim_dist_for_by_exer
        'seed': 2023,
        'params_topk':5, 
        'params_votek':2,
        "cpt_aff_name": "cpt_affiliation"
    }

    def __init__(self, cfg):
        super().__init__(cfg)
        set_same_seeds(self.data_cfg['seed'])
        df_interaction_train = pd.read_csv(f"{cfg.data_folder_path}/{cfg.DATASET}-train.csv", encoding='utf-8')
        df_interaction_val = pd.read_csv(f"{cfg.data_folder_path}/{cfg.DATASET}-val.csv", encoding='utf-8')
        df_interaction_test = pd.read_csv(f"{cfg.data_folder_path}/{cfg.DATASET}-test.csv", encoding='utf-8')
        self.df_Q = pd.read_csv(f"{cfg.data_folder_path}/{cfg.DATASET}-Q.csv", encoding='utf-8')
        self.df_Q['cpt_list'] = self.df_Q['cpt_list'].astype(str).apply(lambda x: [int(i) for i in x.split(',')])
        self.df_interaction = pd.concat([df_interaction_train, df_interaction_val, df_interaction_test], ignore_index=True)
        
        aff_name = self.data_cfg['cpt_aff_name']
        if os.path.exists(f"{cfg.data_folder_path}/{aff_name}.json"):
            self.dict_cpt_affiliation = self.read_json(f"{cfg.data_folder_path}/{aff_name}.json")
        
        self.check()
        self.build_info()
        self.build_datasets()

    def k_fold(self):
        fold_num = self.data_cfg['fold_num']
        df_interaction = self.df_interaction
        # 划分训练、测试、验证集
        from sklearn.model_selection import StratifiedKFold
        
        self.train_df_list = []
        self.val_df_list = []
        # 先把测试集弄出来
        skf = StratifiedKFold(n_splits=fold_num, shuffle=True, random_state=self.data_cfg['seed'])
        splits = skf.split(df_interaction, df_interaction['uid'])
        for i, (train_index, test_index) in enumerate(splits):
            df_interaction_train = df_interaction.iloc[train_index].reset_index(drop=True)
            df_interaction_val = df_interaction.iloc[test_index].reset_index(drop=True)
            self.train_df_list.append(df_interaction_train)
            self.val_df_list.append(df_interaction_val)

    @classmethod
    def from_cfg(cls, cfg):
        return cls(cfg)

    def check(self):
        assert 0 <= self.data_cfg['Q_delete_ratio'] < 1
    
    def build_info(self):
        self.data_cfg['dt_info'] = {
            'user_count': self.df_interaction['uid'].nunique(),
            'item_count': self.df_interaction['iid'].nunique(),
            'cpt_count': len(set(list(chain(*self.df_Q['cpt_list'].to_list()))))
        }
        assert self.df_Q['iid'].nunique() == self.data_cfg['dt_info']['item_count']

        self.logger.info(f"dt_info: {self.data_cfg['dt_info']}")

    
    def build_datasets(self):
        self.k_fold()
        user_count = self.cfg.data_cfg['dt_info']['user_count']
        item_count = self.cfg.data_cfg['dt_info']['item_count']
        cpt_count = self.cfg.data_cfg['dt_info']['cpt_count']

        # self.df_Q_train, self.df_Q_train_eval = self.get_df_Q_train() # 若None，则两个一样均为缺失Q矩阵，否则前者为填充的，后者为缺失的
        self.Q_mat = self.get_Q_mat_from_df_arr(self.df_Q, item_count, cpt_count)

        self.missing_df_Q = self.get_df_Q_train_v2()
        missing_Q_mat = self.get_Q_mat_from_df_arr(self.missing_df_Q, item_count, cpt_count)

        self.train_dataset_list = []
        self.val_dataset_list = []
        for fold_id in range(self.data_cfg['fold_num']):
            self.train_dataset_list.append(VAEDataset.from_df_arr(self.cfg, interact_df=self.train_df_list[fold_id]))
            self.val_dataset_list.append(VAEDataset.from_df_arr(self.cfg, interact_df=self.val_df_list[fold_id]))

            if self.data_cfg['Q_delete_ratio'] != 0.0:
                if self.data_cfg['Q_fill_type'] != "None":
                    fill_df_Q = self.fill_Q_by_sim_dist(self.train_df_list[fold_id], self.missing_df_Q, params_topk=self.data_cfg['params_topk'], params_votek=self.data_cfg['params_votek'])
                    fill_Q_mat = self.get_Q_mat_from_df_arr(fill_df_Q, item_count, cpt_count)
                    self.train_dataset_list[fold_id].Q_mat = fill_Q_mat
                else:
                    self.train_dataset_list[fold_id].Q_mat = missing_Q_mat
            else:
                self.train_dataset_list[fold_id].Q_mat = self.Q_mat
    
    def get_df_Q_train_v1(self):
        """删减部分Q，每个习题可存在知识点缺失
        """
        ratio = self.data_cfg['Q_delete_ratio']

        iid2cptlist = self.df_Q.set_index('iid')['cpt_list'].to_dict()
        iid_lis = np.array(list(chain(*[[i]*len(iid2cptlist[i]) for i in iid2cptlist])))
        cpt_lis = np.array(list(chain(*list(iid2cptlist.values()))))
        entry_arr = np.vstack([iid_lis, cpt_lis]).T
        np.random.shuffle(entry_arr)
        entry_num = len(iid_lis)

        # 检查保证ratio不会大到存在知识点没有习题
        delete_num = int(ratio * entry_num)
        preserved_num = entry_num - delete_num
        assert preserved_num >= self.data_cfg['dt_info']['cpt_count']

        # 下面的处理保证了每个知识点至少有一个习题与之对应

        # 1. 先从每个知识点中选一个习题作为保留
        # reference: https://stackoverflow.com/questions/64834655/python-how-to-find-first-duplicated-items-in-an-numpy-array
        _, idx = np.unique(entry_arr[:, 1], return_index=True)
        bool_idx = np.zeros_like(entry_arr[:, 1], dtype=bool)
        bool_idx[idx] = True
        preserved_entry = entry_arr[bool_idx, :]
        other_entry = entry_arr[~bool_idx, :]

        # 2. 剩余知识点随便删
        left_entry = other_entry[np.random.choice(len(other_entry), preserved_num - len(idx), replace=False), :]
        final_entry = np.vstack([preserved_entry, left_entry])
        df = pd.DataFrame({'iid': final_entry[:, 0], 'cpt_list': final_entry[:, 1]})

        return df.groupby('iid').agg(lambda x: list(x)).reset_index()


    def get_df_Q_train_v2(self):
        """删减Q矩阵
            习题要么有完整的标注，要么没有任何标注
            保证每个知识点存在1道习题
        """
        set_same_seeds(20230518)
        ratio = self.data_cfg['Q_delete_ratio']
        iid2cptlist = self.df_Q.set_index('iid')['cpt_list'].to_dict()
        iid_lis = np.array(list(chain(*[[i]*len(iid2cptlist[i]) for i in iid2cptlist])))
        cpt_lis = np.array(list(chain(*list(iid2cptlist.values()))))
        entry_arr = np.vstack([iid_lis, cpt_lis]).T

        np.random.shuffle(entry_arr)

        # 下面的处理保证了每个知识点至少有一个习题与之对应
        # reference: https://stackoverflow.com/questions/64834655/python-how-to-find-first-duplicated-items-in-an-numpy-array
        _, idx = np.unique(entry_arr[:, 1], return_index=True) # 先从每个知识点中选出1题出来
        bool_idx = np.zeros_like(entry_arr[:, 1], dtype=bool)
        bool_idx[idx] = True
        preserved_exers = np.unique(entry_arr[bool_idx, 0]) # 选择符合条件的习题作为保留

        delete_num = int(ratio * self.data_cfg['dt_info']['item_count'])
        preserved_num = self.data_cfg['dt_info']['item_count'] - delete_num

        if len(preserved_exers) >= preserved_num:
            # 保留的数量超过删除数量，就算了
            self.logger.warning(
                f"Cant Satisfy Delete Require: {len(preserved_exers)=},{preserved_num=}"
            )
        else:
            # 还需要保留一些习题的知识点
            need_preserved_num = preserved_num - len(preserved_exers)

            left_iids = np.arange(self.data_cfg['dt_info']['item_count'])
            left_iids = left_iids[~np.isin(left_iids, preserved_exers)]
            np.random.shuffle(left_iids)
            choose_iids = left_iids[0:need_preserved_num]

            preserved_exers = np.hstack([preserved_exers, choose_iids])

        return self.df_Q.copy()[self.df_Q['iid'].isin(preserved_exers)].reset_index(drop=True)


    def get_Q_mat_from_df_arr(self, df_Q_arr, item_count, cpt_count):
        Q_mat = torch.zeros((item_count, cpt_count), dtype=torch.int64)
        for _, item in df_Q_arr.iterrows():
            for cpt_id in item['cpt_list']: Q_mat[item['iid'], cpt_id] = 1
        return Q_mat


    def fill_Q_by_sim_dist(self, df_interaction, df_Q_left, params_topk=5, params_votek=2):
        preserved_exers = df_Q_left['iid'].to_numpy()
        user_count = self.cfg.data_cfg['dt_info']['user_count']
        item_count = self.cfg.data_cfg['dt_info']['item_count']
        cpt_count = self.cfg.data_cfg['dt_info']['cpt_count']

        interact_mat = torch.zeros((user_count, item_count), dtype=torch.int8).to(self.environ_cfg['device'])
        idx = df_interaction[df_interaction['label'] == 1][['uid','iid']].to_numpy()
        interact_mat[idx[:,0], idx[:,1]] = 1  # 输入数据做对为-1
        idx = df_interaction[df_interaction['label'] != 1][['uid','iid']].to_numpy()
        interact_mat[idx[:,0], idx[:,1]] = -1 # 输入数据做错为-1

        interact_mat = interact_mat.T

        # 1. 计算相似度矩阵
        sim_mat = torch.zeros((item_count, item_count))
        missing_iids = np.array(list(set(np.arange(item_count)) - set(preserved_exers)))
        for iid in tqdm(missing_iids, desc="[FILL_Q_MAT] compute sim_mat"):
            # 计算当前item与其他所有item的编辑距离
            temp = interact_mat[iid] != 0
            same_mat =  interact_mat[iid] == interact_mat
            bool_mat = (temp) & (interact_mat != 0) # 选择双方都做过的题目
            same_mat[~bool_mat] = False
            sim_mat[iid] = same_mat.sum(dim=1) / (temp).sum()
            sim_mat[iid, bool_mat.sum(dim=1) == 0] = 0.0
            sim_mat[iid, iid] = -1.0
            sim_mat[iid, missing_iids] = -1.0

        assert torch.isnan(sim_mat).sum() == 0
        # 2. 根据相似度矩阵给未保留的习题，填充知识点
        # 暂定：knn相似个习题，将知识点频数排序，取前2个

        topk_mat_val, topk_mat_idx = torch.topk(sim_mat, dim=1, k=params_topk, largest=True, sorted=True)
        topk_mat_idx = tensor2npy(topk_mat_idx)

        index_df_Q = df_Q_left.set_index('iid')
        missing_iid_fill_cpts = {}
        for iid in tqdm(missing_iids, desc="[FILL_Q_MAT] fill process"):
            count_dict = dict(zip(*np.unique(
                list(chain(*[index_df_Q.loc[iid2]['cpt_list'] for iid2 in topk_mat_idx[iid] if iid2 in preserved_exers])),
                return_counts=True,
            )))
            count_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
            missing_iid_fill_cpts[iid] = [i[0] for i in count_dict[0:params_votek]]

        missing_fill_df_Q = pd.DataFrame(
            {'iid': list(missing_iid_fill_cpts.keys()),'cpt_list':list(missing_iid_fill_cpts.values())}
        )
        final_df_Q = pd.concat([df_Q_left, missing_fill_df_Q], axis=0, ignore_index=True)

        hit_ratio = 0
        t_Q = self.df_Q.set_index('iid')
        for iid in missing_iid_fill_cpts:
            if len(set(t_Q.loc[iid]['cpt_list']) & set(missing_iid_fill_cpts[iid])) > 0:
                hit_ratio += 1
        hit_ratio = hit_ratio / len(missing_iid_fill_cpts)

        self.logger.info(f"[FILL_Q] Hit_ratio={hit_ratio}")

        return final_df_Q
    