"""
This file contains code from PyTorch Vision (https://github.com/pytorch/vision) which is licensed under BSD 3-Clause License.
These snippets are the Copyright (c) of Soumith Chintala 2016. All other code is the Copyright (c) of the NeuroBench Developers 2023.
"""

from torch.utils.data import Dataset
import os
import sys
import torch
import math
import numpy as np
import h5py
from scipy.signal import convolve2d
from urllib.error import URLError
import matplotlib.pyplot as plt

import stork

import logging
import pickle

logger = logging.getLogger(__name__)
# The spikes recorded in the Primate Reaching datasets have an interval of 4ms.
# SAMPLING_RATE = 4e-3

import hydra
from omegaconf import DictConfig
from hydra.utils import to_absolute_path
from pathlib import Path
import torch.nn.functional as F
from nlb_tools.nwb_interface import NWBDataset

from nlb_tools.nwb_interface import NWBDataset
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

def get_dataloader_RTT(cfg, dtype=torch.float32):
    if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
        output_feedback=cfg.model.output_feedback
    elif hasattr(cfg.model, "self_and_crossAttention") and cfg.model.self_and_crossAttention:
        output_feedback = cfg.model.self_and_crossAttention
    else:
        output_feedback=False
    if hasattr(cfg, "session_classfication") and cfg.session_classfication:
        session_classfication=cfg.session_classfication
    else:
        session_classfication=False

    dataloader = DatasetLoader_RTT(
        basepath=cfg.data.data_dir,
        ratio_val=cfg.data.ratio_val,
        random_val=cfg.data.random_val,
        extend_data=cfg.data.extend_data,
        sample_duration=cfg.data.sample_duration,
        remove_segments_inactive=cfg.data.remove_segments_inactive,
        p_drop=cfg.data.p_drop,
        p_insert=cfg.data.p_insert,
        jitter_sigma=cfg.data.jitter_sigma,
        dtype=dtype,
        output_feedback=output_feedback,
        session_classfication=session_classfication,
        testFlag=cfg.testFlag,
        padding=cfg.data.padding,
    )

    return dataloader


def compute_input_firing_rates(data, cfg, nb_inputs):
    mean1 = 0
    mean2 = 0

    if nb_inputs == 16:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :nb_inputs]) / cfg.data.sample_duration / nb_inputs
            mean1 /= len(data)
            return mean1, None
    else:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :96]) / cfg.data.sample_duration / 96
            try:
                mean2 += torch.sum(data[i][0][:, 96:]) / cfg.data.sample_duration / 96
            except:
                continue

        mean1 /= len(data)
        mean2 /= len(data)


class DatasetLoader_RTT:
    """Loads the data from the PrimateReaching dataset and splits it into train, val and test sets. The train and valid sets are split into samples of a given length, while the test set is kept as a single sample. The data is returned as a tuple of stork RasDatasets. This datasets can then be used as usual with the stork StandardGenerator."""

    def __init__(
            self,
            basepath,
            num_steps=1,
            dt=0.004,
            ratio_val=0.25,
            biological_delay=0,
            spike_sorting=False,
            label_series=False,
            random_val=False,
            extend_data=True,
            sample_duration=2,
            bin_width=None,
            stride=None,
            remove_segments_inactive=False,
            dtype=torch.float32,
            p_drop=0.0,
            p_insert=0.0,
            jitter_sigma=0.0,
            output_feedback=False,
            session_classfication=False,
            testFlag=False,
            padding="zeros",
    ):
        """Initialize

        Args:
            basepath (str): the path to the data folder
            filename (str): the name of the specific data file to load
            num_steps (int, optional): Argument for the Neurobench dataloader. Should be 1. Defaults to 1.
            dt (float, optional): Time step, should be 0.004 for the monkey data. Defaults to 0.004.
            ratio_val (list, optional): Ratio for validation set. Defaults to 0.25
            biological_delay (int, optional): Delay of readout w.r.t input. Defaults to 0.
            spike_sorting (bool, optional): If True, using single unit activities, otherwise multi unit activities. Defaults to False.
            label_series (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            random_val (bool, optional): If True, samples the validation samples randomly from the train data. Otherwise takes the last samples. Defaults to False.
            extend_data (bool, optional): If true, extends the data to overlapping samples. Defaults to True.
            sample_duration (int, optional): sample duration in seconds. Defaults to 2.
            bin_width (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            stride (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            remove_segments_inactive (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            dtype (_type_, optional): The dtype of the datasets. Defaults to torch.float32.
        """
        # 获取当前服务器的特征（例如检查主目录）
        home_dir = os.path.expanduser("~")
        if sys.platform == "win32":
            print("Windows")
            basepath = basepath.replace("/home/User/Data_dir/", "H:/User/")
        elif sys.platform == "linux":
            print("Linux")
            if home_dir == "/home2/User":
                # 如果是在服务器2上，路径不需要做改变
                basepath = basepath.replace("/home/", "/home2/")

        self.basepath = basepath
        self.num_steps = num_steps
        self.dt = dt
        self.ratio_val = ratio_val
        self.biological_delay = biological_delay
        self.spike_sorting = spike_sorting
        self.label_series = label_series
        self.random_val = random_val
        self.extend_data = extend_data
        self.sample_duration = sample_duration
        self.remove_segments_inactive = remove_segments_inactive
        self.dtype = dtype
        self.p_drop = p_drop
        self.p_insert = p_insert
        self.jitter_sigma = jitter_sigma
        self.output_feedback = output_feedback
        self.session_classfication = session_classfication
        self.testFlag = testFlag
        self.padding = padding

        if bin_width is None:
            self.bin_width = self.dt
        else:
            self.bin_width = bin_width
        if stride is None:
            self.stride = self.dt
        else:
            self.stride = stride

        self.n_time_steps = int(sample_duration / dt)

    def get_single_session_data(self, filename, nb_inputs, with_S1=False, only_S1=False, zscore=False, session_code=None):
        dataset = PrimateReaching(
            file_path=self.basepath,
            filename=filename,
            num_steps=self.num_steps,
            train_ratio=0.5,  # Hardcoded here for 25 % test split
            bin_width=self.bin_width,
            biological_delay=self.biological_delay,
            remove_segments_inactive=self.remove_segments_inactive,
        )

        # If we want to remove inactive segments, we need to load the data again
        # with remove_segments_inactive=False for the test set
        if self.remove_segments_inactive:
            dataset_test = PrimateReaching(
                file_path=self.basepath,
                filename=filename,
                num_steps=self.num_steps,
                train_ratio=0.5,  # Hardcoded here for 25 % test split
                bin_width=self.bin_width,
                biological_delay=self.biological_delay,
                remove_segments_inactive=False,
            )
        else:
            dataset_test = dataset

        """Loads data of a single session and returns a tuple of stork RasDatasets containing the train, val and test data.

        Returns:
            tuple of stork RasDatasets: Train, val and test datasets
        """

        # Sum train & validation data (75 %) and make own validation split ##合并验证集和测试集
        ind_tv = dataset.ind_train + dataset.ind_val

        # Effective validation ratio = val_ratio / 0.75  ## 验证集的有效比例，重新划分验证集
        eff_ratio_val = self.ratio_val / 0.75

        n_val = int(np.round(len(dataset) * eff_ratio_val))  ##　验证集的大小

        ## 如果 self.random_val 为 True，则从 ind_tv 中随机选择 n_val 个样本作为验证集，并将剩余的样本作为训练集。
        ## 否则，按照顺序分割训练集和验证集。
        if self.random_val:
            start_idx = np.random.choice(a=ind_tv[:-n_val], size=1)[0]
            ind_val = np.array(ind_tv[start_idx: start_idx + n_val])
            ind_train = np.array(sorted(set(ind_tv) - set(ind_val)))

        else:
            ind_train = np.array(ind_tv[:-n_val])
            ind_val = np.array(sorted(set(ind_tv) - set(ind_train)))

        spikes = dataset.samples.T  ## T代表转置（Transpose）
        spikes_testdat = dataset_test.samples.T
        labels = dataset.labels.T
        labels_testdat = dataset_test.labels.T
        if zscore:
            # z-score标准化
            labels = (labels - labels.mean()) / labels.std()
            labels_testdat = (labels_testdat - labels_testdat.mean()) / labels_testdat.std()

        self.ind_train = ind_train
        self.ind_val = ind_val
        self.ind_test = dataset_test.ind_test

        # split into train, val and test  ## 数据分割成训练集、验证集和测试集
        if with_S1:
            if spikes.shape[1] == 96:
                spikes_train = spikes[ind_train][:, 0:nb_inputs]
                spikes_val = spikes[ind_val][:, 0:nb_inputs]
                spikes_test = spikes_testdat[dataset_test.ind_test][:, 0:nb_inputs]

                labels_train = labels[ind_train]
                labels_val = labels[ind_val]
                labels_test = labels_testdat[dataset_test.ind_test]

            elif spikes.shape[1] == 192:
                spikes_train = torch.cat((spikes[ind_train][:, :nb_inputs], spikes[ind_train][:, nb_inputs:]))
                spikes_val = torch.cat((spikes[ind_val][:, :nb_inputs], spikes[ind_val][:, nb_inputs:]))
                spikes_test = torch.cat((spikes_testdat[dataset_test.ind_test][:, :nb_inputs], spikes_testdat[dataset_test.ind_test][:, nb_inputs:]))

                labels_train = torch.cat((labels[ind_train], labels[ind_train]))
                labels_val = torch.cat((labels[ind_val], labels[ind_val]))
                labels_test = torch.cat((labels_testdat[dataset_test.ind_test], labels_testdat[dataset_test.ind_test]))

            else:
                raise ValueError("data dimension err!")

        elif only_S1:
            spikes_train = spikes[ind_train][:, nb_inputs:]
            spikes_val = spikes[ind_val][:, nb_inputs:]
            spikes_test = spikes_testdat[dataset_test.ind_test][:, nb_inputs:]

            labels_train = labels[ind_train]
            labels_val = labels[ind_val]
            labels_test = labels_testdat[dataset_test.ind_test]

        else:
            if nb_inputs == 192 and spikes.shape[1] == 96:

                if self.padding == "zeros":
                    print("Padding with zeros...")
                    zeros_train = torch.zeros_like(spikes[ind_train])
                    zeros_val = torch.zeros_like(spikes[ind_val])
                    zeros_test = torch.zeros_like(spikes_testdat[dataset_test.ind_test])

                    spikes_train = torch.cat((spikes[ind_train], zeros_train), dim=1)
                    spikes_val = torch.cat((spikes[ind_val], zeros_val), dim=1)
                    spikes_test = torch.cat((spikes_testdat[dataset_test.ind_test], zeros_test), dim=1)

                    labels_train = labels[ind_train]
                    labels_val = labels[ind_val]
                    labels_test = labels_testdat[dataset_test.ind_test]
                else:
                    print("Padding with copying...")
                    spikes_train = torch.cat((spikes[ind_train], spikes[ind_train]), dim=1)
                    spikes_val = torch.cat((spikes[ind_val], spikes[ind_val]), dim=1)
                    spikes_test = torch.cat((spikes_testdat[dataset_test.ind_test], spikes_testdat[dataset_test.ind_test]), dim=1)

                    labels_train = labels[ind_train]
                    labels_val = labels[ind_val]
                    labels_test = labels_testdat[dataset_test.ind_test]

            else:
                spikes_train = spikes[ind_train][:, 0:nb_inputs]
                spikes_val = spikes[ind_val][:, 0:nb_inputs]
                spikes_test = spikes_testdat[dataset_test.ind_test][:, 0:nb_inputs]

                labels_train = labels[ind_train]
                labels_val = labels[ind_val]
                labels_test = labels_testdat[dataset_test.ind_test]

        if self.output_feedback:
            shifted_data = torch.zeros_like(labels_train)
            shifted_data[1:] = labels_train[:-1]
            # shifted_data[0] = shifted_data[1]
            spikes_train = torch.cat((spikes_train, shifted_data), dim=1)

        if session_code is not None:
            if self.session_classfication:
                # 将session code补到labels中
                session_code = torch.tensor(session_code, dtype=self.dtype).unsqueeze(0)
                session_code_train = session_code.repeat(labels_train.shape[0], 1)
                session_code_val = session_code.repeat(labels_val.shape[0], 1)
                session_code_test = session_code.repeat(labels_test.shape[0], 1)

                labels_train = torch.cat((labels_train, session_code_train), dim=1)
                labels_val = torch.cat((labels_val, session_code_val), dim=1)
                labels_test = torch.cat((labels_test, session_code_test), dim=1)
            else:
                # 将session code补到spikes中
                session_code = [int(char) for char in session_code]
                session_code = torch.tensor(session_code, dtype=self.dtype).unsqueeze(0)

                session_code_train = session_code.repeat(spikes_train.shape[0], 1)
                session_code_val = session_code.repeat(spikes_val.shape[0], 1)
                session_code_test = session_code.repeat(spikes_test.shape[0], 1)

                spikes_train = torch.cat((spikes_train, session_code_train), dim=1)
                spikes_val = torch.cat((spikes_val, session_code_val), dim=1)
                spikes_test = torch.cat((spikes_test, session_code_test), dim=1)

        if self.testFlag:
            spikes_test = spikes_test[0:3600, :]
            labels_test = labels_test[0:3600, :]

        # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
        if self.extend_data:
            logger.info("Extending data...")
            train_data, train_labels = self.extend_spikes(
                spikes_train, labels_train, self.n_time_steps, chunksize=int(self.n_time_steps / 5)
            )
            val_data, val_labels = self.extend_spikes(
                spikes_val, labels_val, self.n_time_steps, chunksize=int(self.n_time_steps / 5)
            )
        else:
            train_data, train_labels = self.extend_spikes(
                spikes_train, labels_train, chunks=99
            )
            val_data, val_labels = self.extend_spikes(
                spikes_val, labels_val, chunks=99
            )

        test_data = [spikes_test]  ## 封装成列表
        test_labels = [labels_test]

        test_data = torch.stack(test_data)
        test_labels = torch.stack(test_labels)

        # Get augmentation kwargs for training dataset
        if any([self.p_drop > 0, self.p_insert > 0, self.jitter_sigma > 0]):

            data_augmentation_kwargs = dict(
                data_augmentation=True,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                sigma_t=self.jitter_sigma
            )
        else:
            data_augmentation_kwargs = {}

        # make it ras datasets
        if self.output_feedback:
            train_ras_data = self.to_dataset(train_data, train_labels,
                                             **data_augmentation_kwargs)
        else:
            train_ras_data = self.to_ras(train_data, train_labels,
                                         **data_augmentation_kwargs)

        val_ras_data = self.to_ras(val_data, val_labels)
        test_ras_data = self.to_ras(test_data, test_labels)

        return train_ras_data, val_ras_data, test_ras_data

    def get_multiple_sessions_data(self, filenames, nb_inputs, with_S1=False, zscore=False, session_codes=None):
        """Loads data from multiple sessions and concatenates them into a single dataset (split in train, test and validation).

        Args:
            filenames (list): List of filenames to load (all files should be in the folder specified by basepath    )

        Returns:
            tuple of stork RasDatasets for train and validation and a list of test dataset (one dataset for each session)
        """

        ds_train, ds_valid, ds_test = [], [], []

        if session_codes is not None:
            for filename, session_code in zip(filenames, session_codes):
                monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
                    self.get_single_session_data(filename, nb_inputs=nb_inputs, with_S1=with_S1, zscore=zscore, session_code=session_code)
                )
                ds_train.append(monkey_ds_train)
                ds_valid.append(monkey_ds_valid)
                ds_test.append(monkey_ds_test)
        else:
            for filename in filenames:
                monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
                    self.get_single_session_data(filename, nb_inputs=nb_inputs, with_S1=with_S1, zscore=zscore)
                )
                ds_train.append(monkey_ds_train)
                ds_valid.append(monkey_ds_valid)
                ds_test.append(monkey_ds_test)

        dataset_train = torch.utils.data.ConcatDataset(ds_train)
        dataset_valid = torch.utils.data.ConcatDataset(ds_valid)

        return dataset_train, dataset_valid, ds_test

    def extend_spikes(self, spikes, labels, chunks="all", chunksize=100):
        """Given spike data and labels of the shape [time x neuron], it cuts it into overlapping samples of shape [samples x n_time_steps x neuron]"""
        ##  将给定的脉冲数据（spikes）和标签（labels）按时间切割成重叠的样本，并将其转换为一个三维的数组，方便后续处理和训练。重叠数：chunks

        if chunks == "all":
            chunks = self.n_time_steps

        extended_spikes = []
        extended_labels = []

        for t in range(0, chunks, chunksize):
            curr_spikes = spikes[t:]
            curr_labels = labels[t:]

            splitter = np.arange(
                self.n_time_steps, curr_spikes.shape[0], self.n_time_steps
            )

            extended_spikes += np.split(curr_spikes, splitter)[:-1]
            extended_labels += np.split(curr_labels, splitter)[:-1]

        ## 张量拼接
        extended_spikes = torch.stack(extended_spikes)
        extended_labels = torch.stack(extended_labels)

        return extended_spikes, extended_labels

    def to_ras(self, data, labels, **data_augmentation_kwargs):
        ras_data = [[[], []] for _ in data]

        for i, sample in enumerate(data):
            for j in range(sample.shape[-1]):
                spike_times = np.where(sample[:, j] == 1)[0].tolist()
                ras_data[i][0] += spike_times
                ras_data[i][1] += [j] * len(spike_times)
            ras_data[i] = torch.tensor(ras_data[i], dtype=self.dtype)

        monkey_ds_kwargs = dict(
            nb_steps=data.shape[-2], nb_units=data.shape[-1], time_scale=1.0
        )

        monkey_ds = stork.datasets.RasDataset(
            (ras_data, labels), dtype=self.dtype,
            **monkey_ds_kwargs, **data_augmentation_kwargs
        )

        return monkey_ds

    def to_dataset(self, data, labels, **data_augmentation_kwargs):

        # 保持原始的脉冲结构或直接使用原始数据
        processed_data = data.clone().to(dtype=self.dtype)

        # 使用PyTorch的TensorDataset封装
        dataset = torch.utils.data.TensorDataset(processed_data, labels.to(dtype=self.dtype))

        return dataset














class PrimateReaching():


    def __init__(
        self,
        file_path,
        filename,
        num_steps,
        train_ratio=0.8,
        label_series=False,
        biological_delay=0,
        spike_sorting=False,
        stride=0.004,
        bin_width=0.028,
        max_segment_length=2000,
        split_num=1,
        remove_segments_inactive=False,
        download=True,
        predict_value='velocity',
    ):

        ##　设置数据路径 MAZE和其他的数据存储格式不一样
        self.task = "RTT"
        self.filename = filename if filename[-13:] == "/sub-Indy/" else filename + "/sub-Indy/"
        self.file_path = os.path.join(file_path, self.filename)

        # test filepath
        assert os.path.exists(self.file_path)

        # related to processing of spike data
        self.spike_sorting = spike_sorting
        self.delay = biological_delay
        self.stride = stride
        self.bin_width = bin_width
        self.num_steps = num_steps
        self.train_ratio = train_ratio
        self.label_series = label_series
        self.SAMPLING_RATE = stride
        self.ratio = int(np.round(self.bin_width / self.SAMPLING_RATE))

        # test parameters
        assert self.delay >= 0
        assert self.stride >= self.SAMPLING_RATE
        assert (
                self.bin_width >= self.SAMPLING_RATE
        ), "The binning window has to be greater than the sampling size (i.e. 0.004s)"
        assert self.num_steps >= 1
        assert 0 <= self.train_ratio <= 1

        # Defines the beginning and end of each segment.
        self.start_end_indices = None
        self.time_segments = None

        # Defines the maximum length of a segment.
        self.max_segment_length = max_segment_length
        assert self.max_segment_length >= 0

        self.split_num = split_num

        # These lists store the index of segments that belongs to training/validation/test set
        self.ind_train, self.ind_val, self.ind_test = [], [], []

        self.input_feature_size = 96

        self.load_data()

        if self.delay > 0:
            self.apply_delay()

        if remove_segments_inactive and self.max_segment_length > 0:
            self.valid_segments = self.remove_segments_by_length()
        else:
            self.valid_segments = np.arange(self.time_segments.shape[0])

        self.split_data()
        return

    def __len__(self):
        return len(self.ind_train) + len(self.ind_test) + len(self.ind_val)

    def __getitem__(self, idx):
        """Getter method of the dataloader."""
        # compute indices of congruent binning windows
        mask = idx - np.arange(self.num_steps) * self.ratio
        if self.label_series:
            samples = self.samples[:, mask].transpose(0, 1)
            labels = self.labels[:, mask].transpose(0, 1)
            return samples, labels
        else:
            return self.samples[:, mask].transpose(0, 1), self.labels[:, idx]

    def load_data(self):
        """Load the data from the matlab file and spike data if spike data has been
        processed and stored already."""
        # Assume input is the original dataset, instead of the reconstructed one
        print(f"Loading {self.filename}")
        dataset = NWBDataset(self.file_path, "*train", split_heldout=False)
        spikes, cursor_pos, target_pos, t = self.get_vel_and_spike(dataset, self.SAMPLING_RATE)


        # # extract data from datafile
        # spikes = dataset["spikes"][
        #     ()
        # ]  # Get the reference object's locations in the HDF5/mat file
        # cursor_pos = dataset["cursor_pos"][()]
        # target_pos = dataset["target_pos"][()]
        # t = np.squeeze(dataset["t"][()])
        new_t = np.arange(t[0] - self.bin_width, t[-1], self.SAMPLING_RATE)

        # Define the segments' start & end indices
        self.start_end_indices = np.array(self.get_flag_index(target_pos))
        self.time_segments = np.array(
            self.split_into_segments(self.start_end_indices, target_pos.shape[1])
        )

        assert self.spike_sorting==False, "The RTT dataset does not have spike sorting, so this argument should be False."
        assert self.ratio == 1, "The RTT dataset does not have binning, so this argument should be 1."

        # Dimensions: (channels x timesteps)
        self.samples = torch.from_numpy(spikes).float()
        # Dimensions: (nr_features x timesteps)
        self.labels = torch.from_numpy(cursor_pos).float()

        # # convert position to velocity 注意这边已经是速度了，所以不需要再求导
        # self.labels = torch.gradient(self.labels, dim=1)[0]

        return

    def apply_delay(self):
        """Shift the labels by the delay to account for the biological delay between
        spikes and movement onset."""
        # Dimension: No_of_Channels*No_of_Records
        self.samples = self.samples[:, : -self.delay]
        self.labels = self.labels[:, self.delay :]

    def split_data(self):
        """Split segments into training/validation/test set."""
        # This is No. of chunks
        split_num = self.split_num
        total_segments = self.time_segments.shape[0]
        sub_length = int(
            total_segments / split_num
        )  # This is no of segments in each chunk
        stride = int(self.stride / self.SAMPLING_RATE)
        # print(total_segments, sub_length)

        train_len = math.floor(self.train_ratio * sub_length)
        val_len = math.floor((sub_length - train_len) / 2)

        # offset = int(np.round(self.bin_width / SAMPLING_RATE)) * self.num_steps
        offset = 0

        # split the data into 4 equal parts
        # for each part, split the data according to training, testing and validation split
        for split_no in range(split_num):
            for i in range(sub_length):
                # Each segment's Dimension is: No_of_Probes * No_of_Recording
                if i < train_len and i in self.valid_segments:
                    self.ind_train += list(
                        np.arange(
                            offset + self.time_segments[split_no * sub_length + i, 0],
                            self.time_segments[split_no * sub_length + i, 1],
                            stride,
                        )
                    )
                elif train_len <= i < train_len + val_len and i in self.valid_segments:
                    self.ind_val += list(
                        np.arange(
                            offset + self.time_segments[split_no * sub_length + i, 0],
                            self.time_segments[split_no * sub_length + i, 1],
                            stride,
                        )
                    )
                elif i in self.valid_segments:
                    self.ind_test += list(
                        np.arange(
                            offset + self.time_segments[split_no * sub_length + i, 0],
                            self.time_segments[split_no * sub_length + i, 1],
                            stride,
                        )
                    )

    def remove_segments_by_length(self):
        """Remove the segments where its duration exceeds the limit set by
        max_segment_length."""
        return np.nonzero(
            self.time_segments[:, 1] - self.time_segments[:, 0]
            < self.max_segment_length
        )[0]

    @staticmethod
    def split_into_segments(indices, last_idx):
        """Combine the start and end index into a NumPy array."""
        indices = np.insert(indices, 0, 0)
        indices = np.append(indices, [last_idx])
        start_end = np.array([indices[:-1], indices[1:]])

        return np.transpose(start_end)

    @staticmethod
    def get_flag_index(target_pos):
        """Find where each segment begins and ends."""
        target_diff = np.diff(
            target_pos, axis=1, append=target_pos[:, -1].reshape(2, 1)
        )

        indices = np.nonzero(np.sum(np.abs(target_diff), axis=0))[0]

        return indices

    @staticmethod
    def get_vel_and_spike(raw_dataset, dt):

        dt = int(dt*1000) # 将单位从s转为ms
        lag = 0
        # lag = 120
        lag_bins = int(round(lag / raw_dataset.bin_width))
        nans = raw_dataset.data.finger_vel.x.isna().reset_index(drop=True)
        spike = raw_dataset.data.spikes[~nans.to_numpy() & ~nans.shift(-lag_bins, fill_value=True).to_numpy()]
        vel = raw_dataset.data.finger_vel[~nans.to_numpy() & ~nans.shift(lag_bins, fill_value=True).to_numpy()]
        target = raw_dataset.data.target_pos[~nans.to_numpy() & ~nans.shift(lag_bins, fill_value=True).to_numpy()]
        vel_index = raw_dataset.data.finger_vel[~nans.to_numpy() & ~nans.shift(lag_bins, fill_value=True).to_numpy()].index

        # 获取所有神经元通道名称
        channels = spike.columns

        # 创建电极编号字典 (1-96)
        electrode_spikes = {i: np.zeros(len(spike)) for i in range(1, 97)}

        # 合并相同电极的神经元数据
        for chan in channels:
            chan_str = str(chan)
            # 解析电极编号
            if len(chan_str) == 3:  # 3位编号的电极
                electrode_id = int(chan_str[0])
            elif len(chan_str) == 4:  # 4位编号的电极
                electrode_id = int(chan_str[:2])
            else:
                continue

            # 累加该神经元的发放数据到对应电极
            if 1 <= electrode_id <= 96:
                electrode_spikes[electrode_id] += spike[chan].values
            else:
                print("woc")

        # 将电极数据转换为DataFrame
        merged_spikes = pd.DataFrame(electrode_spikes)
        merged_spikes.index = spike.index  # 保留原始时间索引

        # 对速度数据进行截断，确保长度能被dt整除
        truncated_length = len(spike) // dt * dt
        vel_trunc = vel.iloc[:truncated_length]
        vel_index_trunc = vel_index[:truncated_length]
        spike_trunc = merged_spikes.iloc[:truncated_length]
        target_trunc = target.iloc[:truncated_length]

        # 降采样函数（针对不同类型数据）
        def downsample_data(data, dt, dtype='spike'):
            """
            将数据降采样到dtms分辨率

            参数:
            data: 输入数据 (DataFrame或Series)
            dtype: 数据类型 ('spike' 或 'velocity')

            返回:
            降采样后的数据
            """
            # 将数据分组，每dt个时间步为一组
            groups = np.arange(len(data)) // dt

            if dtype == 'spike':
                # 对于spike数据：只要组内有至少一个spike，就记为1
                downsampled = data.groupby(groups).max()
            else:
                # 对于速度数据：取组内平均值
                downsampled = data.groupby(groups).last()
                # downsampled = data.groupby(groups).mean()

            # 更新时间索引
            new_time = [vel_index_trunc[i * dt] for i in range(len(downsampled))]
            downsampled.index = pd.Index(new_time, name='clock_time')

            return downsampled

        # 对spike数据进行二值化降采样
        downsampled_spikes = downsample_data(spike_trunc, dt, 'spike')

        # 对速度数据进行降采样
        downsampled_vel = downsample_data(vel_trunc, dt, 'velocity')
        downsampled_target = downsample_data(target_trunc, dt, 'velocity')

        # 更新vel_index为降采样后的时间点
        downsampled_vel_index = downsampled_spikes.index

        # # 验证结果
        # print(f"原始spike数据长度: {len(spike)}, 降采样后: {len(downsampled_spikes)}")
        # print(f"原始速度数据长度: {len(vel)}, 降采样后: {len(downsampled_vel)}")
        # print(f"降采样spike数据值范围: {downsampled_spikes.min().min()} - {downsampled_spikes.max().max()}")
        # print(f"降采样速度数据示例:")
        # print(downsampled_vel.head())

        spikes_array = downsampled_spikes.values  # 转换为NumPy数组
        spikes_tensor = torch.tensor(spikes_array, dtype=torch.float32)  # 转换为PyTorch张量
        vel_array = downsampled_vel.values  # 转换为NumPy数组
        vel_tensor = torch.tensor(vel_array, dtype=torch.float32)  # 转换为PyTorch张量
        target_array = downsampled_vel.values  # 转换为NumPy数组
        target_tensor = torch.tensor(vel_array, dtype=torch.float32)  # 转换为PyTorch张量

        time_in_seconds = downsampled_vel.index.total_seconds().to_numpy()

        return spikes_array.T, vel_array.T, target_array.T, time_in_seconds


# 添加画轨迹的函数
def plot_trajectory(labels_data, num_steps=4000, title="Trajectory"):
    """
    画出标签数据的轨迹

    Args:
        labels_data: 标签数据，形状为 [时间步, 特征维度]
        num_steps: 要绘制的时间步数，默认4000
        title: 图像标题
    """
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D

    # 随机选择时间步
    total_steps = labels_data.shape[0]
    if num_steps > total_steps:
        num_steps = total_steps

    # 随机选择起始点
    start_idx = np.random.randint(0, max(1, total_steps - num_steps))
    end_idx = start_idx + num_steps

    trajectory = labels_data[start_idx:end_idx]

    # 创建图像
    fig = plt.figure(figsize=(12, 8))

    # 如果是2D数据
    if trajectory.shape[1] == 2:
        # 创建子图：轨迹图和时间序列图
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

        # 第一个子图：2D轨迹
        ax1.plot(trajectory[:, 0], trajectory[:, 1], 'b-', alpha=0.7, linewidth=1)
        ax1.scatter(trajectory[0, 0], trajectory[0, 1], c='green', s=100, label='Start', marker='o')
        ax1.scatter(trajectory[-1, 0], trajectory[-1, 1], c='red', s=100, label='End', marker='x')
        ax1.set_xlabel('X Position')
        ax1.set_ylabel('Y Position')
        ax1.set_title(f'{title} (2D Trajectory, {num_steps} steps)')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # 第二个子图：X维度时间序列
        time_steps = np.arange(num_steps)
        ax2.plot(time_steps, trajectory[:, 0], 'r-', alpha=0.7, linewidth=1)
        ax2.set_xlabel('Time Steps')
        ax2.set_ylabel('X Position')
        ax2.set_title(f'{title} (X Dimension)')
        ax2.grid(True, alpha=0.3)

        # 第三个子图：Y维度时间序列
        ax3.plot(time_steps, trajectory[:, 1], 'g-', alpha=0.7, linewidth=1)
        ax3.set_xlabel('Time Steps')
        ax3.set_ylabel('Y Position')
        ax3.set_title(f'{title} (Y Dimension)')
        ax3.grid(True, alpha=0.3)


    # 如果是1D数据
    else:
        time_steps = np.arange(num_steps)
        plt.plot(time_steps, trajectory[:, 0], 'b-', alpha=0.7, linewidth=1)
        plt.xlabel('Time Steps')
        plt.ylabel('Value')
        plt.title(f'{title} (1D, {num_steps} steps)')
        plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()