#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2025/4/21 08:34
# @Author  : hb
# @File    : data_load.py
import os

import numpy as np
import pandas as pd
from torch.utils.data import Dataset

from data_trans import filter_datas


def split_generator(data_dir,skill_num, problem_num,group_num, data_file_dir=None, feature_names=None, label_names=None, sample_num=-1, max_len=-1,drop_remains=False):
    if data_file_dir is None:
        data_file_dir = data_dir
    if max_len > 0:
        data_paths = filter_datas(data_dir,max_len, data_file_dir,drop_remains)
    else:
        data_paths = [os.path.join(data_file_dir, "train.pkl"),
                      os.path.join(data_file_dir, "valid.pkl"),
                      os.path.join(data_file_dir, "test.pkl")]
    datasets = [KtDataSet(f, skill_num, problem_num,group_num, feature_names, label_names,
                          samples=sample_num) for
                f in
                data_paths]
    if max_len > 0:
        datasets = [d.padding(max_len) for d in datasets]
    return datasets
class KtDataSet(Dataset):
    def __init__(self, pklfile_or_dataframe, skill_num, problem_num,group_num, feature_names=None, label_names=None,
                 samples=-1, problem_map_data=None, skill_map_data=None, **kwargs):
        if feature_names is None:
            feature_names = ["skill", "skill_response", "problem_response", "problem","group"]
        if label_names is None:
            label_names = ["correct"]
        if isinstance(pklfile_or_dataframe, (str,)) and pklfile_or_dataframe.endswith(".pkl"):
            df = pd.read_pickle(pklfile_or_dataframe)
        else:
            df = pklfile_or_dataframe
        if samples > 0:
            if samples >= 1:
                df = df.sample(int(samples))
            else:
                df = df.sample(frac=samples)
        if isinstance(label_names, str):
            label_names = [label_names]
        else:
            label_names = list(label_names)
        if isinstance(feature_names, str):
            feature_names = [feature_names]
        else:
            feature_names = list(feature_names)

        self.feature_names = feature_names
        self.label_names = label_names
        self.seq = df
        self.skill_num = skill_num
        self.problem_num = problem_num
        self.group_num=group_num
        self.position = -1
        self.size = len(self.seq)
        self.skill_map_data = skill_map_data
        self.problem_map_data = problem_map_data
        self._init_data()

    def padding(self, max_len, pad_in_end=True):
        columns = set(self.seq.columns)
        for column in columns:
            if pad_in_end:
                self.seq[column] = self.seq[column].apply(
                    lambda x: np.pad(x, (0, max_len - len(x)), constant_values=-1) if isinstance(x, np.ndarray) else x)
            else:
                self.seq[column] = self.seq[column].apply(
                    lambda x: np.pad(x, (max_len - len(x), 0), constant_values=-1) if isinstance(x, np.ndarray) else x)
        return self

    def _init_data(self):
        for k in set(self.feature_names + self.label_names) - set(self.seq.columns):
            if k in self.skill_map_data.keys():
                values = self.skill_map_data[k]
                self.seq[k] = self.seq["skill"].apply(lambda x: np.array([values[i] for i in x]))
            if k in self.problem_map_data.keys():
                values = self.problem_map_data[k]
                self.seq[k] = self.seq["problem"].apply(lambda x: np.array([values[i] for i in x]))
            if k == "skill_response":
                self.seq[k] = self.seq[["skill", "correct"]].apply(lambda x: x[0] + x[1] * self.skill_num, axis=1)
            if k == "problem_response":
                self.seq[k] = self.seq[["problem", "correct"]].apply(lambda x: x[0] + x[1] * self.problem_num, axis=1)

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        row = self.seq.iloc[index]
        self.position += 1
        x_datas = [row[arg] for arg in self.feature_names]
        y_datas = [row[arg] for arg in self.label_names]
        if len(x_datas) > 1:
            x_datas = tuple(x_datas)
        else:
            x_datas = x_datas[0]
        if len(y_datas) > 1:
            y_datas = tuple(y_datas)
        else:
            y_datas = y_datas[0]
        return x_datas, y_datas