"""
This file contains code from PyTorch Vision (https://github.com/pytorch/vision) which is licensed under BSD 3-Clause License.
These snippets are the Copyright (c) of Soumith Chintala 2016. All other code is the Copyright (c) of the NeuroBench Developers 2023.
"""


from torch.utils.data import Dataset
import os
import torch
import math
import numpy as np
import h5py
from scipy.signal import convolve2d
from urllib.error import URLError
import matplotlib.pyplot as plt


import stork
from neurobench.datasets.primate_reaching import PrimateReaching
from neurobench.datasets.utils import download_url


import logging
logger = logging.getLogger(__name__)
# The spikes recorded in the Primate Reaching datasets have an interval of 4ms.
# SAMPLING_RATE = 4e-3

def get_dataloader_Mice(cfg, dtype=torch.float32):

    dataloader = DatasetLoader_Mice(
        basepath=cfg.data.data_dir,
        ratio_val=cfg.data.ratio_val,
        random_val=cfg.data.random_val,
        extend_data=cfg.data.extend_data,
        sample_duration=cfg.data.sample_duration,
        remove_segments_inactive=cfg.data.remove_segments_inactive,
        p_drop=cfg.data.p_drop,
        p_insert=cfg.data.p_insert,
        jitter_sigma=cfg.data.jitter_sigma,
        dtype=dtype,
        dt=cfg.data.dt
    )

    return dataloader

def compute_input_firing_rates(data, cfg, nb_inputs):

    mean1 = 0
    mean2 = 0

    if nb_inputs==16:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :nb_inputs]) / cfg.data.sample_duration / nb_inputs
            mean1 /= len(data)
            return mean1, None
    else:
        for i in range(len(data)):
            mean1 += torch.sum(data[i][0][:, :96]) / cfg.data.sample_duration / 96
            try:
                mean2 += torch.sum(data[i][0][:, 96:]) / cfg.data.sample_duration / 96
            except:
                continue

        mean1 /= len(data)
        mean2 /= len(data)

        # For LOCO
        if data[0][0].shape[1] == 192:
            return mean1, mean2

        # FOR INDY
        else:
            return mean1, None

class DatasetLoader_Mice:
    def __init__(
            self,
            basepath,
            num_steps=1,
            dt=0.004,
            ratio_val=0.1,  ###
            biological_delay=0,
            spike_sorting=False,
            label_series=False,
            random_val=False,
            extend_data=True,
            sample_duration=2,
            bin_width=None,
            stride=None,
            remove_segments_inactive=False,
            dtype=torch.float32,
            p_drop=0.0,
            p_insert=0.0,
            jitter_sigma=0.0,
    ):
        """Initialize

        Args:
            basepath (str): the path to the data folder
            filename (str): the name of the specific data file to load
            num_steps (int, optional): Argument for the Neurobench dataloader. Should be 1. Defaults to 1.
            dt (float, optional): Time step, should be 0.004 for the monkey data. Defaults to 0.004.
            ratio_val (list, optional): Ratio for validation set. Defaults to 0.25
            biological_delay (int, optional): Delay of readout w.r.t input. Defaults to 0.
            spike_sorting (bool, optional): If True, using single unit activities, otherwise multi unit activities. Defaults to False.
            label_series (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            random_val (bool, optional): If True, samples the validation samples randomly from the train data. Otherwise takes the last samples. Defaults to False.
            extend_data (bool, optional): If true, extends the data to overlapping samples. Defaults to True.
            sample_duration (int, optional): sample duration in seconds. Defaults to 2.
            bin_width (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            stride (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            remove_segments_inactive (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            dtype (_type_, optional): The dtype of the datasets. Defaults to torch.float32.
        """

        self.SAMPLING_RATE = 1/dt

        self.basepath = basepath
        self.num_steps = num_steps
        self.dt = dt
        self.ratio_val = ratio_val
        self.biological_delay = biological_delay
        self.spike_sorting = spike_sorting
        self.label_series = label_series
        self.random_val = random_val
        self.extend_data = extend_data
        self.sample_duration = sample_duration
        self.remove_segments_inactive = remove_segments_inactive
        self.dtype = dtype
        self.p_drop = p_drop
        self.p_insert = p_insert
        self.jitter_sigma = jitter_sigma


        if bin_width is None:
            self.bin_width = self.dt
        else:
            self.bin_width = bin_width
        if stride is None:
            self.stride = self.dt
        else:
            self.stride = stride

        # self.fileName = 'mice3689_0525.mat'
        # self.basepath = 'D:\\PPPPProject\\SNN_Environment\\neural-decoding-RSNN-main\\monkeyData\\mice3689_0525.mat'

        # self.random_val=False
        # self.ratio_val=0.1
        # self.extend_data=True
        # self.dt = 0.004
        # self.sample_duration = 10
        self.n_time_steps = int(sample_duration / dt)
        # self.n_time_steps=500

        # self.dtype = torch.float32

        # self.p_drop = 0
        # self.p_insert = 0
        # self.jitter_sigma = 0

        return

    def get_single_session_data_Mice(self,filename):

        dataset = PrimateReaching_Mice(
            file_path=self.basepath,
            filename=filename,
            num_steps=self.num_steps,
            train_ratio=0.5,  # Hardcoded here for 25 % test split
            bin_width=self.bin_width,
            biological_delay=self.biological_delay,
            remove_segments_inactive=self.remove_segments_inactive,
            stride=self.dt,
        )
        dataset_test = dataset

        """Loads data of a single session and returns a tuple of stork RasDatasets containing the train, val and test data.

                Returns:
                    tuple of stork RasDatasets: Train, val and test datasets
                """

        # Sum train & validation data (75 %) and make own validation split ##合并验证集和测试集
        ind_tv = dataset.ind_train + dataset.ind_val

        # Effective validation ratio = val_ratio / 0.75  ## 验证集的有效比例，重新划分验证集
        eff_ratio_val = self.ratio_val / 0.75

        n_val = int(np.round(dataset.samples.shape[0] * eff_ratio_val))  ##　验证集的大小

        ## 如果 self.random_val 为 True，则从 ind_tv 中随机选择 n_val 个样本作为验证集，并将剩余的样本作为训练集。
        ## 否则，按照顺序分割训练集和验证集。
        if self.random_val:
            start_idx = np.random.choice(a=ind_tv[:-n_val], size=1)[0]
            ind_val = np.array(ind_tv[start_idx: start_idx + n_val])
            ind_train = np.array(sorted(set(ind_tv) - set(ind_val)))
        else:
            ind_train = np.array(ind_tv[:-n_val])
            ind_val = np.array(sorted(set(ind_tv) - set(ind_train)))

        spikes = dataset.samples
        labels = dataset.labels

        spikes_testdat = dataset_test.samples
        labels_testdat = dataset_test.labels

        self.ind_train = ind_train
        self.ind_val = ind_val
        self.ind_test = dataset_test.ind_test

        # ## labels归一化
        # # Min-Max 归一化
        # min_vals = labels.min()  # 每一行的最小值
        # max_vals = labels.max()  # 每一行的最大值
        # normalized_data_min_max = (labels - min_vals) / (max_vals - min_vals)
        # labels=normalized_data_min_max



        # split into train, val and test  ## 数据分割成训练集、验证集和测试集
        spikes_train = spikes[ind_train]
        spikes_val = spikes[ind_val]
        spikes_test = spikes_testdat[dataset_test.ind_test]

        labels_train = labels[ind_train]
        labels_val = labels[ind_val]
        labels_test = labels_testdat[dataset_test.ind_test]

        # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
        if self.extend_data:
            logger.info("Extending data...")
            train_data, train_labels = self.extend_spikes(
                spikes_train, labels_train, self.n_time_steps, chunksize=int((1/self.dt)/2.5)
            )
            val_data, val_labels = self.extend_spikes(
                spikes_val, labels_val, self.n_time_steps, chunksize=int((1/self.dt)/2.5)
            )
        else:
            train_data, train_labels = self.extend_spikes(
                spikes_train, labels_train, self.n_time_steps, chunks=99
            )
            val_data, val_labels = self.extend_spikes(
                spikes_val, labels_val, self.n_time_steps, chunks=99
            )

        test_data = [spikes_test]  ## 封装成列表
        test_labels = [labels_test]

        test_data = torch.stack(test_data)
        test_labels = torch.stack(test_labels)

        # Get augmentation kwargs for training dataset
        if any([self.p_drop > 0, self.p_insert > 0, self.jitter_sigma > 0]):

            data_augmentation_kwargs = dict(
                data_augmentation=True,
                p_drop=self.p_drop,
                p_insert=self.p_insert,
                sigma_t=self.jitter_sigma
            )
        else:
            data_augmentation_kwargs = {}

        # make it ras datasets
        train_ras_data = self.to_ras(train_data, train_labels,
                                     **data_augmentation_kwargs)
        val_ras_data = self.to_ras(val_data, val_labels)
        test_ras_data = self.to_ras(test_data, test_labels)

        return train_ras_data, val_ras_data, test_ras_data


    def extend_spikes(self, spikes, labels, chunks="all", chunksize=100):
        """Given spike data and labels of the shape [time x neuron], it cuts it into overlapping samples of shape [samples x n_time_steps x neuron]"""
        ##  将给定的脉冲数据（spikes）和标签（labels）按时间切割成重叠的样本，并将其转换为一个三维的数组，方便后续处理和训练。重叠数：chunks

        if chunks == "all":
            chunks = self.n_time_steps

        extended_spikes = []
        extended_labels = []

        for t in range(0, chunks, chunksize):
            curr_spikes = spikes[t:]
            curr_labels = labels[t:]

            splitter = np.arange(
                self.n_time_steps, curr_spikes.shape[0], self.n_time_steps
            )

            extended_spikes += np.split(curr_spikes, splitter)[:-1]
            extended_labels += np.split(curr_labels, splitter)[:-1]

        ## 张量拼接
        extended_spikes = torch.stack(extended_spikes)
        extended_labels = torch.stack(extended_labels)

        return extended_spikes, extended_labels

    def to_ras(self, data, labels, **data_augmentation_kwargs):
        ras_data = [[[], []] for _ in data]

        for i, sample in enumerate(data):
            for j in range(sample.shape[-1]):
                spike_times = np.where(sample[:, j] == 1)[0].tolist()
                ras_data[i][0] += spike_times
                ras_data[i][1] += [j] * len(spike_times)
            ras_data[i] = torch.tensor(ras_data[i], dtype=self.dtype)

        monkey_ds_kwargs = dict(
            nb_steps=data.shape[-2], nb_units=data.shape[-1], time_scale=1.0
        )

        monkey_ds = stork.datasets.RasDataset(
            (ras_data, labels), dtype=self.dtype,
            **monkey_ds_kwargs, **data_augmentation_kwargs
        )

        return monkey_ds
























class PrimateReaching_Mice():

    def __init__(
            self,
            file_path,
            filename,
            num_steps,
            train_ratio=0.5,
            label_series=False,
            biological_delay=0,
            spike_sorting=False,
            stride=0.004,
            bin_width=0.028,
            max_segment_length=2000,
            split_num=1,
            remove_segments_inactive=False,
            download=True,
    ):
        ##　设置数据路径
        self.filename = filename if filename[-4:] == ".mat" else filename + ".mat"
        self.file_path = os.path.join(file_path, self.filename)

        self.segmentBinNum=20
        self.split_num=1
        self.train_ratio=0.5
        self.stride=stride
        self.SAMPLING_RATE = stride

        # The samples and labels of the dataset
        self.samples = None
        self.labels = None
        # These lists store the index of segments that belongs to training/validation/test set
        self.ind_train, self.ind_val, self.ind_test = [], [], []


        self.load_data()
        self.valid_segments = np.arange(self.time_segments.shape[0])
        self.split_data()
        return

    def load_data(self):
        """Load the data from the matlab file and spike data if spike data has been
        processed and stored already."""

        # Assume input is the original dataset, instead of the reconstructed one
        print(f"Loading {self.filename}")
        dataset = h5py.File(self.file_path, "r")
        print("Keys in the file:", list(dataset.keys()))
        ## 从dataset中提取spikes、cursor_pos、target_pos
        Fs = dataset['Fs'][()]
        SpkTime = dataset["SpkTime"][()]  # Get the reference object's locations in the HDF5/mat file
        SpkCh = dataset["SpkCh"][()]
        SpkUnit = dataset["SpkUnit"][()]
        startTime = dataset["startTime"][()]
        endTime = dataset["endTime"][()]
        preData = dataset["preData2"][()]
        Fs = dataset["Fs"][()]
        spike_ch = dataset["spike_ch"][()]
        Fir_ch = dataset["Fir_ch"][()]

        ## labels归一化
        # Min-Max 归一化
        min_vals = preData.min()  # 每一行的最小值
        max_vals = preData.max()  # 每一行的最大值
        normalized_data_min_max = (preData - min_vals) / (max_vals - min_vals)
        preData=normalized_data_min_max

        # Define the segments' start & end indices
        durationTime=endTime-startTime
        self.segmentBins=durationTime/self.segmentBinNum
        self.start_end_indices = np.rint(np.arange(0,preData.shape[0],preData.shape[0]/self.segmentBinNum,dtype=np.int32))
        ## 根据target的变化给每个session分段，变成一个个trial的开始和结束时间
        self.time_segments = np.array(
            self.split_into_segments(self.start_end_indices, preData.shape[0])
        )

        # Dimensions: (channels x timesteps)  ## 将数据转换为 PyTorch 张量
        self.samples = torch.from_numpy(Fir_ch).float()
        # Dimensions: (nr_features x timesteps)  ## 将 NumPy 数组转换为 PyTorch 张量，并转换为浮点类型。
        self.labels = torch.from_numpy(preData).float()

        return

    def split_data(self):
        """Split segments into training/validation/test set."""
        # This is No. of chunks
        split_num = self.split_num
        total_segments = self.time_segments.shape[0]  ## trial 总数
        sub_length = int(
            total_segments / split_num
        )  # This is no of segments in each chunk  ## 每个块内的片段数
        stride = int(self.stride / self.SAMPLING_RATE) ## 步幅调整为相对于采样率的步长
        # print(total_segments, sub_length)

        train_len = math.floor(self.train_ratio * sub_length)
        val_len = math.floor((sub_length - train_len) / 2)

        # offset = int(np.round(self.bin_width / self.SAMPLING_RATE)) * self.num_steps
        offset = 0

        # split the data into 4 equal parts
        # for each part, split the data according to training, testing and validation split
        for split_no in range(split_num):
            for i in range(sub_length):
                # Each segment's Dimension is: No_of_Probes * No_of_Recording
                ## 训练集
                if i < train_len and i in self.valid_segments:
                    self.ind_train += list(
                        np.arange(
                            offset + self.time_segments[split_no * sub_length + i, 0],
                            self.time_segments[split_no * sub_length + i, 1],
                            stride,
                        )
                    )
                ## 验证集
                elif train_len <= i < train_len + val_len and i in self.valid_segments:
                    self.ind_val += list(
                        np.arange(
                            offset + self.time_segments[split_no * sub_length + i, 0],
                            self.time_segments[split_no * sub_length + i, 1],
                            stride,
                        )
                    )
                ## 测试集
                elif i in self.valid_segments:
                    self.ind_test += list(
                        np.arange(
                            offset + self.time_segments[split_no * sub_length + i, 0],
                            self.time_segments[split_no * sub_length + i, 1],
                            stride,
                        )
                    )

        return
    @staticmethod
    def split_into_segments(indices, last_idx):
        """Combine the start and end index into a NumPy array."""
        # indices = np.insert(indices, 0, 0)
        indices = np.append(indices, [last_idx])
        start_end = np.array([indices[:-1], indices[1:]])

        return np.transpose(start_end)
