"""
This file contains code from PyTorch Vision (https://github.com/pytorch/vision) which is licensed under BSD 3-Clause License.
These snippets are the Copyright (c) of Soumith Chintala 2016. All other code is the Copyright (c) of the NeuroBench Developers 2023.
"""

from torch.utils.data import Dataset
import os
import torch
import math
import numpy as np
import h5py
from scipy.signal import convolve2d
from urllib.error import URLError
import matplotlib.pyplot as plt

import stork

import logging
import pickle

logger = logging.getLogger(__name__)
# The spikes recorded in the Primate Reaching datasets have an interval of 4ms.
# SAMPLING_RATE = 4e-3

import hydra
from omegaconf import DictConfig
from hydra.utils import to_absolute_path
from pathlib import Path
import torch.nn.functional as F
from nlb_tools.nwb_interface import NWBDataset

from nlb_tools.nwb_interface import NWBDataset
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

def get_dataloader_MAZE(cfg, dtype=torch.float32):
    dataloader = DatasetLoader_MAZE(
        basepath=cfg.data.data_dir,
        ratio_val=cfg.data.ratio_val,
        random_val=cfg.data.random_val,
        extend_data=cfg.data.extend_data,
        sample_duration=cfg.data.sample_duration,
        remove_segments_inactive=cfg.data.remove_segments_inactive,
        p_drop=cfg.data.p_drop,
        p_insert=cfg.data.p_insert,
        jitter_sigma=cfg.data.jitter_sigma,
        dtype=dtype,
        dt=cfg.data.dt,
        predict_value=cfg.predict_value,
        testFlag=cfg.testFlag,
    )

    return dataloader


def compute_input_firing_rates(data, cfg, nb_inputs):
    mean1 = 0
    mean2 = 0

    if nb_inputs == 16:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :nb_inputs]) / cfg.data.sample_duration / nb_inputs
            mean1 /= len(data)
            return mean1, None
    else:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :96]) / cfg.data.sample_duration / 96
            try:
                mean2 += torch.sum(data[i][0][:, 96:]) / cfg.data.sample_duration / 96
            except:
                continue

        mean1 /= len(data)
        mean2 /= len(data)


class DatasetLoader_MAZE:
    def __init__(
            self,
            basepath,
            num_steps=1,
            dt=0.004,
            ratio_val=0.1,  ###
            biological_delay=0,
            spike_sorting=False,
            label_series=False,
            random_val=False,
            extend_data=True,
            sample_duration=2,
            bin_width=None,
            stride=None,
            remove_segments_inactive=False,
            dtype=torch.float32,
            p_drop=0.0,
            p_insert=0.0,
            jitter_sigma=0.0,
            predict_value='velocity',
            testFlag=False,
            continuous_trial=True,
            mix_continuous_uncontinuous=False,
    ):
        """Initialize

        Args:
            basepath (str): the path to the data folder
            filename (str): the name of the specific data file to load
            num_steps (int, optional): Argument for the Neurobench dataloader. Should be 1. Defaults to 1.
            dt (float, optional): Time step, should be 0.004 for the monkey data. Defaults to 0.004.
            ratio_val (list, optional): Ratio for validation set. Defaults to 0.25
            biological_delay (int, optional): Delay of readout w.r.t input. Defaults to 0.
            spike_sorting (bool, optional): If True, using single unit activities, otherwise multi unit activities. Defaults to False.
            label_series (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            random_val (bool, optional): If True, samples the validation samples randomly from the train data. Otherwise takes the last samples. Defaults to False.
            extend_data (bool, optional): If true, extends the data to overlapping samples. Defaults to True.
            sample_duration (int, optional): sample duration in seconds. Defaults to 2.
            bin_width (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            stride (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            remove_segments_inactive (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            dtype (_type_, optional): The dtype of the datasets. Defaults to torch.float32.
        """

        self.SAMPLING_RATE = 1 / dt

        self.basepath = basepath
        self.num_steps = num_steps
        self.dt = dt
        self.ratio_val = ratio_val
        self.biological_delay = biological_delay
        self.spike_sorting = spike_sorting
        self.label_series = label_series
        self.random_val = random_val
        self.extend_data = extend_data
        self.sample_duration = sample_duration
        self.remove_segments_inactive = remove_segments_inactive
        self.dtype = dtype
        self.p_drop = p_drop
        self.p_insert = p_insert
        self.jitter_sigma = jitter_sigma
        self.predict_value = predict_value
        self.testFlag = testFlag
        self.continuous_trial = continuous_trial
        self.mix_continuous_uncontinuous = mix_continuous_uncontinuous

        if bin_width is None:
            self.bin_width = self.dt
        else:
            self.bin_width = bin_width
        if stride is None:
            self.stride = self.dt
        else:
            self.stride = stride

        # self.fileName = 'mice3689_0525.mat'
        # self.basepath = 'D:\\PPPPProject\\SNN_Environment\\neural-decoding-RSNN-main\\monkeyData\\mice3689_0525.mat'

        # self.random_val=False
        # self.ratio_val=0.1
        # self.extend_data=True
        # self.dt = 0.004
        # self.sample_duration = 10
        self.n_time_steps = int(sample_duration / dt)
        # self.n_time_steps=500

        # self.dtype = torch.float32

        # self.p_drop = 0
        # self.p_insert = 0
        # self.jitter_sigma = 0

        return

    def get_single_session_data(self, filename, nb_inputs=192 , with_PMd=False, zscore=False):

        dataset = PrimateReaching(
            file_path=self.basepath,
            filename=filename,
            num_steps=self.num_steps,
            train_ratio=0.7,  # Hardcoded here for 25 % test split
            bin_width=self.bin_width,
            biological_delay=self.biological_delay,
            remove_segments_inactive=self.remove_segments_inactive,
            stride=self.dt,
            predict_value=self.predict_value,
        )
        dataset_test = dataset

        """Loads data of a single session and returns a tuple of stork RasDatasets containing the train, val and test data.

                Returns:
                    tuple of stork RasDatasets: Train, val and test datasets
                """

        # Sum train & validation data (75 %) and make own validation split ##合并验证集和测试集
        ind_tv = dataset.ind_train + dataset.ind_val

        # Effective validation ratio = val_ratio / 0.75  ## 验证集的有效比例，重新划分验证集
        n_val=int(np.round(dataset.segmentsNum  * self.ratio_val))  ##　验证集的大小

        ## 如果 self.random_val 为 True，则从 ind_tv 中随机选择 n_val 个样本作为验证集，并将剩余的样本作为训练集。
        ## 否则，按照顺序分割训练集和验证集。
        if self.random_val:
            start_idx = np.random.choice(a=ind_tv[:-n_val], size=1)[0]
            ind_val = np.array(ind_tv[start_idx: start_idx + n_val])
            ind_train = np.array(sorted(set(ind_tv) - set(ind_val)))
        else:
            ind_train = np.array(ind_tv[:-n_val])
            ind_val = np.array(sorted(set(ind_tv) - set(ind_train)))

        if zscore:
            spikes = dataset.samples
            spikes_testdat = dataset_test.samples
            # z-score标准化
            labels = dataset.labels = (dataset.labels - dataset.labels.mean()) / dataset.labels.std()
            labels_testdat = dataset_test.labels = (dataset_test.labels - dataset_test.labels.mean()) / dataset_test.labels.std()
        else:
            spikes = dataset.samples
            spikes_testdat = dataset_test.samples
            labels = dataset.labels
            labels_testdat = dataset_test.labels

        self.ind_train = ind_train
        self.ind_val = ind_val
        self.ind_test = np.array(dataset_test.ind_test)

        # split val and train data into single samples
        # 增加数据量，并将数据切割成规定长度的段落，默认chunk为样本长度的1/10
        # 对于训练集和验证集存在三种情况：
        # 1. 使用所有通道数据作为一个样本
        # 2. 使用M1数据和PMd数据分别作为一个样本
        # 3. 仅使用M1数据作为一个样本
        train_data, train_labels, val_data, val_labels = self.extend_to_sample(dataset.time_segments, spikes, labels, nb_inputs=nb_inputs, with_PMd = with_PMd, mix_trial=self.mix_continuous_uncontinuous)

        # if nb_inputs==192:
        #     logger.info("using all channel data")
        # elif nb_inputs==96:
        #     if with_PMd:
        #         logger.info("using PMd channel data and M1 channel data")
        #         train_data = torch.cat((train_data[:,:,:nb_inputs],train_data[:,:,nb_inputs:]))
        #         val_data = torch.cat((val_data[:,:,:nb_inputs],val_data[:,:,nb_inputs:]))
        #
        #         train_labels = torch.cat((train_labels,train_labels))
        #         val_labels = torch.cat((val_labels,val_labels))
        #
        #     else:
        #         logger.info("using M1 channel data")
        #         train_data = train_data[:,:,:nb_inputs]
        #         val_data = val_data[:,:,:nb_inputs]
        # else:
        #     raise ValueError(f"Unknown number of inputs {nb_inputs}.")


        # Get augmentation kwargs for training dataset 默认为False
        if any([self.p_drop > 0, self.p_insert > 0, self.jitter_sigma > 0]):
            data_augmentation_kwargs = dict(
                data_augmentation=True,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                sigma_t=self.jitter_sigma
            )
        else:
            data_augmentation_kwargs = {}

        # make it ras datasets
        train_ras_data = self.to_ras(train_data, train_labels,
                                     **data_augmentation_kwargs)
        val_ras_data = self.to_ras(val_data, val_labels)

        if self.continuous_trial:
            # test连起来
            # 对于测试集存在两种情况：
            # 1. 使用所有通道数据作为一个样本
            # 2. 仅使用M1数据作为一个样本
            print("test continuous trial")
            test_data, test_labels, test_trials = self.get_test_spikes(
                self.ind_test,
                dataset.time_segments,
                spikes_testdat,
                labels_testdat,
                continusFlag = True
            )
            test_ras_data = self.to_ras(test_data, test_labels)
        else:
            # test不连起来
            # 对于测试集存在两种情况：
            # 1. 使用所有通道数据作为一个样本
            # 2. 仅使用M1数据作为一个样本
            print("test un-continuous trial")
            test_data, test_labels, test_trials = self.get_test_spikes(
                self.ind_test,
                dataset.time_segments,
                spikes_testdat,
                labels_testdat,
                continusFlag = False,
                nb_inputs = nb_inputs,
            )
            test_ras_data=[]
            for i in range(len(test_data)):
                test_ras_data.append(self.to_ras(test_data[i].unsqueeze(0), test_labels[i].unsqueeze(0)))

        return train_ras_data, val_ras_data, test_ras_data

    def extend_spikes(self,
                      trial_indx,
                      data_indx,
                      spike, label,
                      chunksize=10,
                      chunks=None,
                      nb_inputs = 192,
                      with_PMd = False,
                      continuous_trial=True,
                      ):
        print("continuous_trial: ", continuous_trial)
        if chunks == None:
            chunks = self.n_time_steps

        extended_spikes = []
        extended_labels = []
        extended_trials = []

        if continuous_trial:
            end_idx = int(data_indx[trial_indx[-1], 1])
            start_idx = int(data_indx[trial_indx[0], 0])
            for t in range(0, chunks, chunksize):
                curr_spikes = spike[start_idx+t:end_idx+1]
                curr_labels = label[start_idx+t:end_idx+1]

                splitter = np.arange(
                    self.n_time_steps, curr_spikes.shape[0], self.n_time_steps
                )

                extended_spikes += np.split(curr_spikes, splitter)[:-1]
                extended_labels += np.split(curr_labels, splitter)[:-1]

            ## 张量拼接
            extended_spikes = torch.stack(extended_spikes)
            extended_labels = torch.stack(extended_labels)
            extended_trials = trial_indx
        else:
            for tmp_indx in range(trial_indx.shape[0]):

                tmp_trial_indx = trial_indx[tmp_indx]

                # 根据trial_indx选取对应trial的segment
                tmp_startIndx = int(data_indx[tmp_trial_indx, 0])
                tmp_endIndx = int(data_indx[tmp_trial_indx, 1])
                tmp_trial_spikeindx = spike[tmp_startIndx:tmp_endIndx + 1, :]
                tmp_trial_labelindx = label[tmp_startIndx:tmp_endIndx + 1, :]

                for t in range(0, chunks, chunksize):

                    curr_spikes = tmp_trial_spikeindx[t:]
                    curr_labels = tmp_trial_labelindx[t:]


                    if curr_spikes.shape[0] != self.n_time_steps:
                        splitter = np.arange(
                            self.n_time_steps, curr_spikes.shape[0], self.n_time_steps
                        )
                        extended_spikes += np.split(curr_spikes, splitter)[:-1]
                        extended_labels += np.split(curr_labels, splitter)[:-1]
                    else:
                        extended_spikes.append(curr_spikes)
                        extended_labels.append(curr_labels)

                extended_trials += [torch.tensor(tmp_trial_indx)]

            ## 张量拼接
            extended_spikes = torch.stack(extended_spikes)
            extended_labels = torch.stack(extended_labels)
            extended_trials = torch.stack(extended_trials)

            extended_spikes, extended_labels = self.select_brain_area(
                extended_spikes, extended_labels, nb_inputs, with_PMd,
            )

        return extended_spikes, extended_labels, extended_trials

    def get_test_spikes(self,
                        trial_indx,
                        data_indx,
                        spike,
                        label,
                        continusFlag = False,
                        nb_inputs=192,
                        with_PMd=False,
                        ):

        # end_idx = int(data_indx[trial_indx[-1], 1])
        # start_idx = int(data_indx[trial_indx[0], 0])
        # test_spikes = spike[start_idx:end_idx+1]
        # test_labels = label[start_idx:end_idx+1]
        # test_trials = trial_indx


        test_spikes = []
        test_labels = []
        test_trials = []

        # for tmp_indx in range(2):
        for tmp_indx in range(trial_indx.shape[0]):
            tmp_trial_indx=trial_indx[tmp_indx]
            # 根据trial_indx选取对应trial的segment
            tmp_startIndx = int(data_indx[tmp_trial_indx, 0])
            tmp_endIndx = int(data_indx[tmp_trial_indx, 1])
            tmp_trial_spikeindx = spike[tmp_startIndx:tmp_endIndx + 1, :]
            tmp_trial_labelindx = label[tmp_startIndx:tmp_endIndx + 1, :]

            if nb_inputs==192:
                test_spikes += [tmp_trial_spikeindx[:,:]]
            elif nb_inputs==96:
                test_spikes += [tmp_trial_spikeindx[:, nb_inputs:]]
            else:
                raise ValueError(f"Unknown number of inputs {nb_inputs}.")
            test_labels += [tmp_trial_labelindx[:,:]]
            test_trials += [torch.tensor(tmp_trial_indx)]

        # 张量拼接
        # test_spikes = torch.stack(test_spikes)
        # test_labels = torch.stack(test_labels)
        # test_trials = torch.stack(test_trials)

        ## 张量拼接
        if continusFlag:
            test_spikes = torch.cat(test_spikes, dim=0)
            test_labels = torch.cat(test_labels, dim=0)
            test_spikes = test_spikes.unsqueeze(0)
            test_labels = test_labels.unsqueeze(0)
            test_trials = torch.stack(test_trials)

        # test_spikes, test_labels = self.select_brain_area(
        #     test_spikes, test_labels, nb_inputs, with_PMd,
        # )

        return test_spikes, test_labels, test_trials

    def select_brain_area(self, spikes, labels, nb_inputs, with_PMd=False):
        if nb_inputs==192:
            logger.info("using all channel data")
        elif nb_inputs==96:
            if with_PMd:
                logger.info("using PMd channel data and M1 channel data")
                spikes = torch.cat((spikes[:,:,:nb_inputs],spikes[:,:,nb_inputs:]))
                labels = torch.cat((labels,labels))
            else:
                logger.info("using M1 channel data")
                spikes = spikes[:,:,nb_inputs:]
        else:
            raise ValueError(f"Unknown number of inputs {nb_inputs}.")

        return spikes, labels

    def to_ras(self, data, labels, **data_augmentation_kwargs):
        ras_data = [[[], []] for _ in data]

        for i, sample in enumerate(data):
            for j in range(sample.shape[-1]):
                spike_times = np.where(sample[:, j] == 1)[0].tolist()
                ras_data[i][0] += spike_times
                ras_data[i][1] += [j] * len(spike_times)
            ras_data[i] = torch.tensor(ras_data[i], dtype=self.dtype)

        monkey_ds_kwargs = dict(
            nb_steps=data.shape[-2], nb_units=data.shape[-1], time_scale=1.0
        )

        monkey_ds = stork.datasets.RasDataset(
            (ras_data, labels), dtype=self.dtype,
            **monkey_ds_kwargs, **data_augmentation_kwargs
        )

        return monkey_ds

    def extend_to_sample(self, time_segments, spikes, labels, nb_inputs, with_PMd, mix_trial=False):
        # mix_trial: 用于混合不同长度的样本片段，并通过pad_sequences_to_max函数将它们补齐到相同长度

        train_data=[]
        train_labels=[]
        val_data=[]
        val_labels=[]

        if mix_trial:
            # 获取连续 trial 数据(已为500步，无需补0)
            self.sample_duration = 2
            self.n_time_steps = total_n_time_steps = int(2 / self.dt)
            # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
            if self.extend_data:
                logger.info("Extending data...")
                train_data_conti, train_labels_conti, train_trials_conti = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=True,
                )
                val_data_conti, val_labels_conti, val_trials_conti = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=True,
                )
            else:
                train_data_conti, train_labels_conti, train_trials_conti = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10)-1),
                    continuous_trial=True,
                )
                val_data_conti, val_labels_conti, val_trials_conti = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10) - 1),
                    continuous_trial=True,
                )
            train_data.append(train_data_conti)
            train_labels.append(train_labels_conti)
            val_data.append(val_data_conti)
            val_labels.append(val_labels_conti)

            # 获取1秒数据并补0到500步
            self.sample_duration = 1
            self.n_time_steps = int(1 / self.dt)
            # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
            if self.extend_data:
                logger.info("Extending data...")
                train_data_1, train_labels_1, train_trials_1 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
                val_data_1, val_labels_1, val_trials_1 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
            else:
                train_data_1, train_labels_1, train_trials_1 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10)-1),
                    continuous_trial=False,
                )
                val_data_1, val_labels_1, val_trials_1 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10) - 1),
                    continuous_trial=False,
                )
            train_data_1, train_labels_1=self.pad_sequences_to_max(train_data_1, train_labels_1, total_n_time_steps)
            val_data_1, val_labels_1 = self.pad_sequences_to_max(val_data_1, val_labels_1, total_n_time_steps)
            train_data.append(train_data_1)
            train_labels.append(train_labels_1)
            val_data.append(val_data_1)
            val_labels.append(val_labels_1)

            # 获取0.7秒数据并补0到500步
            self.sample_duration = 0.7
            self.n_time_steps = int(0.7 / self.dt)
            # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
            if self.extend_data:
                logger.info("Extending data...")
                train_data_07, train_labels_07, train_trials_07 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
                val_data_07, val_labels_07, val_trials_07 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
            else:
                train_data_07, train_labels_07, train_trials_07 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10)-1),
                    continuous_trial=False,
                )
                val_data_07, val_labels_07, val_trials_07 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10) - 1),
                    continuous_trial=False,
                )
            train_data_07, train_labels_07 = self.pad_sequences_to_max(train_data_07, train_labels_07, total_n_time_steps)
            val_data_07, val_labels_07 = self.pad_sequences_to_max(val_data_07, val_labels_07, total_n_time_steps)
            train_data.append(train_data_07)
            train_labels.append(train_labels_07)
            val_data.append(val_data_07)
            val_labels.append(val_labels_07)

            # 获取0.4秒数据并补0到500步
            self.sample_duration = 0.4
            self.n_time_steps = int(0.4 / self.dt)
            # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
            if self.extend_data:
                logger.info("Extending data...")
                train_data_04, train_labels_04, train_trials_04 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
                val_data_04, val_labels_04, val_trials_04 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
            else:
                train_data_04, train_labels_04, train_trials_04 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10)-1),
                    continuous_trial=False,
                )
                val_data_04, val_labels_04, val_trials_04 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10) - 1),
                    continuous_trial=False,
                )
            train_data_04, train_labels_04 = self.pad_sequences_to_max(train_data_04, train_labels_04, total_n_time_steps)
            val_data_04, val_labels_04 = self.pad_sequences_to_max(val_data_04, val_labels_04, total_n_time_steps)
            train_data.append(train_data_04)
            train_labels.append(train_labels_04)
            val_data.append(val_data_04)
            val_labels.append(val_labels_04)

            train_data = torch.cat(train_data, dim=0)
            train_labels = torch.cat(train_labels, dim=0)
            val_data = torch.cat(val_data, dim=0)
            val_labels = torch.cat(val_labels, dim=0)
        else:
            if self.extend_data:
                logger.info("Extending data...")
                train_data, train_labels, train_trials = self.extend_spikes(
                    self.ind_train,
                    time_segments,
                    spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    continuous_trial=self.continuous_trial,
                )
                val_data, val_labels, val_trials = self.extend_spikes(
                    self.ind_val,
                    time_segments,
                    spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    continuous_trial=self.continuous_trial,
                )
            else:
                train_data, train_labels, train_trials = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=int(((1 / self.dt) * self.sample_duration) / 10) - 1,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    continuous_trial=self.continuous_trial,
                )
                val_data, val_labels, val_trials = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=int(((1 / self.dt) * self.sample_duration) / 10) - 1,
                    nb_inputs=nb_inputs,
                    with_PMd=with_PMd,
                    continuous_trial=self.continuous_trial,
                )

        return train_data, train_labels, val_data, val_labels

    def pad_sequences_to_max(self, spike, label, max_length):
        """将序列数据补0到最大长度
        Args:
            data: 输入数据，形状为(样本数, 时间步数, 特征数)
            max_length: 目标最大时间步数
        Returns:
            补0后的序列数据
        """
        pad_length = max_length - spike.shape[1]

        padded_spike=F.pad(spike, (0, 0, 0, pad_length))
        padded_label= F.pad(label, (0, 0, 0, pad_length))


        return padded_spike, padded_label















class PrimateReaching():

    def __init__(
            self,
            file_path,
            filename,
            num_steps,
            train_ratio=0.5,
            label_series=False,
            biological_delay=0,
            spike_sorting=False,
            stride=0.004,
            bin_width=0.028,
            max_segment_length=2000,
            split_num=1,
            remove_segments_inactive=False,
            download=True,
            predict_value='velocity',
    ):
        """
        Initialises the Dataset for the Primate Reaching Task.

        Args:
            file_path (str): The path to the directory storing the matlab files.
            filename (str): The name of the file that will be loaded.
            num_steps (int): Number of consecutive timesteps that are included per sample.
                             In the real-time case, this should be 1.
            train_ratio (float): ratio for how the dataset will be split into training/(val+test) set.
                                 Default is 0.8 (80% of data is training).
            label_series (bool): Whether the labels are series or not. Useful for training with multiple
                                 timesteps. Default is False.
            biological_delay (int): How many steps of delay is to be applied to the dataset. Default is 0
                                    i.e. no delay applied.
            spike_sorting (bool): Apply spike sorting for processing raw spike data. Default is False.
            stride (float):  How many steps are taken when moving the bin_window. Default is 0.004 (4ms).
            bin_width (float): The size of the bin_window. Default is 0.028 (28ms).
            max_segment_length: Define the upper limits of a segment. Default is 2000 data points (8s)
            split_num (int): The number of chunks to break the timeseries into. Default is 1 (no splits).
            remove_segments_inactive (bool): Whether to remove segments longer than max_segment_length,
                                             which represent subject inactivity. Default is False.
            download (bool): If True, downloads the dataset from the internet and puts it in root
                             directory. If dataset is already downloaded, it will not be downloaded again.

        """
        ##　设置数据路径 MAZE和其他的数据存储格式不一样
        if "000128" in filename:
            self.task="MAZE"
            self.filename = filename if filename[-13:] == "/sub-Jenkins/" else filename + "/sub-Jenkins/"
        elif "00129" in filename:
            self.task = "RTT"
            self.filename = filename if filename[-13:] == "/sub-Indy/" else filename + "/sub-Indy/"
        self.file_path = os.path.join(file_path, self.filename)

        self.segmentBinNum = 20
        self.split_num = 1
        self.train_ratio = 0.5
        self.stride = stride
        self.SAMPLING_RATE = stride
        self.delay = biological_delay

        # The samples and labels of the dataset
        self.samples = None
        self.labels = None
        # These lists store the index of segments that belongs to training/validation/test set
        self.ind_train, self.ind_val, self.ind_test = [], [], []
        self.predict_value=predict_value

        # test parameters
        assert self.delay >= 0

        # 读取数据，和其他数据不一样，C05的数据是已经分好trial的
        self.load_data()


        # # 添加延迟
        # if self.delay > 0:
        #     self.apply_delay()

        self.valid_segments = np.arange(self.time_segments.shape[0])
        self.split_data()
        return

    def load_data(self):
        """Load the data from the matlab file and spike data if spike data has been
        processed and stored already."""

        # Assume input is the original dataset, instead of the reconstructed one
        print(f"Loading {self.filename}")
        # dataset = NWBDataset(self.file_path, "*train", split_heldout=False)
        # dataset.resample(self.SAMPLING_RATE*1000) # 对数据降采样 SAMPLING_RATE是采样的bin，单位是s，原本的数据集采样率是1ms
        with open(self.file_path + 'nwb_dataset.pkl', 'rb') as f:
            dataset = pickle.load(f)

        # Extract neural data and lagged hand velocity
        # trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-130, 370))  # 用于神经活动 align_range单位：ms，共500ms
        # lagged_trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-50, 450))  # 用于手部速度

        lagTmie=80
        startTime=0
        endTime=None

        trial_data = dataset.make_trial_data(align_field='go_cue_time', align_range=(startTime, endTime))  # 用于神经活动 align_range单位：ms
        lagged_trial_data = trial_data  # 用于手部速度
        # lagged_trial_data = dataset.make_trial_data(align_field='go_cue_time', align_range=(startTime, endTime))  # 用于手部速度

        # 从dataset中提取spikes、cursor_pos、target_pos
        if self.predict_value[0] == 'velocity':
            self.labels = torch.from_numpy(lagged_trial_data.hand_vel.values)  # 手部速度
        elif self.predict_value[0] == 'position':
            self.labels = torch.from_numpy(lagged_trial_data.hand_pos.values)  # 手部位置
        elif self.predict_value[0] == 'acceleration':
            self.labels = torch.from_numpy(lagged_trial_data.hand_vel.values)  # 手部加速度
            self.labels = torch.gradient(self.labels, dim=0)[0]

        # # 截断极值
        # self.labels = np.clip(self.labels, -900, 900)
        #
        # # Min-Max 归一化
        # min_vals = self.labels.min()  # 每一行的最小值
        # max_vals = self.labels.max()  # 每一行的最大值
        # normalized_data_min_max = (self.labels - min_vals) / (max_vals - min_vals)
        # self.labels = normalized_data_min_max

        # z-score标准化
        # mean_vals = self.labels.mean()  # 均值
        # std_vals = self.labels.std()  # 标准差
        # normalized_data_z_score = (self.labels - mean_vals) / std_vals
        # self.labels = normalized_data_z_score
        # print("without zscore")

        trial_id = trial_data.trial_id  # 所属的trial_id
        trial_num = max(trial_id)


        self.samples = self.convert_spikes_to_mua(trial_data.spikes)

        # 计算每个trial的长度
        self.final_frameNum = trial_id.value_counts().sort_index().values

        self.time_segments = self.split_into_segments(self.final_frameNum)
        self.segmentsNum = self.time_segments.shape[0]

        return


    @staticmethod
    def convert_spikes_to_mua(spikes):
        # unit转为MUA
        spikes_id = spikes.columns.tolist()
        spikes_electrode = pd.DataFrame(
            np.zeros((len(spikes), 192)),
            columns=list(range(101, 197)) + list(range(201, 297))  # 生成列名 [1-96, 101-196]
            )  # 创建一个 len(spikes) 行 × 192 列的 DataFrame，填充为 0

        for i in range(len(spikes_id)):
            unit_id = spikes_id[i]
            if (unit_id // 1000) == 1:
                electrode_id = (unit_id // 10) % 100 + 100
            else:
                electrode_id = (unit_id // 10) % 100 + 200
            spikes_electrode[electrode_id] = spikes_electrode[electrode_id] + spikes[unit_id]
        spikes_electrode[spikes_electrode > 1] = 1
        return torch.from_numpy(spikes_electrode.values)

    def apply_delay(self):
        """Shift the labels by the delay to account for the biological delay between
        spikes and movement onset."""
        # Dimension: No_of_Channels*No_of_Records
        self.samples = self.samples[:, : -self.delay]
        self.labels = self.labels[:, self.delay:]

    def split_data(self):
        """Split segments into training/validation/test set."""
        # This is No. of chunks
        split_num = self.split_num
        total_segments = self.time_segments.shape[0]  ## trial 总数
        sub_length = int(
            total_segments / split_num
        )  # This is no of segments in each chunk  ## 每个块内的片段数

        train_len = math.floor(self.train_ratio * sub_length)
        val_len = math.floor((sub_length - train_len) / 2)

        # offset = int(np.round(self.bin_width / SAMPLING_RATE)) * self.num_steps
        offset = 0

        for split_no in range(split_num):
            for i in range(sub_length):
                # Each segment's Dimension is: No_of_Probes * No_of_Recording
                ## 训练集
                if i < train_len:
                    self.ind_train += [i]
                ## 验证集
                elif train_len <= i < train_len + val_len:
                    self.ind_val += [i]
                ## 测试集
                else:
                    self.ind_test += [i]
        return

    @staticmethod
    def split_into_segments(indices):

        # 计算累加和
        cumulative_indices = np.cumsum(indices)

        indices_temp = np.insert(cumulative_indices[:-1], 0, 0)
        result = np.column_stack((indices_temp, cumulative_indices-1))


        return result

    def read_test(self):
        # 打开 .mat 文件
        with h5py.File(self.file_path, 'r') as f:
            # 提取 data_diff1 和 data_fir
            data_diff1 = f['data_diff1'][()]  # 提取第一个 cell 数组
            data_fir = f['data_fir'][()]  # 提取第二个 cell 数组

            # 检查数据
            print(f"data_diff1 的形状: {data_diff1.shape}")  # 应该是 (1, 232)
            print(f"data_fir 的形状: {data_fir.shape}")  # 应该是 (1, 232)

            # 访问第一个 trial 的数据
            trial_diff1_ref = data_diff1[0, 0]  # 第一个 trial 的引用
            trial_fir_ref = data_fir[0, 0]  # 第一个 trial 的引用

            # 通过引用提取实际数据
            trial_diff1 = f[trial_diff1_ref][()]  # 第一个 trial 的 2D 坐标数据
            trial_fir = f[trial_fir_ref][()]  # 第一个 trial 的神经信号数据

            print(f"第一个 trial 的 2D 坐标数据形状: {trial_diff1.shape}")
            print(f"第一个 trial 的神经信号数据形状: {trial_fir.shape}")




