"""
This file contains code from PyTorch Vision (https://github.com/pytorch/vision) which is licensed under BSD 3-Clause License.
These snippets are the Copyright (c) of Soumith Chintala 2016. All other code is the Copyright (c) of the NeuroBench Developers 2023.
"""

from torch.utils.data import Dataset
import os
import torch
import math
import numpy as np
import h5py
from scipy.signal import convolve2d
from urllib.error import URLError
import matplotlib.pyplot as plt
from temporaldata import Data as temp_Data


import stork

import logging

logger = logging.getLogger(__name__)
# The spikes recorded in the Primate Reaching datasets have an interval of 4ms.
# SAMPLING_RATE = 4e-3

import hydra
from omegaconf import DictConfig
from hydra.utils import to_absolute_path
from pathlib import Path
import torch.nn.functional as F


def get_dataloader_POYO(cfg, dtype=torch.float32):
    dataloader = DatasetLoader_POYO(
        basepath=cfg.data.data_dir,
        ratio_val=cfg.data.ratio_val,
        random_val=cfg.data.random_val,
        extend_data=cfg.data.extend_data,
        sample_duration=cfg.data.sample_duration,
        remove_segments_inactive=cfg.data.remove_segments_inactive,
        p_drop=cfg.data.p_drop,
        p_insert=cfg.data.p_insert,
        jitter_sigma=cfg.data.jitter_sigma,
        dtype=dtype,
        dt=cfg.data.dt,
        predict_value=cfg.predict_value,
        testFlag=cfg.testFlag,
        padding=cfg.data.padding,
    )

    return dataloader


def compute_input_firing_rates(data, cfg, nb_inputs):
    mean1 = 0
    mean2 = 0

    if nb_inputs == 16:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :nb_inputs]) / cfg.data.sample_duration / nb_inputs
            mean1 /= len(data)
            return mean1, None
    else:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :96]) / cfg.data.sample_duration / 96
            try:
                mean2 += torch.sum(data[i][0][:, 96:]) / cfg.data.sample_duration / 96
            except:
                continue

        mean1 /= len(data)
        mean2 /= len(data)


class DatasetLoader_POYO:
    def __init__(
            self,
            basepath,
            num_steps=1,
            dt=0.004,
            ratio_val=0.1,  ###
            biological_delay=0,
            spike_sorting=False,
            label_series=False,
            random_val=False,
            extend_data=True,
            sample_duration=2,
            bin_width=None,
            stride=None,
            remove_segments_inactive=False,
            dtype=torch.float32,
            p_drop=0.0,
            p_insert=0.0,
            jitter_sigma=0.0,
            predict_value='velocity',
            testFlag=False,
            # continuous_trial=True,
            mix_continuous_uncontinuous=False,
            padding="zeros",
    ):
        """Initialize

        Args:
            basepath (str): the path to the data folder
            filename (str): the name of the specific data file to load
            num_steps (int, optional): Argument for the Neurobench dataloader. Should be 1. Defaults to 1.
            dt (float, optional): Time step, should be 0.004 for the monkey data. Defaults to 0.004.
            ratio_val (list, optional): Ratio for validation set. Defaults to 0.25
            biological_delay (int, optional): Delay of readout w.r.t input. Defaults to 0.
            spike_sorting (bool, optional): If True, using single unit activities, otherwise multi unit activities. Defaults to False.
            label_series (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            random_val (bool, optional): If True, samples the validation samples randomly from the train data. Otherwise takes the last samples. Defaults to False.
            extend_data (bool, optional): If true, extends the data to overlapping samples. Defaults to True.
            sample_duration (int, optional): sample duration in seconds. Defaults to 2.
            bin_width (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            stride (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            remove_segments_inactive (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            dtype (_type_, optional): The dtype of the datasets. Defaults to torch.float32.
        """

        self.SAMPLING_RATE = 1 / dt

        self.basepath = basepath
        self.num_steps = num_steps
        self.dt = dt
        self.ratio_val = ratio_val
        self.biological_delay = biological_delay
        self.spike_sorting = spike_sorting
        self.label_series = label_series
        self.random_val = random_val
        self.extend_data = extend_data
        self.sample_duration = sample_duration
        self.remove_segments_inactive = remove_segments_inactive
        self.dtype = dtype
        self.p_drop = p_drop
        self.p_insert = p_insert
        self.jitter_sigma = jitter_sigma
        self.predict_value = predict_value
        self.testFlag = testFlag
        self.continuous_trial = False
        self.mix_continuous_uncontinuous = mix_continuous_uncontinuous
        self.padding = padding

        if bin_width is None:
            self.bin_width = self.dt
        else:
            self.bin_width = bin_width
        if stride is None:
            self.stride = self.dt
        else:
            self.stride = stride

        # self.fileName = 'mice3689_0525.mat'
        # self.basepath = 'D:\\PPPPProject\\SNN_Environment\\neural-decoding-RSNN-main\\monkeyData\\mice3689_0525.mat'

        # self.random_val=False
        # self.ratio_val=0.1
        # self.extend_data=True
        # self.dt = 0.004
        # self.sample_duration = 10
        self.n_time_steps = int(sample_duration / dt)
        # self.n_time_steps=500

        # self.dtype = torch.float32

        # self.p_drop = 0
        # self.p_insert = 0
        # self.jitter_sigma = 0

        return

    def get_single_session_data(self, filename, nb_inputs=96, zscore=False):

        dataset = PrimateReaching(
            file_path=self.basepath,
            filename=filename,
            num_steps=self.num_steps,
            train_ratio=0.5,  # Hardcoded here for 25 % test split
            bin_width=self.bin_width,
            biological_delay=self.biological_delay,
            remove_segments_inactive=self.remove_segments_inactive,
            stride=self.dt,
            predict_value=self.predict_value,
        )
        dataset_test = dataset

        """Loads data of a single session and returns a tuple of stork RasDatasets containing the train, val and test data.

                Returns:
                    tuple of stork RasDatasets: Train, val and test datasets
                """

        ind_train = np.array(dataset.ind_train)
        ind_val = np.array(dataset.ind_val)

        # spikes = dataset.samples
        # labels = dataset.labels
        #
        # spikes_testdat = dataset_test.samples
        # labels_testdat = dataset_test.labels

        # 对数据进行zscore标准化
        if zscore:
            spikes = dataset.samples
            spikes_testdat = dataset_test.samples
            # z-score标准化
            labels = dataset.labels = (dataset.labels - dataset.labels.mean()) / dataset.labels.std()
            labels_testdat = dataset_test.labels = (dataset_test.labels - dataset_test.labels.mean()) / dataset_test.labels.std()
        else:
            spikes = dataset.samples
            spikes_testdat = dataset_test.samples
            labels = dataset.labels
            labels_testdat = dataset_test.labels

        self.ind_train = ind_train
        self.ind_val = ind_val
        self.ind_test = np.array(dataset_test.ind_test)

        # 如果要求nb_input=192，则将spikes进行扩展
        if nb_inputs == 192 and spikes.shape[1] == 96:
            if self.padding=="zeros":
                print("Padding with zeros...")
                zeros_spikes = torch.zeros_like(spikes)
                spikes = torch.cat((spikes, zeros_spikes), dim=1)
            else:
                print("Padding with copying...")
                spikes = torch.cat((spikes, spikes), dim=1)


        train_data, train_labels, val_data, val_labels = self.extend_to_sample(dataset.time_segments, spikes ,labels , mix_trial=self.mix_continuous_uncontinuous)


        # Get augmentation kwargs for training dataset 默认为False
        if any([self.p_drop > 0, self.p_insert > 0, self.jitter_sigma > 0]):
            data_augmentation_kwargs = dict(
                data_augmentation=True,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                sigma_t=self.jitter_sigma
            )
        else:
            data_augmentation_kwargs = {}

        # make it ras datasets
        train_ras_data = self.to_ras(train_data, train_labels,
                                     **data_augmentation_kwargs)
        val_ras_data = self.to_ras(val_data, val_labels)

        if self.continuous_trial:
            # test连起来
            print("test continuous trial")
            test_data, test_labels, test_trials = self.get_test_spikes(
            self.ind_test, dataset.time_segments, spikes, labels, continusFlag = True
            )
            test_ras_data = self.to_ras(test_data, test_labels)
        else:
            # test不连起来
            print("test un-continuous trial")
            test_data, test_labels, test_trials = self.get_test_spikes(
                self.ind_test,
                dataset.time_segments,
                spikes_testdat,
                labels_testdat,
                continusFlag = False
            )
            test_ras_data=[]
            for i in range(len(test_data)):
                test_ras_data.append(self.to_ras(test_data[i].unsqueeze(0), test_labels[i].unsqueeze(0)))

        return train_ras_data, val_ras_data, test_ras_data

    def extend_spikes(self,
                      trial_indx,
                      data_indx,
                      spike,
                      label,
                      chunksize=10,
                      chunks=None,
                      continuous_trial=True,
                      ):
        print("continuous_trial: ", continuous_trial)

        if chunks == None:
            chunks = self.n_time_steps

        extended_spikes = []
        extended_labels = []
        extended_trials = []

        if continuous_trial:
            end_idx = int(data_indx[trial_indx[-1], 1])
            start_idx = int(data_indx[trial_indx[0], 0])
            for t in range(0, chunks, chunksize):
                curr_spikes = spike[start_idx+t:end_idx+1]
                curr_labels = label[start_idx+t:end_idx+1]

                splitter = np.arange(
                    self.n_time_steps, curr_spikes.shape[0], self.n_time_steps
                )

                extended_spikes += np.split(curr_spikes, splitter)[:-1]
                extended_labels += np.split(curr_labels, splitter)[:-1]

            ## 张量拼接
            extended_spikes = torch.stack(extended_spikes)
            extended_labels = torch.stack(extended_labels)
            extended_trials = trial_indx
        else:
            for tmp_indx in range(trial_indx.shape[0]):

                tmp_trial_indx = trial_indx[tmp_indx]

                # 根据trial_indx选取对应trial的segment
                tmp_startIndx = int(data_indx[tmp_trial_indx, 0])
                tmp_endIndx = int(data_indx[tmp_trial_indx, 1])
                tmp_trial_spikeindx = spike[tmp_startIndx:tmp_endIndx + 1, :]
                tmp_trial_labelindx = label[tmp_startIndx:tmp_endIndx + 1, :]

                for t in range(0, chunks, chunksize):

                    curr_spikes = tmp_trial_spikeindx[t:]
                    curr_labels = tmp_trial_labelindx[t:]

                    splitter = np.arange(
                        self.n_time_steps, curr_spikes.shape[0], self.n_time_steps
                    )

                    extended_spikes += np.split(curr_spikes, splitter)[:-1]
                    extended_labels += np.split(curr_labels, splitter)[:-1]
                extended_trials += [torch.tensor(tmp_trial_indx)]

            ## 张量拼接
            extended_spikes = torch.stack(extended_spikes)
            extended_labels = torch.stack(extended_labels)
            extended_trials = torch.stack(extended_trials)

        return extended_spikes, extended_labels, extended_trials

    def get_test_spikes(self,
                        trial_indx,
                        data_indx,
                        spike,
                        label,
                        continusFlag = False
                        ):
        # end_idx = int(data_indx[trial_indx[-1], 1])
        # start_idx = int(data_indx[trial_indx[0], 0])
        # test_spikes = spike[start_idx:end_idx+1]
        # test_labels = label[start_idx:end_idx+1]
        # test_trials = trial_indx

        test_spikes = []
        test_labels = []
        test_trials = []

        # for tmp_indx in range(2):
        for tmp_indx in range(trial_indx.shape[0]):
            tmp_trial_indx=trial_indx[tmp_indx]
            # 根据trial_indx选取对应trial的segment
            tmp_startIndx = int(data_indx[tmp_trial_indx, 0])
            tmp_endIndx = int(data_indx[tmp_trial_indx, 1])
            tmp_trial_spikeindx = spike[tmp_startIndx:tmp_endIndx + 1, :]
            tmp_trial_labelindx = label[tmp_startIndx:tmp_endIndx + 1, :]

            test_spikes += [tmp_trial_spikeindx[:,:]]
            test_labels += [tmp_trial_labelindx[:,:]]
            test_trials += [torch.tensor(tmp_trial_indx)]

        # 张量拼接
        # test_spikes = torch.stack(test_spikes)
        # test_labels = torch.stack(test_labels)
        # test_trials = torch.stack(test_trials)

        ## 张量拼接
        if continusFlag:
            test_spikes = torch.cat(test_spikes, dim=0)
            test_labels = torch.cat(test_labels, dim=0)
            test_spikes = test_spikes.unsqueeze(0)
            test_labels = test_labels.unsqueeze(0)
            test_trials = torch.stack(test_trials)

        return test_spikes, test_labels, test_trials

    def select_brain_area(self, spikes, labels, nb_inputs, with_PMd=False):
        if nb_inputs==192:
            logger.info("using all channel data")
        elif nb_inputs==96:
            if with_PMd:
                logger.info("using PMd channel data and M1 channel data")
                spikes = torch.cat((spikes[:,:,:nb_inputs],spikes[:,:,nb_inputs:]))
                labels = torch.cat((labels,labels))
            else:
                logger.info("using M1 channel data")
                spikes = spikes[:,:,nb_inputs:]
        else:
            raise ValueError(f"Unknown number of inputs {nb_inputs}.")

        return spikes, labels

    def to_ras(self, data, labels, **data_augmentation_kwargs):
        ras_data = [[[], []] for _ in data]

        for i, sample in enumerate(data):
            for j in range(sample.shape[-1]):
                spike_times = np.where(sample[:, j] == 1)[0].tolist()
                ras_data[i][0] += spike_times
                ras_data[i][1] += [j] * len(spike_times)
            ras_data[i] = torch.tensor(ras_data[i], dtype=self.dtype)

        monkey_ds_kwargs = dict(
            nb_steps=data.shape[-2], nb_units=data.shape[-1], time_scale=1.0
        )

        monkey_ds = stork.datasets.RasDataset(
            (ras_data, labels), dtype=self.dtype,
            **monkey_ds_kwargs, **data_augmentation_kwargs
        )

        return monkey_ds

    def extend_to_sample(self, time_segments, spikes, labels , mix_trial=False):
        # mix_trial: 用于混合不同长度的样本片段，并通过pad_sequences_to_max函数将它们补齐到相同长度

        train_data=[]
        train_labels=[]
        val_data=[]
        val_labels=[]

        if mix_trial:
            # 获取连续 trial 数据(已为500步，无需补0)
            self.sample_duration = 2
            self.n_time_steps = total_n_time_steps = int(2 / self.dt)
            # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
            if self.extend_data:
                logger.info("Extending data...")
                train_data_conti, train_labels_conti, train_trials_conti = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=True,
                )
                val_data_conti, val_labels_conti, val_trials_conti = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=True,
                )
            else:
                train_data_conti, train_labels_conti, train_trials_conti = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10)-1),
                    continuous_trial=True,
                )
                val_data_conti, val_labels_conti, val_trials_conti = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10) - 1),
                    continuous_trial=True,
                )
            train_data.append(train_data_conti)
            train_labels.append(train_labels_conti)
            val_data.append(val_data_conti)
            val_labels.append(val_labels_conti)

            # 获取1秒数据并补0到500步
            self.sample_duration = 1
            self.n_time_steps = int(1 / self.dt)
            # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
            if self.extend_data:
                logger.info("Extending data...")
                train_data_1, train_labels_1, train_trials_1 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
                val_data_1, val_labels_1, val_trials_1 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
            else:
                train_data_1, train_labels_1, train_trials_1 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10)-1),
                    continuous_trial=False,
                )
                val_data_1, val_labels_1, val_trials_1 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10) - 1),
                    continuous_trial=False,
                )
            train_data_1, train_labels_1=self.pad_sequences_to_max(train_data_1, train_labels_1, total_n_time_steps)
            val_data_1, val_labels_1 = self.pad_sequences_to_max(val_data_1, val_labels_1, total_n_time_steps)
            train_data.append(train_data_1)
            train_labels.append(train_labels_1)
            val_data.append(val_data_1)
            val_labels.append(val_labels_1)

            # 获取0.7秒数据并补0到500步
            self.sample_duration = 0.7
            self.n_time_steps = int(0.7 / self.dt)
            # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
            if self.extend_data:
                logger.info("Extending data...")
                train_data_07, train_labels_07, train_trials_07 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
                val_data_07, val_labels_07, val_trials_07 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
            else:
                train_data_07, train_labels_07, train_trials_07 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10)-1),
                    continuous_trial=False,
                )
                val_data_07, val_labels_07, val_trials_07 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10) - 1),
                    continuous_trial=False,
                )
            train_data_07, train_labels_07 = self.pad_sequences_to_max(train_data_07, train_labels_07, total_n_time_steps)
            val_data_07, val_labels_07 = self.pad_sequences_to_max(val_data_07, val_labels_07, total_n_time_steps)
            train_data.append(train_data_07)
            train_labels.append(train_labels_07)
            val_data.append(val_data_07)
            val_labels.append(val_labels_07)

            # 获取0.4秒数据并补0到500步
            self.sample_duration = 0.4
            self.n_time_steps = int(0.4 / self.dt)
            # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
            if self.extend_data:
                logger.info("Extending data...")
                train_data_04, train_labels_04, train_trials_04 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
                val_data_04, val_labels_04, val_trials_04 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=False,
                )
            else:
                train_data_04, train_labels_04, train_trials_04 = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10)-1),
                    continuous_trial=False,
                )
                val_data_04, val_labels_04, val_trials_04 = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=str(int(((1 / self.dt) * self.sample_duration) / 10) - 1),
                    continuous_trial=False,
                )
            train_data_04, train_labels_04 = self.pad_sequences_to_max(train_data_04, train_labels_04, total_n_time_steps)
            val_data_04, val_labels_04 = self.pad_sequences_to_max(val_data_04, val_labels_04, total_n_time_steps)
            train_data.append(train_data_04)
            train_labels.append(train_labels_04)
            val_data.append(val_data_04)
            val_labels.append(val_labels_04)

            train_data = torch.cat(train_data, dim=0)
            train_labels = torch.cat(train_labels, dim=0)
            val_data = torch.cat(val_data, dim=0)
            val_labels = torch.cat(val_labels, dim=0)
        else:
            if self.extend_data:
                logger.info("Extending data...")
                train_data, train_labels, train_trials = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=self.continuous_trial,
                )
                val_data, val_labels, val_trials = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    continuous_trial=self.continuous_trial,
                )
            else:
                train_data, train_labels, train_trials = self.extend_spikes(
                    self.ind_train, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=int(int(((1 / self.dt) * self.sample_duration) / 10)-1),
                    continuous_trial=self.continuous_trial,
                )
                val_data, val_labels, val_trials = self.extend_spikes(
                    self.ind_val, time_segments, spikes, labels,
                    chunksize=int(((1 / self.dt) * self.sample_duration) / 10),  # 以样本的1/10作为overlap
                    chunks=int(int(((1 / self.dt) * self.sample_duration) / 10) - 1),
                    continuous_trial=self.continuous_trial,
                )

        return train_data, train_labels, val_data, val_labels

    def pad_sequences_to_max(self, spike, label, max_length):
        """将序列数据补0到最大长度
        Args:
            data: 输入数据，形状为(样本数, 时间步数, 特征数)
            max_length: 目标最大时间步数
        Returns:
            ���0后的序列数据
        """
        pad_length = max_length - spike.shape[1]

        padded_spike=F.pad(spike, (0, 0, 0, pad_length))
        padded_label= F.pad(label, (0, 0, 0, pad_length))


        return padded_spike, padded_label


def plot_and_save_labels(labels, predict_value='velocity', save_path='output.png', max_samples=1000):
    """
    绘制labels数据并保存为图片

    Args:
        labels: 标签数据，torch.Tensor或numpy数组
        predict_value: 预测值类型，用于图表标题
        save_path: 保存路径，默认为'output.png'
        max_samples: 最大采样点数，用于大数据集的下采样
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import torch

    if labels is None:
        print("Label data is empty, cannot plot")
        return

    # 转换为numpy数组以便处理
    labels_np = labels.numpy() if isinstance(labels, torch.Tensor) else labels

    fig, axes = plt.subplots(2, 1, figsize=(12, 8))

    # 绘制X方向数据
    axes[0].plot(labels_np[:max_samples, 0], 'b-', linewidth=0.8)
    axes[0].set_title(f'{predict_value.capitalize()} - X Direction')
    axes[0].set_ylabel('X Value')
    axes[0].grid(True, alpha=0.3)

    # 绘制Y方向数据
    axes[1].plot(labels_np[:max_samples, 1], 'r-', linewidth=0.8)
    axes[1].set_title(f'{predict_value.capitalize()} - Y Direction')
    axes[1].set_ylabel('Y Value')
    axes[1].set_xlabel('Time Steps')
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # 保存图片
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"图片已保存到: {save_path}")
    plt.close()  # 关闭图形以释放内存

    # 打印统计信息
    print(f"Label数据统计:")
    print(f"形状: {labels.shape}")
    print(f"X方向 - 均值: {labels_np[:, 0].mean():.4f}, 标准差: {labels_np[:, 0].std():.4f}")
    print(f"Y方向 - 均值: {labels_np[:, 1].mean():.4f}, 标准差: {labels_np[:, 1].std():.4f}")


def plot_spikes(spikes, max_channels=96, max_time=1000, save_path=None):
    """
    绘制spikes的脉冲图（raster plot）

    Args:
        spikes: torch.Tensor或numpy.ndarray，形状为(时间步数, 通道数)
        max_channels: 最多显示多少个通道
        max_time: 最多显示多少个时间步
        save_path: 图片保存路径，若为None则直接显示
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import torch

    spikes=spikes.T
    print(spikes.shape)
    if isinstance(spikes, torch.Tensor):
        spikes = spikes.cpu().numpy()
    spikes = spikes[:max_time, :max_channels]

    plt.figure(figsize=(12, 6))
    for ch in range(spikes.shape[1]):
        spike_times = np.where(spikes[:, ch] == 1)[0]
        plt.vlines(spike_times, ch + 0.5, ch + 1.5, color='black', linewidth=4)
    plt.xlabel('Time step')
    plt.ylabel('Channel')
    plt.title(f'Spike Raster Plot ({spikes.shape[1]} channels)')
    # plt.ylim(0.5, spikes.shape[1] + 0.5)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"图片已保存到: {save_path}")
    else:
        plt.show()
    plt.close()


def remove_zero_columns(matrix):
    """
    去除输入矩阵中所有元素全为0的列

    参数:
        matrix: 输入矩阵 (NumPy ndarray 或 PyTorch Tensor), 形状为 (n_samples, n_features)

    返回:
        去除全零列后的矩阵 (保留原始数据类型)
    """
    if isinstance(matrix, np.ndarray):
        # NumPy 版本
        non_zero_cols = np.any(matrix != 0, axis=0)
        return matrix[:, non_zero_cols]
    elif torch.is_tensor(matrix):
        # PyTorch 版本
        non_zero_cols = torch.any(matrix != 0, dim=0)
        return matrix[:, non_zero_cols]
    else:
        raise TypeError("输入必须是 NumPy 数组或 PyTorch 张量")




class PrimateReaching():

    def __init__(
            self,
            file_path,
            filename,
            num_steps,
            train_ratio=0.5,
            label_series=False,
            biological_delay=0,
            spike_sorting=False,
            stride=0.004,
            bin_width=0.004,
            max_segment_length=2000,
            split_num=1,
            remove_segments_inactive=False,
            download=True,
            predict_value='velocity',
    ):
        """
        Initialises the Dataset for the Primate Reaching Task.

        Args:
            file_path (str): The path to the directory storing the matlab files.
            filename (str): The name of the file that will be loaded.
            num_steps (int): Number of consecutive timesteps that are included per sample.
                             In the real-time case, this should be 1.
            train_ratio (float): ratio for how the dataset will be split into training/(val+test) set.
                                 Default is 0.8 (80% of data is training).
            label_series (bool): Whether the labels are series or not. Useful for training with multiple
                                 timesteps. Default is False.
            biological_delay (int): How many steps of delay is to be applied to the dataset. Default is 0
                                    i.e. no delay applied.
            spike_sorting (bool): Apply spike sorting for processing raw spike data. Default is False.
            stride (float):  How many steps are taken when moving the bin_window. Default is 0.004 (4ms).
            bin_width (float): The size of the bin_window. Default is 0.028 (28ms).
            max_segment_length: Define the upper limits of a segment. Default is 2000 data points (8s)
            split_num (int): The number of chunks to break the timeseries into. Default is 1 (no splits).
            remove_segments_inactive (bool): Whether to remove segments longer than max_segment_length,
                                             which represent subject inactivity. Default is False.
            download (bool): If True, downloads the dataset from the internet and puts it in root
                             directory. If dataset is already downloaded, it will not be downloaded again.

        """
        ##　设置数据路径
        self.filename = filename if filename[-3:] == ".h5" else filename + ".h5"
        self.file_path = os.path.join(file_path, self.filename)

        self.split_num = 1
        self.train_ratio = 0.5
        self.stride = stride
        self.SAMPLING_RATE = stride
        self.delay = biological_delay

        self.bin_width = bin_width
        self.ratio = int(np.round(self.bin_width / self.SAMPLING_RATE))
        self.label_series = label_series
        self.num_steps = num_steps
        assert self.num_steps >= 1
        self.spike_sorting= spike_sorting

        # The samples and labels of the dataset
        self.samples = None
        self.labels = None
        # These lists store the index of segments that belongs to training/validation/test set
        self.ind_train, self.ind_val, self.ind_test = [], [], []
        self.predict_value=predict_value

        # test parameters
        assert self.delay >= 0

        self.load_data()


        # self.split_data()
        return

    def load_data(self):
        """Load the data from the matlab file and spike data if spike data has been
        processed and stored already."""

        # Assume input is the original dataset, instead of the reconstructed one
        print(f"Loading {self.filename}")
        dataset = h5py.File(self.file_path, "r")
        dataset = temp_Data.from_hdf5(dataset)  # 保存dataset对象
        print("Keys in the file:", list(dataset.keys()))

        # 检查是否存在domain信息
        if hasattr(dataset, 'train_domain') and hasattr(dataset, 'valid_domain') and hasattr(dataset, 'test_domain'):
            print("Found predefined domains in dataset")
            self.has_predefined_domains = True
        else:
            print("No predefined domains found, will use ratio-based splitting")
            self.has_predefined_domains = False

        ## 从dataset中提取spikes、cursor_pos、target_pos
        if self.predict_value[0] == 'velocity':
            # self.labels = torch.from_numpy(dataset.cursor.vel).float()
            original_labels = torch.from_numpy(dataset.cursor.vel).float()
            # 上采样从100Hz到250Hz (2.5倍)
            original_length = original_labels.shape[0]
            target_length = int(original_length * 2.5)
            self.labels = F.interpolate(
                original_labels.T.unsqueeze(0),
                size=target_length,
                mode='linear',
                align_corners=False
            ).squeeze(0)
        elif self.predict_value[0] == 'acceleraction':
            self.labels = torch.from_numpy(dataset.cursor.acc).float()
            original_labels = torch.from_numpy(dataset.cursor.acc).float()
            # 上采样从100Hz到250Hz (2.5倍)
            original_length = original_labels.shape[0]
            target_length = int(original_length * 2.5)
            self.labels = F.interpolate(
                original_labels.T.unsqueeze(0),
                size=target_length,
                mode='linear',
                align_corners=False
            ).squeeze(0)
        elif self.predict_value[0] == 'position':
            # self.labels = torch.from_numpy(dataset.cursor.pos).float()
            original_labels = torch.from_numpy(dataset.cursor.pos).float()
            # 上采样从100Hz到250Hz (2.5倍)
            original_length = original_labels.shape[0]
            target_length = int(original_length * 2.5)
            self.labels = F.interpolate(
                original_labels.T.unsqueeze(0),
                size=target_length,
                mode='linear',
                align_corners=False
            ).squeeze(0)



        # self.plot_labels()
        # t=np.arange(dataset.start, dataset.end, self.stride)
        new_t = np.arange(dataset.start - self.bin_width, dataset.end, self.stride)
        units = dataset.units
        s = units.id[-1]
        s = s.decode('utf-8')  # 如果 s 是 bytes 类型，先解码成 str
        # 分割字符串
        parts = s.split('/')
        for part in parts:
            if part.startswith('elec'):
                elec_number = part[4:]  # 去掉 'elec' 前缀
                print(elec_number)  # 输出: '190'
                break
        else:
            print("未找到匹配的数字")
        elec_number = math.ceil(int(elec_number) / 96) * 96
        spikes = np.empty((len(units),elec_number), dtype=object)
        unit_id_to_elec=np.zeros((len(units)))
        for unit_id,unit_number in zip(units.id, units.unit_number):
            s = unit_id.decode('utf-8')  # 如果 s 是 bytes 类型，先解码成 str
            # 分割字符串
            parts = s.split('/')
            for part in parts:
                if part.startswith('elec'):
                    elec_id = part[4:]
                    unit_id_to_elec[unit_number] = int(elec_id)
                    # print(elec_id)
                    break
            else:
                print("未找到匹配的数字")
        assert dataset.spikes.unit_index.max()==len(units)-1, "unit_index should be equal to the number of units - 1"
        for i in range(len(units)):
            spike_timeStample =  np.where(dataset.spikes.unit_index == i)
            spike_timeStample = dataset.spikes.timestamps[spike_timeStample]
            spikes[i, int(unit_id_to_elec[i])] = spike_timeStample

        spikes_mask = np.zeros_like(spikes, dtype=int)  # 创建与spikes形状相同的全0数组
        for i, row in enumerate(spikes):
            for j, col in enumerate(row):
                if col is not None:
                    spikes_mask[i][j] = 1
        non_none_counts = spikes_mask.sum(axis=0).max()

        new_a = np.empty((non_none_counts, 192), dtype=object)
        # 填充数据
        for i in range(spikes.shape[1]):  # 遍历每一列
            col = spikes[:, i]  # 获取当前列
            non_none_indices = np.where([x is not None for x in col])[0]  # 非 None 的行索引
            for j, idx in enumerate(non_none_indices):
                new_a[j, i] = col[idx]  # 填充到新数组

        spikes = new_a  # 替换原数组

        spike_train = np.zeros((*spikes.shape, len(new_t)), dtype=np.int8)
        # iterate over hdf5 dataframe and preprocess data
        for row_idx, row in enumerate(spikes):
            for col_idx, element in enumerate(row):
                # get indices of spikes and convert data to spike train
                if isinstance(element, np.ndarray):
                    bins, _ = np.histogram(element, bins=new_t.squeeze())
                elif element is None:
                    bins = np.zeros_like(new_t.squeeze())

                # histogram is assigns spikes to lower bound of binning window, therefor increment by one to shift to
                # upper bound
                idx = np.nonzero(bins)[0] + 1
                spike_train[row_idx, col_idx, idx] = 1
        if self.spike_sorting:
            # if using spike sorting, reshape # channels x # units into a single dimension => # features
            spike_train = np.transpose(spike_train, (2, 1, 0)).reshape(new_t.shape[0], -1)

            # remove empty channels
            spike_train = spike_train[:, spike_train.any(axis=0)]
            spike_train = spike_train.transpose()
        else:
            # combine units into channels
            spike_train = np.bitwise_or.reduce(spike_train, axis=0)

        # use convolution to compute binning window
        if self.ratio != 1:
            binned_spike_train = convolve2d(
                spike_train, np.ones((1, self.ratio)), mode="valid"
            )
        else:
            binned_spike_train = spike_train
        self.samples = torch.from_numpy(binned_spike_train).float()
        # 对齐时间维度
        min_length = min(self.samples.shape[1], self.labels.shape[1])
        self.samples = self.samples[:, :min_length].T
        self.labels = self.labels[:, :min_length].T

        self.time_segments, self.ind_train, self.ind_val, self.ind_test = self.split_into_segments(dataset)

        return

    def apply_delay(self):
        """Shift the labels by the delay to account for the biological delay between
        spikes and movement onset."""
        # Dimension: No_of_Channels*No_of_Records
        self.samples = self.samples[:, : -self.delay]
        self.labels = self.labels[:, self.delay:]

    # def split_data(self):
    #     """Split segments into training/validation/test set using predefined domains if available."""
    #     if not self.has_predefined_domains:
    #         print("No predefined domains found, using ratio-based splitting")
    #         return
    #
    #     print("Using predefined domains for data splitting...")
    #
    #     # 获取时间段的索引 - 现在是数组
    #     train_start_idx, train_end_idx = self.time_segments[0]
    #     val_start_idx, val_end_idx = self.time_segments[1]
    #     test_start_idx, test_end_idx = self.time_segments[2]
    #
    #     # 创建索引数组 - 处理多个段
    #     self.ind_train = []
    #     self.ind_val = []
    #     self.ind_test = []
    #
    #     # 处理训练集的多个段
    #     for start, end in zip(train_start_idx, train_end_idx):
    #         start = max(0, int(start))
    #         end = min(self.samples.shape[1], int(end))
    #         self.ind_train.extend(list(range(start, end)))
    #
    #     # 处理验证集的多个段
    #     for start, end in zip(val_start_idx, val_end_idx):
    #         start = max(0, int(start))
    #         end = min(self.samples.shape[1], int(end))
    #         self.ind_val.extend(list(range(start, end)))
    #
    #     # 处理测试集的多个段
    #     for start, end in zip(test_start_idx, test_end_idx):
    #         start = max(0, int(start))
    #         end = min(self.samples.shape[1], int(end))
    #         self.ind_test.extend(list(range(start, end)))
    #
    #     # 打印分割信息
    #     print(f"Training set: {len(self.ind_train)} samples from {len(train_start_idx)} segments")
    #     print(f"Validation set: {len(self.ind_val)} samples from {len(val_start_idx)} segments")
    #     print(f"Test set: {len(self.ind_test)} samples from {len(test_start_idx)} segments")
    #
    #     # 检查是否有重叠
    #     train_set = set(self.ind_train)
    #     val_set = set(self.ind_val)
    #     test_set = set(self.ind_test)
    #
    #     if train_set & val_set:
    #         print("Warning: Overlap between training and validation sets!")
    #     if train_set & test_set:
    #         print("Warning: Overlap between training and test sets!")
    #     if val_set & test_set:
    #         print("Warning: Overlap between validation and test sets!")
    #
    #     return

    @staticmethod
    def split_into_segments(dataset):
        # 用于计算每个trial的片段
        dt=0.004
        train_start = dataset.train_domain.start
        train_end = dataset.train_domain.end
        train_stample=[]
        train_stample.append(np.floor(train_start/dt))
        train_stample.append( np.floor(train_end/dt))

        val_start = dataset.valid_domain.start
        val_end = dataset.valid_domain.end
        val_stample=[]
        val_stample.append(np.floor(val_start/dt))
        val_stample.append(np.floor(val_end/dt))

        test_start = dataset.test_domain.start
        test_end = dataset.test_domain.end
        test_stample=[]
        test_stample.append(np.floor(test_start/dt))
        test_stample.append(np.floor(test_end/dt))

        time_segments_len=train_start.shape[0]+ val_start.shape[0]+test_start.shape[0]
        time_segments =  np.zeros((time_segments_len, 2))
        total_train_dur=0
        total_val_dur=0
        total_test_dur=0
        for i in range(train_start.shape[0]):
            time_segments[i, 0] = int(train_stample[0][i])
            time_segments[i, 1] = int(train_stample[1][i])
            total_train_dur+= int(train_stample[1][i]) - int(train_stample[0][i])
        for i in range(val_start.shape[0]):
            time_segments[i + train_start.shape[0], 0] = int(val_stample[0][i])
            time_segments[i + train_start.shape[0], 1] = int(val_stample[1][i])
            total_val_dur += int(val_stample[1][i]) - int(val_stample[0][i])
        for i in range(test_start.shape[0]):
            time_segments[i + train_start.shape[0] + val_start.shape[0], 0] = int(test_stample[0][i])
            time_segments[i + train_start.shape[0] + val_start.shape[0], 1] = int(test_stample[1][i])
            total_test_dur += int(test_stample[1][i]) - int(test_stample[0][i])

        ind_train = list(np.arange(train_start.shape[0]))
        ind_val = list(np.arange(val_start.shape[0])+len(train_start))
        test_ind = list(np.arange(test_start.shape[0])+len(val_start)+len(train_start))

        return time_segments,ind_train,ind_val,test_ind

    def read_test(self):
        # 打开 .mat 文件
        with h5py.File(self.file_path, 'r') as f:
            # 提取 data_diff1 和 data_fir
            data_diff1 = f['data_diff1'][()]  # ��取第一个 cell 数组
            data_fir = f['data_fir'][()]  # 提取第二个 cell 数组

            # 检查数据
            print(f"data_diff1 的形状: {data_diff1.shape}")  # 应该是 (1, 232)
            print(f"data_fir 的形状: {data_fir.shape}")  # 应该是 (1, 232)

            # 访问第一个 trial 的数据
            trial_diff1_ref = data_diff1[0, 0]  # 第一个 trial 的引用
            trial_fir_ref = data_fir[0, 0]  # 第一个 trial 的引用

            # 通过引用提取实际数据
            trial_diff1 = f[trial_diff1_ref][()]  # 第一个 trial 的 2D 坐标数据
            trial_fir = f[trial_fir_ref][()]  # 第一个 trial 的神经信号数据

            print(f"第一个 trial 的 2D 坐标数据形状: {trial_diff1.shape}")
            print(f"第一个 trial 的神经信号数据形状: {trial_fir.shape}")

    def plot_labels(self, save_path=None, max_samples=1000):
        if self.labels is None:
            print("Label data is empty, cannot plot")
            return

        # Convert to numpy array for easier processing
        labels_np = self.labels.numpy() if isinstance(self.labels, torch.Tensor) else self.labels

        # Sample if data is too large
        if labels_np.shape[0] > max_samples:
            step = labels_np.shape[0] // max_samples
            labels_np = labels_np[::step]

        fig, axes = plt.subplots(2, 1, figsize=(12, 8))

        # Plot X direction data
        axes[0].plot(labels_np[:, 0], 'b-', linewidth=0.8)
        axes[0].set_title(f'{self.predict_value[0].capitalize()} - X Direction')
        axes[0].set_ylabel('X Value')
        axes[0].grid(True, alpha=0.3)

        # Plot Y direction data
        axes[1].plot(labels_np[:, 1], 'r-', linewidth=0.8)
        axes[1].set_title(f'{self.predict_value[0].capitalize()} - Y Direction')
        axes[1].set_ylabel('Y Value')
        axes[1].set_xlabel('Time Steps')
        axes[1].grid(True, alpha=0.3)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Image saved to: {save_path}")
        else:
            plt.show()

        # Print statistics
        print(f"Label data statistics:")
        print(f"Shape: {self.labels.shape}")
        print(f"X direction - Mean: {labels_np[:, 0].mean():.4f}, Std: {labels_np[:, 0].std():.4f}")
        print(f"Y direction - Mean: {labels_np[:, 1].mean():.4f}, Std: {labels_np[:, 1].std():.4f}")

    @staticmethod
    def plot_spikes(spikes, max_time=1000):
        """
        绘制spikes raster图。
        参数:
            spikes: IrregularTimeSeries对象，需包含timestamps和unit_index属性
            max_time: 绘图的最大时间（同spikes的时间单位）
        """
        timestamps = np.array(spikes.timestamps)
        unit_index = np.array(spikes.unit_index)
        # 根据max_time限制绘制的时间范围
        mask = timestamps <= max_time
        timestamps = timestamps[mask]
        unit_index = unit_index[mask]
        plt.figure(figsize=(10, 4))
        plt.scatter(timestamps, unit_index, s=2, color='black')
        plt.xlabel('Time (s)')
        plt.ylabel('Unit')
        plt.title('Spike Raster Plot')
        plt.xlim(0, max_time)
        plt.tight_layout()
        plt.show()

