#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time: 2025/9/9 19:48
# @Author: hb925
# @File: data_trans.py
import argparse
import collections
import json
import os
import random
import re
import shutil

import numpy as np
import pandas as pd


def _gen_fold_index(data_lens, folds):
    fold_size = data_lens // folds
    remains = data_lens - fold_size * folds
    folds_set = []
    for i in range(folds):
        folds_set.extend([i for _ in range(fold_size)])
        print(f'fold:{i},len:{fold_size}')
    if remains > 0:
        folds_set.extend([int(k % 5) for k in range(remains)])
    random.shuffle(folds_set)
    random.shuffle(folds_set)
    random.shuffle(folds_set)
    return folds_set


def filter_maxlength(all_data, max_len, min_len=3, drop_remains=False):
    assert {'user', "correct"} < set(all_data.columns)
    all_data["seq_len"] = all_data["correct"].apply(lambda x: len(x))

    def sub_seq(r):
        s_len = r["seq_len"]
        datas = []
        if s_len > max_len:
            if drop_remains:
                v = 1
            else:
                v = s_len // max_len
                v = v if (s_len % max_len) < min_len else v + 1
            for i in range(v):
                j = i * max_len
                d = {}
                for ind in r.index:
                    if not isinstance(r[ind], (list, tuple, np.ndarray)):
                        d[ind] = r[ind]
                    else:
                        d[ind] = r[ind][j:j + max_len]
                    if ind == "seq_len":
                        d[ind] = max_len
                datas.append(d)
            return pd.DataFrame(datas, columns=r.index)
        else:
            return pd.DataFrame([r])

    filter_index = all_data["seq_len"] > max_len
    tmp_data = all_data[filter_index].apply(sub_seq, axis=1)
    seq = pd.concat(tmp_data.to_list() + [all_data[~filter_index]]).reset_index(drop=True)
    seq.drop(columns=["seq_len"], inplace=True)
    # dataset_info["name"] = f'{dataset_info["name"]}_l{max_len}'
    return seq


def filter_datas(data_dir, max_len, data_file_dir=None, drop_remains=False):
    if data_file_dir is None:
        data_file_dir = data_dir
    data_paths = [os.path.join(data_file_dir, "train.pkl"),
                  os.path.join(data_file_dir, "valid.pkl"),
                  os.path.join(data_file_dir, "test.pkl")]
    if any([not os.path.exists(f) for f in data_paths]):
        expand_split_data(data_file_dir, out_dir=data_file_dir)
    if max_len > 0:
        new_paths = [f.replace(".pkl", f'-{max_len}.pkl') for f in data_paths]
        for f, f1 in zip(data_paths, new_paths):
            if not os.path.exists(f1):
                data = pd.read_pickle(f)
                new_data = filter_maxlength(data, max_len, drop_remains)
                pd.to_pickle(new_data, f1)
    else:
        new_paths = data_paths
    return new_paths


##序列化存储数据集
def save_dataset(train_data, valid_data, test_data, data_info, out_dir, all_datas=None):
    os.makedirs(out_dir, exist_ok=True)
    if train_data is not None:
        train_data.to_pickle(os.path.join(out_dir, "train.pkl"))
    if valid_data is not None:
        valid_data.to_pickle(os.path.join(out_dir, "valid.pkl"))
    if test_data is not None:
        test_data.to_pickle(os.path.join(out_dir, "test.pkl"))
    if all_datas is not None:
        all_datas.to_pickle(os.path.join(out_dir, "all.pkl"))
    with open(os.path.join(out_dir, "info.json"), "w", encoding="utf-8") as f:
        json.dump(data_info, f, indent=4)


def split_dataset(all_data, data_info, dataset_out_dir, test_frac=0.2, folds=5, keep_split=False,expand_split=False):
    dataset_info_split = {'test_frac': test_frac, 'folds': folds}
    old_split = None
    if keep_split:
        old_split_file = os.path.join(dataset_out_dir, "split_fold.pkl")
        if os.path.exists(old_split_file):
            old_split = pd.read_pickle(old_split_file)
        else:
            raise ValueError("no split info file")
    if old_split is not None:
        assert all(k in old_split.columns for k in ["user", "split", "fold"])
        assert "user" in old_split.columns
        assert len(all_data) == len(old_split)
        all_data = pd.merge(all_data, old_split, how="left", on="user", validate="one_to_one")
        train_data = all_data[all_data["split"] == 0]
        valid_data = all_data[all_data["split"] == 1]
        test_data = all_data[all_data["split"] == -1]
        new_dataset_info = data_info.copy()
        new_dataset_info.update(dataset_info_split)
    else:
        print("############### split train test ################")
        all_data = all_data.sample(frac=1.0)
        test_data = all_data.sample(frac=test_frac)
        val_train_data = all_data[~all_data.index.isin(test_data.index)]
        valid_data = val_train_data.sample(frac=0.1)
        data_info.update(dataset_info_split)
        print("############### split train test over ################")
        if folds and folds > 0:
            print("############### split fold ################")
            data_lens = all_data.shape[0]
            folds_set = _gen_fold_index(data_lens, folds)
            all_data.loc[:, "fold"] = folds_set
            all_data = all_data.astype({"fold": np.int8})
        train_data = val_train_data[~val_train_data.index.isin(valid_data.index)]
        all_data.loc[train_data.index, "split"] = 0
        all_data.loc[valid_data.index, "split"] = 1
        all_data.loc[test_data.index, "split"] = -1
        all_data = all_data.astype({"split": np.int8})
        print("############### split fold over################")
        print("############### saving  dataname ################")
    if expand_split:
        save_dataset(train_data, valid_data, test_data, data_info, dataset_out_dir, all_data)
    else:
        save_dataset(None, None, None, data_info, dataset_out_dir, all_data)
    return data_info


def expand_kf_data(data_dir, target_dir=None, max_len=-1, min_len=3, drop_remains=False):
    print("############### expand kf data ################")
    if target_dir is None:
        target_dir = data_dir
    all_file_path = os.path.join(data_dir, "all.pkl")
    info_file_path = os.path.join(data_dir, "info.json")
    assert os.path.exists(all_file_path)
    assert os.path.exists(info_file_path)
    all_datas = pd.read_pickle(all_file_path)
    assert "fold" in all_datas.columns
    train_valid_data = all_datas
    with open(info_file_path, "r", encoding="utf-8") as f:
        dataset_info = json.load(f)
        folds = train_valid_data["fold"].astype(np.int32).value_counts().keys()
        for k in folds:
            fold_filter = train_valid_data["fold"] == k
            valid_data = train_valid_data[fold_filter].reset_index(drop=True)
            train_data = train_valid_data[~fold_filter].reset_index(drop=True)
            test_data = valid_data.copy()
            print(f'train:{len(train_data)},valid:{len(valid_data)},test:{len(test_data)}')
            sub_out_dir = os.path.join(target_dir, f'k{k}')
            os.makedirs(sub_out_dir, exist_ok=True)
            dataset_info_sub = dataset_info.copy()
            if max_len and max_len > 0:
                train_data = filter_maxlength(train_data, max_len, min_len, drop_remains)
                valid_data = filter_maxlength(valid_data, max_len, min_len, drop_remains)
                test_data = filter_maxlength(test_data, max_len, min_len, drop_remains)
            save_dataset(train_data, valid_data, test_data, dataset_info_sub, sub_out_dir)
    print("############### expand kf data over################")


def expand_split_data(data_dir, max_len=-1, min_len=3, out_dir=None):
    all_file_path = os.path.join(data_dir, "all.pkl")
    info_file_path = os.path.join(data_dir, "info.json")
    if out_dir is None or not os.path.exists(out_dir):
        out_dir = data_dir
    assert os.path.exists(all_file_path)
    assert os.path.exists(info_file_path)
    all_datas = pd.read_pickle(all_file_path)
    assert "split" in all_datas.columns
    with open(info_file_path, "r", encoding="utf-8") as f:
        dataset_info = json.load(f)
        train_data = all_datas[all_datas["split"] == 0].reset_index(drop=True)
        test_data = all_datas[all_datas["split"] == -1].reset_index(drop=True)
        valid_data = all_datas[all_datas["split"] == 1].reset_index(drop=True)

        if max_len and max_len > 0:
            train_data = filter_maxlength(train_data, max_len, min_len)
            valid_data = filter_maxlength(valid_data, max_len, min_len)
            test_data = filter_maxlength(test_data, max_len, min_len)
        if train_data is not None:
            train_data.to_pickle(os.path.join(out_dir, "train.pkl"))
        if valid_data is not None:
            valid_data.to_pickle(os.path.join(out_dir, "valid.pkl"))
        if test_data is not None:
            test_data.to_pickle(os.path.join(out_dir, "test.pkl"))
        with open(os.path.join(out_dir, "info.json"), "w", encoding="utf-8") as f:
            json.dump(dataset_info, f, indent=4)


##将原始数据转换成按学习者分组的序列并补充skill group信息
def to_seq_data(fn, dataset_dir, dataset_name, min_len=10, min_skill_inters=-1, min_problem_inters=-1, max_user_num=-1,
                remove_same_question=False, **kwargs):
    use_cols = {"user": np.int64, "problem": np.int64, "correct": np.int8}
    data_dir_parent = os.path.dirname(fn)
    question_skill_file = os.path.join(data_dir_parent, "question_skill.csv")
    question_group_file = os.path.join(data_dir_parent, "question_group.csv")

    df = pd.read_csv(fn, encoding="utf-8", encoding_errors="ignore", header=None, names=["user", "problem", "correct"],
                     low_memory=False)
    df = df.dropna(subset=use_cols.keys())
    if min_problem_inters > 0:
        df = df[df.groupby("problem")["problem"].transform('count').ge(min_problem_inters)]
    user_count = df["user"].value_counts()
    min_len_max_user_num = -1
    if 0 < max_user_num < user_count.size:
        min_len_max_user_num = user_count.iloc[max_user_num]
    min_len_1 = max(min_len, min_len_max_user_num)
    if min_len_1 > 0:
        df = df[df.groupby("user")["user"].transform('count').ge(min_len_1)]
    df = df.astype(use_cols, copy=False)
    ##获取试题对应的知识点(skill)ID
    if os.path.exists(question_skill_file):
        question_skill_data = pd.read_csv(question_skill_file, encoding="utf-8", encoding_errors="ignore", header=None,
                                          names=["problem", "skill"], low_memory=False)
        question_skill_map = question_skill_data.set_index("problem").to_dict(orient='dict')["skill"]
        df["skill"] = df["problem"].apply(lambda x: question_skill_map[x])
    else:
        df["skill"] = 0
    ##获取试题对应的组(group)ID
    if os.path.exists(question_group_file):
        question_group_data = pd.read_csv(question_group_file, encoding="utf-8", encoding_errors="ignore", header=None,
                                          names=["problem", "group"], low_memory=False)
        question_group_map = question_group_data.set_index("problem").to_dict(orient='dict')["group"]
        df["group"] = df["problem"].apply(lambda x: question_group_map[x])
    else:
        df["group"] = 0
    user_num = df["user"].max() + 1
    problem_num = df["problem"].max() + 1
    skill_num = df["skill"].max() + 1
    group_num = df["group"].max() + 1
    seq_cols = ["problem", "skill","group", "correct"]

    def proc_group(r):
        # r.sort_values(by=[ORDER_KEY], inplace=True, key=lambda x: pd.to_numeric(x, errors='coerce'))
        r = r.sort_index()
        values = [r[k].values for k in seq_cols]
        values.append(len(r))
        return pd.Series(values, index=seq_cols + ["seq_len"])

    seq = df.groupby('user').apply(
        proc_group
    )
    seq.reset_index(inplace=True)
    # seq = seq[seq["seq_len"] > min_len]
    max_len = seq["seq_len"].max()
    seq.drop(columns=["seq_len"], inplace=True)
    dataset_info = collections.OrderedDict(name=dataset_name,
                                           user_num=int(user_num),
                                           problem_num=int(problem_num),
                                           skill_num=int(skill_num),
                                           group_num=int(group_num),
                                           org_max_len=int(max_len),
                                           min_len=int(min_len),
                                           min_len_max_user_num=int(min_len_max_user_num),
                                           max_user_num=int(max_user_num),
                                           min_skill_inters=min_skill_inters,
                                           min_problem_inters=min_problem_inters)
    print("数据集基本信息:")
    print(dataset_info)
    with open(os.path.join(dataset_dir, "info.json"), "w", encoding="utf-8") as f:
        json.dump(dataset_info, f, indent=4)
    all_data_file = os.path.join(dataset_dir, "all.pkl")
    seq.to_pickle(all_data_file)
    print(f"data columns: {seq.columns}")
    return seq, dataset_info, dataset_dir

def clean_data_dir(data_dir,keep_split=False):
    file_list = os.listdir(data_dir)
    for file in file_list:
        file_path = os.path.join(data_dir, file)
        if keep_split and file=="all.pkl":
            continue
        if os.path.isfile(file_path):
            if file.endswith(".pkl") or file=="info.json":
                os.remove(file_path)
        else:
            shutil.rmtree(file_path)
parser = argparse.ArgumentParser(description='数据集生成')
parser.add_argument('data_name', nargs='?', type=str, help='数据集名称', default="assist2009")
parser.add_argument('--data_base', default='../dataset', help='data base  path')
parser.add_argument('--max_len', type=int, default=100, help='The max length of sequence')
parser.add_argument('--min_len', type=int, default=10, help='The min length of sequence')
parser.add_argument('--keep_split', action='store_true', default=False,
                    help='maintain the original division of the dataset')
if __name__ == "__main__":
    args = parser.parse_args()
    max_len = args.max_len
    data_name = args.data_name
    data_dir = os.path.join(os.path.abspath(args.data_base), data_name)
    data_file = os.path.join(data_dir, "record.csv")
    assert os.path.exists(data_file)
    question_skill_file = os.path.join(data_dir, "question_skill.csv")
    ##如果保持旧数据的划分信息，则需要将fold和split字段按学习者id合并到新数据
    if args.keep_split:
        old_data_file = os.path.join(data_dir, "all.pkl")
        assert os.path.exists(old_data_file)
        old_data = pd.read_pickle(old_data_file)
        if all(k in old_data.columns for k in ["user", "split", "fold"]):
            old_split = old_data[["user", "split", "fold"]]
            old_split.to_pickle(old_data_file.replace("all.pkl", "split_fold.pkl"))
        else:
            raise ValueError("all.pkl do not contain division info")
    clean_data_dir(data_dir,args.keep_split)
    all_data, data_info, dataset_out_dir = to_seq_data(data_file, data_dir, data_name,min_len=args.min_len)
    split_dataset(all_data, data_info, dataset_out_dir, test_frac=0.2, folds=5, keep_split=args.keep_split)
    # expand_split_data(data_dir, max_len=max_len, min_len=3, out_dir=None)
    # expand_kf_data(data_dir, max_len=max_len)
    # filter_datas(data_dir, max_len=max_len)
