# Data Loaders for NeuroBench Challenge

import numpy as np
import torch
import stork
from neurobench.datasets.primate_reaching import PrimateReaching
from neurobench.datasets.utils import download_url
from urllib.error import URLError

import os

import logging

from .data import get_dataloader
from .getMiceData import get_dataloader_Mice
from .getC05data import get_dataloader_C05, DatasetLoader_C05
from .getMAZEdata import get_dataloader_MAZE, DatasetLoader_MAZE
from .getRTTdata import get_dataloader_RTT, DatasetLoader_RTT
from .getB04data import get_dataloader_B04, DatasetLoader_B04
from .getPOYOdata import get_dataloader_POYO, DatasetLoader_POYO
from .data_foundationDemo1 import get_dataloader_foundation, DatasetLoader_foundation
import torch.nn.functional as F

logger = logging.getLogger(__name__)
SAMPLING_RATE = 4e-3

def get_dataloader_foundation_crossSet(cfg, dtype=torch.float32, Dataset="all"):

    dataloader = DatasetLoader(
        basepath=cfg.data.home_dir,
        setpath=cfg.data.set_dir,
        ratio_val=cfg.data.ratio_val,
        random_val=cfg.data.random_val,
        extend_data=cfg.data.extend_data,
        sample_duration=cfg.data.sample_duration,
        remove_segments_inactive=cfg.data.remove_segments_inactive,
        p_drop=cfg.data.p_drop,
        p_insert=cfg.data.p_insert,
        jitter_sigma=cfg.data.jitter_sigma,
        dtype=dtype,
        dt = cfg.data.dt,
        predict_value = cfg.predict_value,
        Dataset=Dataset,
        zscore=cfg.data.zscore,
        testFlag=cfg.testFlag,
        continuous_trial=cfg.data.continuous_trial,
        mix_continuous_uncontinuous=cfg.data.mix_continuous_uncontinuous,
        padding=cfg.data.padding,
    )

    return dataloader

def compute_input_firing_rates(data, cfg, nb_inputs):

    mean1 = 0
    mean2 = 0

    if nb_inputs==16:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :nb_inputs]) / cfg.data.sample_duration / nb_inputs
            mean1 /= len(data)
            return mean1, None
    else:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :96]) / cfg.data.sample_duration / 96
            try:
                mean2 += torch.sum(data[i][0][:, 96:]) / cfg.data.sample_duration / 96
            except:
                continue

        mean1 /= len(data)
        mean2 /= len(data)

        # For LOCO
        if data[0][0].shape[1] == 192:
            return mean1, mean2

        # FOR INDY
        else:
            return mean1, None

class DatasetLoader:
    """Loads the data from the PrimateReaching dataset and splits it into train, val and test sets. The train and valid sets are split into samples of a given length, while the test set is kept as a single sample. The data is returned as a tuple of stork RasDatasets. This datasets can then be used as usual with the stork StandardGenerator."""
    def __init__(
        self,
        basepath,
        setpath,
        num_steps=1,
        dt=0.004,
        ratio_val=0.25,
        biological_delay=0,
        spike_sorting=False,
        label_series=False,
        random_val=False,
        extend_data=True,
        sample_duration=2,
        bin_width=None,
        stride=None,
        remove_segments_inactive=False,
        dtype=torch.float32,
        p_drop=0.0,
        p_insert=0.0,
        jitter_sigma=0.0,
        predict_value='velocity',
        Dataset="all",
        zscore=True,
        testFlag=False,
        continuous_trial=True,
        mix_continuous_uncontinuous=False,
        padding="zeros",
    ):
        """Initialize

        Args:
            basepath (str): the path to the data folder
            num_steps (int, optional): Argument for the Neurobench dataloader. Should be 1. Defaults to 1.
            dt (float, optional): Time step, should be 0.004 for the monkey data. Defaults to 0.004.
            ratio_val (list, optional): Ratio for validation set. Defaults to 0.25
            biological_delay (int, optional): Delay of readout w.r.t input. Defaults to 0.
            spike_sorting (bool, optional): If True, using single unit activities, otherwise multi unit activities. Defaults to False.
            label_series (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            random_val (bool, optional): If True, samples the validation samples randomly from the train data. Otherwise takes the last samples. Defaults to False.
            extend_data (bool, optional): If true, extends the data to overlapping samples. Defaults to True.
            sample_duration (int, optional): sample duration in seconds. Defaults to 2.
            bin_width (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            stride (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            remove_segments_inactive (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            dtype (_type_, optional): The dtype of the datasets. Defaults to torch.float32.
        """
        # 获取当前服务器的特征（例如检查主目录）
        home_dir = os.path.expanduser("~")
        import sys
        if sys.platform == "win32":
            print("Windows")
            basepath = basepath.replace("/home/User/Data_dir/", "D:/PPPPProject/SNN_Environment/")
        elif sys.platform == "linux":
            print("Linux")
            if home_dir == "/home2/User":
                # 如果是在服务器2上，路径不需要做改变
                basepath = basepath.replace("/home/", "/home2/")
        if home_dir == "/home2/User":
            # 如果是在服务器2上，路径不需要做改变
            basepath=basepath.replace("/home/", "/home2/")

        self.basepath = basepath
        self.num_steps = num_steps
        self.dt = dt
        self.ratio_val = ratio_val
        self.biological_delay = biological_delay
        self.spike_sorting = spike_sorting
        self.label_series = label_series
        self.random_val = random_val
        self.extend_data = extend_data
        self.sample_duration = sample_duration
        self.remove_segments_inactive = remove_segments_inactive
        self.dtype = dtype
        self.p_drop = p_drop
        self.p_insert = p_insert
        self.jitter_sigma = jitter_sigma
        self.predict_value = predict_value
        self.zscore = zscore
        self.testFlag = testFlag
        self.continuous_trial = continuous_trial
        self.mix_continuous_uncontinuous = mix_continuous_uncontinuous
        self.padding = padding

        if bin_width is None:
            self.bin_width = self.dt
        else:
            self.bin_width = bin_width
        if stride is None:
            self.stride = self.dt
        else:
            self.stride = stride

        self.n_time_steps = int(sample_duration / dt)

        self.dataloader = {}
        self.Dataset = Dataset

        if self.Dataset=="all" :
            self.dataloader["indy_and_loco"] = DatasetLoader_foundation(
                basepath=os.path.join(self.basepath, setpath['indy_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                testFlag=self.testFlag,
                padding=self.padding,
            )
            self.dataloader["C05"] = DatasetLoader_C05(
                basepath=os.path.join(self.basepath, setpath['C05_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                predict_value=self.predict_value,
                testFlag=self.testFlag,
                continuous_trial=self.continuous_trial,
                mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
                padding=self.padding,
            )
            self.dataloader["B04"] = DatasetLoader_B04(
                basepath=os.path.join(self.basepath, setpath['B04_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                predict_value=self.predict_value,
                testFlag=self.testFlag,
                # continuous_trial=self.continuous_trial,
                # mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
                padding=self.padding,
            )
            self.dataloader["MAZE"] = DatasetLoader_MAZE(
                basepath=os.path.join(self.basepath, setpath['MAZE_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                predict_value=self.predict_value,
                testFlag=self.testFlag,
                continuous_trial=self.continuous_trial,
                mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
            )
            self.dataloader["RTT"] = DatasetLoader_RTT(
                basepath=os.path.join(self.basepath, setpath['RTT_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                # predict_value=self.predict_value,
                testFlag=self.testFlag,
                padding=self.padding,
                # continuous_trial=self.continuous_trial,
                # mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
            )
            self.dataloader["POYO"] = DatasetLoader_POYO(
                basepath=os.path.join(self.basepath, setpath['POYO_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                predict_value=self.predict_value,
                testFlag=self.testFlag,
                # continuous_trial=self.continuous_trial,
                mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
                padding=self.padding,
            )
        elif self.Dataset=="indy_and_loco":
            self.dataloader["indy_and_loco"] = DatasetLoader_foundation(
                basepath=os.path.join(self.basepath, setpath['indy_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                testFlag=self.testFlag,
                padding=self.padding,
            )
        elif self.Dataset=="MAZE":
            self.dataloader["MAZE"] = DatasetLoader_MAZE(
                basepath=os.path.join(self.basepath, setpath['MAZE_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                predict_value=self.predict_value,
                testFlag=self.testFlag,
                continuous_trial=self.continuous_trial,
                mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
            )
        elif self.Dataset=="C05":
            self.dataloader["C05"] = DatasetLoader_C05(
                basepath=os.path.join(self.basepath, setpath['C05_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                predict_value=self.predict_value,
                testFlag=self.testFlag,
                continuous_trial=self.continuous_trial,
                mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
                padding=self.padding,
            )
        elif self.Dataset=="RTT":
            self.dataloader["RTT"] = DatasetLoader_RTT(
                basepath=os.path.join(self.basepath, setpath['RTT_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                # predict_value=self.predict_value,
                testFlag=self.testFlag,
                padding=self.padding,
                # continuous_trial=self.continuous_trial,
                # mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
            )
        elif self.Dataset=="B04":
            self.dataloader["B04"] = DatasetLoader_B04(
                basepath=os.path.join(self.basepath, setpath['B04_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                predict_value=self.predict_value,
                testFlag=self.testFlag,
                # continuous_trial=self.continuous_trial,
                # mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
                padding=self.padding,
            )
        elif self.Dataset=="POYO":
            self.dataloader["POYO"] = DatasetLoader_POYO(
                basepath=os.path.join(self.basepath, setpath['POYO_data_dir']),
                ratio_val=self.ratio_val,
                random_val=self.random_val,
                extend_data=self.extend_data,
                sample_duration=self.sample_duration,
                remove_segments_inactive=self.remove_segments_inactive,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                jitter_sigma=self.jitter_sigma,
                dtype=self.dtype,
                dt=self.dt,
                predict_value=self.predict_value,
                testFlag=self.testFlag,
                # continuous_trial=self.continuous_trial,
                mix_continuous_uncontinuous=self.mix_continuous_uncontinuous,
                padding=self.padding,
            )

    def get_multiple_set_data(self, filenames, nb_inputs, with_S1):

        ds_train, ds_valid, ds_test = [], [], []

        if with_S1:
            print("with_S1 is True, using S1 data")

        for key in filenames:
            if key == "indy" or key == "loco":
                for fileidx,filename in filenames[key].items():
                    monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
                        self.dataloader["indy_and_loco"].get_single_session_data(
                            filename, nb_inputs=nb_inputs, with_S1=with_S1, zscore=self.zscore
                        )
                    )
                    ds_train.append(monkey_ds_train)
                    ds_valid.append(monkey_ds_valid)
                    ds_test.append(monkey_ds_test)
            elif key == "C05":
                for fileidx,filename in filenames[key].items():
                    monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
                        self.dataloader["C05"].get_single_session_data(filename, nb_inputs=nb_inputs, zscore=self.zscore)
                    )
                    ds_train.append(monkey_ds_train)
                    ds_valid.append(monkey_ds_valid)
                    ds_test.append(monkey_ds_test)
            elif key == "MAZE":
                for fileidx,filename in filenames[key].items():
                    monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
                        self.dataloader["MAZE"].get_single_session_data(
                            filename, nb_inputs=nb_inputs, with_PMd=with_S1, zscore=self.zscore
                        )
                    )
                    ds_train.append(monkey_ds_train)
                    ds_valid.append(monkey_ds_valid)
                    ds_test.append(monkey_ds_test)
            elif key == "RTT":
                for fileidx,filename in filenames[key].items():
                    monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
                        self.dataloader["RTT"].get_single_session_data(
                            filename, nb_inputs=nb_inputs, zscore=self.zscore
                        )
                    )
                    ds_train.append(monkey_ds_train)
                    ds_valid.append(monkey_ds_valid)
                    ds_test.append(monkey_ds_test)
            elif key == "Sub_T":
                for fileidx,filename in filenames[key].items():
                    monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
                        self.dataloader["POYO"].get_single_session_data(filename, nb_inputs=nb_inputs, zscore=self.zscore)
                    )
                    ds_train.append(monkey_ds_train)
                    ds_valid.append(monkey_ds_valid)
                    ds_test.append(monkey_ds_test)
            else:
                raise ValueError(f"Unknown key {key}.")

        dataset_train = torch.utils.data.ConcatDataset(ds_train)
        dataset_valid = torch.utils.data.ConcatDataset(ds_valid)

        return dataset_train, dataset_valid, ds_test


    def get_single_session_data(self, filename, monkeyname, nb_inputs):
        if monkeyname == "indy" or monkeyname == "loco":
            train_dat, val_dat, test_dat = (
                self.dataloader["indy_and_loco"].get_single_session_data(
                    filename,
                    nb_inputs=nb_inputs,
                    with_S1=False,
                    zscore=self.zscore,
                )
            )
        elif monkeyname == "MAZE":
            train_dat, val_dat, test_dat = (
                self.dataloader["MAZE"].get_single_session_data(
                    filename,
                    nb_inputs=nb_inputs,
                    with_PMd=False,
                    zscore=self.zscore,
                )
            )
        elif monkeyname == "RTT":
            train_dat, val_dat, test_dat = (
                self.dataloader["RTT"].get_single_session_data(
                    filename,
                    nb_inputs=nb_inputs,
                    with_S1=False,
                    zscore=self.zscore,
                )
            )
        elif monkeyname == "C05":
            train_dat, val_dat, test_dat = (
                self.dataloader["C05"].get_single_session_data(filename, nb_inputs=nb_inputs, zscore=self.zscore)
            )
        elif monkeyname == "B04":
            train_dat, val_dat, test_dat = (
                self.dataloader["B04"].get_single_session_data(filename, nb_inputs=nb_inputs, zscore=self.zscore)
            )
        elif monkeyname == "Sub_T":
            train_dat, val_dat, test_dat = (
                self.dataloader["POYO"].get_single_session_data(filename, nb_inputs=nb_inputs, zscore=self.zscore)
            )
        else:
            raise ValueError(f"Unknown monkeyname {monkeyname}.")
        # if self.mix_continuous_uncontinuous:
        #     if monkeyname == "MAZE" or monkeyname == "RTT":
        #         # 获取连续 trial 数据(已为500步，无需补0)
        #         self.dataloader["MAZE"].continuous_trial=True
        #         self.dataloader["MAZE"].sample_duration = 2
        #         self.dataloader["MAZE"].n_time_steps = self.n_time_steps = int(2 / self.dt)
        #         train_dat_conti, val_dat_conti, test_dat_conti = (
        #                     self.dataloader["MAZE"].get_single_session_data(
        #                         filename,
        #                         nb_inputs=nb_inputs,
        #                         with_PMd=False,
        #                         zscore=self.zscore,
        #                     )
        #                 )
        #
        #         # 获取1秒数据并补0到500步
        #         self.dataloader["MAZE"].continuous_trial=False
        #         self.dataloader["MAZE"].sample_duration = 1
        #         self.dataloader["MAZE"].n_time_steps = self.n_time_steps = int(1 / self.dt)
        #         train_dat_1, val_dat_1, test_dat_1 = (
        #                     self.dataloader["MAZE"].get_single_session_data(
        #                         filename,
        #                         nb_inputs=nb_inputs,
        #                         with_PMd=False,
        #                         zscore=self.zscore,
        #                     )
        #                 )
        #         # 补0处理
        #         max_steps = int(2 / self.dt)  # 2秒数据的时间步数
        #         train_dat_1 = pad_sequences_to_max(train_dat_1, max_steps)
        #         val_dat_1 = pad_sequences_to_max(val_dat_1, max_steps)
        #
        #         self.dataloader["MAZE"].sample_duration = 0.7
        #         self.dataloader["MAZE"].n_time_steps = self.n_time_steps = int(0.7 / self.dt)
        #         train_dat_07, val_dat_07, test_dat_07 = (
        #                     self.dataloader["MAZE"].get_single_session_data(
        #                         filename,
        #                         nb_inputs=nb_inputs,
        #                         with_PMd=False,
        #                         zscore=self.zscore,
        #                     )
        #                 )
        #         # 补0处理
        #         train_dat_07 = pad_sequences_to_max(train_dat_07, max_steps)
        #         val_dat_07 = pad_sequences_to_max(val_dat_07, max_steps)
        #
        #         self.dataloader["MAZE"].sample_duration = 0.4
        #         self.dataloader["MAZE"].n_time_steps = self.n_time_steps = int(0.4 / self.dt)
        #         train_dat_04, val_dat_04, test_dat_04 = (
        #                     self.dataloader["MAZE"].get_single_session_data(
        #                         filename,
        #                         nb_inputs=nb_inputs,
        #                         with_PMd=False,
        #                         zscore=self.zscore,
        #                     )
        #                 )
        #         # 补0处理
        #         train_dat_04 = pad_sequences_to_max(train_dat_04, max_steps)
        #         val_dat_04 = pad_sequences_to_max(val_dat_04, max_steps)
        #
        #         train_dat=[]
        #         val_dat=[]
        #         train_dat.append(train_dat_conti)
        #         train_dat.append(train_dat_1)
        #         train_dat.append(train_dat_07)
        #         train_dat.append(train_dat_04)
        #         val_dat.append(val_dat_conti)
        #         val_dat.append(val_dat_1)
        #         val_dat.append(val_dat_07)
        #         val_dat.append(val_dat_04)
        #         train_dat = torch.utils.data.ConcatDataset(train_dat)
        #         val_dat = torch.utils.data.ConcatDataset(val_dat)
        #         test_dat = test_dat_conti
        #
        #     elif monkeyname == "C05":
        #
        #         # 获取连续 trial 数据(已为500步，无需补0)
        #         self.dataloader["C05"].continuous_trial=True
        #         self.dataloader["C05"].sample_duration = 2
        #         self.dataloader["C05"].n_time_steps = self.n_time_steps = int(2 / self.dt)
        #         train_dat_conti, val_dat_conti, test_dat_conti = (
        #                     self.dataloader["C05"].get_single_session_data(filename,zscore=self.zscore)
        #                 )
        #
        #         # 获取1秒数据并补0到500步
        #         self.dataloader["C05"].continuous_trial=False
        #         self.dataloader["C05"].sample_duration = 1
        #         self.dataloader["C05"].n_time_steps = self.n_time_steps = int(1 / self.dt)
        #         train_dat_1, val_dat_1, test_dat_1 = (
        #                     self.dataloader["C05"].get_single_session_data(filename,zscore=self.zscore)
        #                 )
        #         # 补0处理
        #         max_steps = int(2 / self.dt)  # 2秒数据的时间步数
        #         train_dat_1 = pad_sequences_to_max(train_dat_1, max_steps)
        #         val_dat_1 = pad_sequences_to_max(val_dat_1, max_steps)
        #
        #         # 获取0.7秒数据并补0到500步
        #         self.dataloader["C05"].sample_duration = 0.7
        #         self.dataloader["C05"].n_time_steps = self.n_time_steps = int(0.7 / self.dt)
        #         train_dat_07, val_dat_07, test_dat_07 = (
        #                     self.dataloader["C05"].get_single_session_data(filename,zscore=self.zscore)
        #                 )
        #         # 补0处理
        #         train_dat_07 = pad_sequences_to_max(train_dat_07, max_steps)
        #         val_dat_07 = pad_sequences_to_max(val_dat_07, max_steps)
        #
        #         # 获取0.4秒数据并补0到500步
        #         self.dataloader["C05"].sample_duration = 0.4
        #         self.dataloader["C05"].n_time_steps = self.n_time_steps = int(0.4 / self.dt)
        #         train_dat_04, val_dat_04, test_dat_04 = (
        #                     self.dataloader["C05"].get_single_session_data(filename,zscore=self.zscore)
        #                 )
        #         # 补0处理
        #         train_dat_04 = pad_sequences_to_max(train_dat_04, max_steps)
        #         val_dat_04 = pad_sequences_to_max(val_dat_04, max_steps)
        #
        #         train_dat=[]
        #         val_dat=[]
        #         train_dat.append(train_dat_conti)
        #         train_dat.append(train_dat_1)
        #         train_dat.append(train_dat_07)
        #         train_dat.append(train_dat_04)
        #         val_dat.append(val_dat_conti)
        #         val_dat.append(val_dat_1)
        #         val_dat.append(val_dat_07)
        #         val_dat.append(val_dat_04)
        #         train_dat = torch.utils.data.ConcatDataset(train_dat)
        #         val_dat = torch.utils.data.ConcatDataset(val_dat)
        #         test_dat = test_dat_conti
        #
        #     else:
        #         raise ValueError(f"Unknown monkeyname {monkeyname}.")
        # else:
        #     if monkeyname == "indy" or monkeyname == "loco":
        #         train_dat, val_dat, test_dat = (
        #                     self.dataloader["indy_and_loco"].get_single_session_data(
        #                         filename,
        #                         nb_inputs=nb_inputs,
        #                         with_S1=False,
        #                         zscore=self.zscore,
        #                     )
        #                 )
        #     elif monkeyname == "MAZE" or monkeyname == "RTT":
        #         train_dat, val_dat, test_dat = (
        #                     self.dataloader["MAZE"].get_single_session_data(
        #                         filename,
        #                         nb_inputs=nb_inputs,
        #                         with_PMd=False,
        #                         zscore=self.zscore,
        #                     )
        #                 )
        #     elif monkeyname == "C05":
        #         train_dat, val_dat, test_dat = (
        #                     self.dataloader["C05"].get_single_session_data(filename,zscore=self.zscore)
        #                 )
        #     else:
        #         raise ValueError(f"Unknown monkeyname {monkeyname}.")


        return train_dat, val_dat, test_dat

    def get_multiple_set_data_divide_set(self, filenames, nb_inputs, with_S1, save_path=None):
        ds_train, ds_valid, ds_test = [], [], []
        div_ds_train, div_ds_valid, div_ds_test = {}, {}, {}

        div_ds, dataset_all = self.load_multiple_set_data_divide_set(self.basepath,Flag = 'divide')

        # if with_S1:
        #     print("with_S1 is True, using S1 data")
        #
        # for key in filenames:
        #     div_ds_train[key] = []
        #     div_ds_valid[key] = []
        #     div_ds_test[key] = []
        #
        #     if key == "indy" or key == "loco":
        #         for fileidx,filename in filenames[key].items():
        #             monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
        #                 self.dataloader["indy_and_loco"].get_single_session_data(
        #                     filename, nb_inputs=nb_inputs, with_S1=with_S1, zscore=self.zscore
        #                 )
        #             )
        #             ds_train.append(monkey_ds_train)
        #             ds_valid.append(monkey_ds_valid)
        #             ds_test.append(monkey_ds_test)
        #             div_ds_train[key].append(monkey_ds_train)
        #             div_ds_valid[key].append(monkey_ds_valid)
        #             div_ds_test[key].append(monkey_ds_test)
        #         div_ds_train[key] = torch.utils.data.ConcatDataset(div_ds_train[key])
        #         div_ds_valid[key] = torch.utils.data.ConcatDataset(div_ds_valid[key])
        #         div_ds_test[key] = torch.utils.data.ConcatDataset(div_ds_test[key])
        #
        #
        #     elif key == "MAZE":
        #         for fileidx,filename in filenames[key].items():
        #             monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
        #                 self.dataloader["MAZE"].get_single_session_data(
        #                     filename, nb_inputs=nb_inputs, with_PMd=with_S1, zscore=self.zscore
        #                 )
        #             )
        #             ds_train.append(monkey_ds_train)
        #             ds_valid.append(monkey_ds_valid)
        #             ds_test.append(monkey_ds_test)
        #             div_ds_train[key].append(monkey_ds_train)
        #             div_ds_valid[key].append(monkey_ds_valid)
        #             div_ds_test[key].append(monkey_ds_test)
        #         div_ds_train[key] = torch.utils.data.ConcatDataset(div_ds_train[key])
        #         div_ds_valid[key] = torch.utils.data.ConcatDataset(div_ds_valid[key])
        #         div_ds_test[key] = torch.utils.data.ConcatDataset(div_ds_test[key])
        #
        #     elif key == "RTT":
        #         for fileidx,filename in filenames[key].items():
        #             monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
        #                 self.dataloader["RTT"].get_single_session_data(
        #                     filename, nb_inputs=nb_inputs, with_S1=with_S1, zscore=self.zscore
        #                 )
        #             )
        #             ds_train.append(monkey_ds_train)
        #             ds_valid.append(monkey_ds_valid)
        #             ds_test.append(monkey_ds_test)
        #             div_ds_train[key].append(monkey_ds_train)
        #             div_ds_valid[key].append(monkey_ds_valid)
        #             div_ds_test[key].append(monkey_ds_test)
        #         div_ds_train[key] = torch.utils.data.ConcatDataset(div_ds_train[key])
        #         div_ds_valid[key] = torch.utils.data.ConcatDataset(div_ds_valid[key])
        #         div_ds_test[key] = torch.utils.data.ConcatDataset(div_ds_test[key])
        #
        #     elif key == "C05":
        #         for fileidx,filename in filenames[key].items():
        #             monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
        #                 self.dataloader["C05"].get_single_session_data(filename, nb_inputs=nb_inputs, zscore=self.zscore)
        #             )
        #             ds_train.append(monkey_ds_train)
        #             ds_valid.append(monkey_ds_valid)
        #             ds_test.append(monkey_ds_test)
        #             div_ds_train[key].append(monkey_ds_train)
        #             div_ds_valid[key].append(monkey_ds_valid)
        #             div_ds_test[key].append(monkey_ds_test)
        #         div_ds_train[key] = torch.utils.data.ConcatDataset(div_ds_train[key])
        #         div_ds_valid[key] = torch.utils.data.ConcatDataset(div_ds_valid[key])
        #         div_ds_test[key] = torch.utils.data.ConcatDataset(div_ds_test[key])
        #
        #     else:
        #         raise ValueError(f"Unknown key {key}.")
        #
        # dataset_all_train = torch.utils.data.ConcatDataset(ds_train)
        # dataset_all_valid = torch.utils.data.ConcatDataset(ds_valid)
        # dataset_all_test = torch.utils.data.ConcatDataset(ds_test)
        #
        # div_ds={}
        # div_ds["div_ds_train"]=div_ds_train
        # div_ds["div_ds_valid"]=div_ds_valid
        # div_ds["div_ds_test"]=div_ds_test
        #
        # dataset_all={}
        # dataset_all["dataset_all_train"]=dataset_all_train
        # dataset_all["dataset_all_valid"]=dataset_all_valid
        # dataset_all["dataset_all_test"]=dataset_all_test
        #
        # # 可选参数，用于保存数据集到本地
        # if save_path is not None:
        #     try:
        #         os.makedirs(save_path, exist_ok=True)
        #         logger.info(f"正在保存数据集到 {save_path}")
        #         torch.save(div_ds, os.path.join(save_path, "div_ds.pt"))
        #         torch.save(dataset_all, os.path.join(save_path, "dataset_all.pt"))
        #         logger.info("数据集保存完成")
        #     except Exception as e:
        #         logger.error(f"保存数据集时发生错误: {e}")

        return div_ds, dataset_all

    def load_multiple_set_data_divide_set(self, load_path, Flag):
        """
        从本地加载已保存的数据集

        Args:
            load_path (str): 保存数据集的路径

        Returns:
            tuple: (div_ds, dataset_all) 两个包含数据集的字典
        """
        if Flag == 'divide':
            try:
                logger.info(f"正在从 {load_path} 加载数据集")

                div_ds_path = os.path.join(load_path, "div_ds.pt")
                dataset_all_path = os.path.join(load_path, "dataset_all.pt")

                if not os.path.exists(div_ds_path) or not os.path.exists(dataset_all_path):
                    raise FileNotFoundError(f"在 {load_path} 中未找到数据集文件")

                # 方法1：使用 weights_only=False（如果你信任数据源）
                div_ds = torch.load(div_ds_path, weights_only=False)
                dataset_all = torch.load(dataset_all_path, weights_only=False)

                # 方法2：将 ConcatDataset 添加到安全对象列表中
                # from torch.utils.data import ConcatDataset
                # import torch.serialization
                # torch.serialization.add_safe_globals([ConcatDataset])
                # div_ds = torch.load(div_ds_path)
                # dataset_all = torch.load(dataset_all_path)

                logger.info("数据集加载完成")
                return div_ds, dataset_all
            except Exception as e:
                logger.error(f"加载数据集时发生错误: {e}")
                raise
        elif Flag == 'multi':
            try:
                logger.info(f"正在从 {load_path} 加载数据集")

                ds_path = os.path.join(load_path, "dataset.pt")

                if not os.path.exists(ds_path):
                    raise FileNotFoundError(f"在 {load_path} 中未找到数据集文件")

                # 方法1：使用 weights_only=False（如果你信任数据源）
                dataset = torch.load(ds_path, weights_only=False)


                logger.info("数据集加载完成")
                return dataset
            except Exception as e:
                logger.error(f"加载数据集时发生错误: {e}")
                raise
        else:
            raise ValueError(f"Unknown Flag {Flag}. Please use 'divide' or 'multi'.")

def pad_sequences_to_max(data, max_length):
    """将序列数据补0到最大长度
    Args:
        data: 输入数据，形状为(样本数, 时间步数, 特征数)
        max_length: 目标最大时间步数
    Returns:
        补0后的序列数据
    """
    padded_spike = []
    padded_label = []

    for sample in data:
        spike=sample[0]
        label=sample[1]
        # 计算需要补0的长度
        pad_length = max_length - spike.shape[0]
        if pad_length > 0:
            # 在时间维度(第0维)末尾补0
            padded_spike.append(F.pad(spike, (0, 0, 0, pad_length)))
            padded_label.append(F.pad(label, (0, 0, 0, pad_length)))
        else:
            raise ValueError("Sample length exceeds max_length. Please adjust max_length.")
    padded_spike=torch.stack(padded_spike)
    padded_label=torch.stack(padded_label)

    padded_data=to_ras(padded_spike, padded_label)

    return padded_data

def to_ras(data, labels, **data_augmentation_kwargs):
    ras_data = [[[], []] for _ in data]

    for i, sample in enumerate(data):
        for j in range(sample.shape[-1]):
            spike_times = np.where(sample[:, j] == 1)[0].tolist()
            ras_data[i][0] += spike_times
            ras_data[i][1] += [j] * len(spike_times)
        ras_data[i] = torch.tensor(ras_data[i], dtype=data.dtype)

    monkey_ds_kwargs = dict(
        nb_steps=data.shape[-2], nb_units=data.shape[-1], time_scale=1.0
    )

    monkey_ds = stork.datasets.RasDataset(
        (ras_data, labels), dtype=data.dtype,
        **monkey_ds_kwargs, **data_augmentation_kwargs
    )

    return monkey_ds


