# Data Loaders for NeuroBench Challenge

import numpy as np
import torch
import stork
from neurobench.datasets.primate_reaching import PrimateReaching
from neurobench.datasets.utils import download_url
from urllib.error import URLError

import os
import sys

import logging

logger = logging.getLogger(__name__)
SAMPLING_RATE = 4e-3



def get_dataloader_foundation(cfg, dtype=torch.float32):

    if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
        output_feedback=cfg.model.output_feedback
    elif hasattr(cfg.model, "self_and_crossAttention") and cfg.model.self_and_crossAttention:
        output_feedback = cfg.model.self_and_crossAttention
    else:
        output_feedback=False
    if hasattr(cfg, "session_classfication") and cfg.session_classfication:
        session_classfication=cfg.session_classfication
    else:
        session_classfication=False

    dataloader = DatasetLoader_foundation(
        basepath=cfg.data.data_dir,
        ratio_val=cfg.data.ratio_val,
        random_val=cfg.data.random_val,
        extend_data=cfg.data.extend_data,
        sample_duration=cfg.data.sample_duration,
        remove_segments_inactive=cfg.data.remove_segments_inactive,
        p_drop=cfg.data.p_drop,
        p_insert=cfg.data.p_insert,
        jitter_sigma=cfg.data.jitter_sigma,
        dtype=dtype,
        output_feedback=output_feedback,
        session_classfication=session_classfication,
        testFlag=cfg.testFlag,
        padding=cfg.data.padding,
    )

    return dataloader


# def compute_input_firing_rates(data, cfg, nb_inputs):
#
#     mean1 = 0
#     mean2 = 0
#
#     if hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback:
#         output_feedback=cfg.model.output_feedback
#     else:
#         output_feedback=False
#
#     if nb_inputs==16:
#         for i in range(len(data)):
#             mean1 += torch.sum(data[i][0][:, :nb_inputs]) / cfg.data.sample_duration / nb_inputs
#             mean1 /= len(data)
#             return mean1, None
#     else:
#         for i in range(len(data)):
#             mean1 += torch.sum(data[i][0][:, :96]) / cfg.data.sample_duration / 96
#             try:
#                 mean2 += torch.sum(data[i][0][:, 96:193]) / cfg.data.sample_duration / 96
#             except:
#                 continue
#
#         mean1 /= len(data)
#         mean2 /= len(data)
#
#         if output_feedback:
#             mean3 = 0
#             for i in range(len(data)):
#                 mean3 += torch.sum(data[i][1]) / cfg.data.sample_duration / 2
#             mean3 /= len(data)
#
#
#             # For LOCO
#             if hasattr(cfg.data, "data_len") and cfg.data.data_len == 192:
#                 return mean1, mean2, mean3
#             # elif data[0][0].shape[1] == 192:
#             elif data[0][0].shape[1] >= 192:
#                 return mean1, mean2, mean3
#             # FOR INDY
#             else:
#                 return mean1, None, mean3
#
#
#         else:
#             # 预训练只用mean1做初始化
#             if len(data)>12000:
#                 return mean1, None
#             # For LOCO
#             if hasattr(cfg.data,"data_len") and cfg.data.data_len == 192:
#                 return mean1, mean2
#             # elif data[0][0].shape[1] == 192:
#             elif data[0][0].shape[1] >= 192:
#                 return mean1, mean2
#
#             # FOR INDY
#             else:
#                 return mean1, None

def compute_input_firing_rates(data, cfg, nb_inputs):
    """计算输入脉冲率的优化版本，通过向量化操作加速计算"""

    # 检查输出反馈设置
    output_feedback = hasattr(cfg.model, "output_feedback") and cfg.model.output_feedback

    # 特殊情况：nb_inputs=16的快速路径
    if nb_inputs == 16:
        # 向量化操作：直接计算所有样本的总和然后归一化
        inputs_data = torch.stack([d[0][:, :nb_inputs] for d in data])
        mean1 = inputs_data.sum() / (cfg.data.sample_duration * nb_inputs * len(data))
        return mean1, None

    # 一般情况处理
    # 预分配张量以存储每个样本的结果
    first_channel_sums = torch.zeros(len(data), device=data[0][0].device)
    second_channel_sums = torch.zeros(len(data), device=data[0][0].device)
    valid_second_channels = torch.zeros(len(data), dtype=torch.bool, device=data[0][0].device)

    # 单次循环处理所有数据
    for i, sample in enumerate(data):
        # 计算第一通道
        first_channel_sums[i] = sample[0][:, :96].sum()

        # 尝试计算第二通道，但避免使用try-except
        if sample[0].shape[1] > 96:
            try:
                second_channel_sums[i] = sample[0][:, 96:193].sum()
                valid_second_channels[i] = True
            except:
                pass

    # 计算平均值
    mean1 = first_channel_sums.sum() / (cfg.data.sample_duration * 96 * len(data))
    mean2 = second_channel_sums.sum() / (cfg.data.sample_duration * 96 * valid_second_channels.sum().item()) if valid_second_channels.any() else 0

    # 处理输出反馈情况
    if output_feedback:
        # 向量化计算mean3
        feedback_sums = torch.stack([d[1].sum() for d in data])
        mean3 = feedback_sums.sum() / (cfg.data.sample_duration * 2 * len(data))

        # 基于数据特征返回结果
        if hasattr(cfg.data, "data_len") and cfg.data.data_len == 192 or data[0][0].shape[1] >= 192:
            return mean1, mean2, mean3
        else:
            return mean1, None, mean3
    else:
        # 预训练情况
        if len(data) > 12000:
            return mean1, None

        # 基于数据特征返回结果
        if hasattr(cfg.data, "data_len") and cfg.data.data_len == 192 or data[0][0].shape[1] >= 192:
            return mean1, mean2
        else:
            return mean1, None




class PretrainPrimateReaching(PrimateReaching):
    """
    Load more sessions as dataset for the Primate Reaching Task with modified MD5 checksums.
    """

    def __init__(
        self,
        file_path,
        filename,
        num_steps,
        train_ratio=0.8,
        label_series=False,
        biological_delay=0,
        spike_sorting=False,
        stride=0.004,
        bin_width=0.028,
        max_segment_length=2000,
        split_num=1,
        remove_segments_inactive=False,
        download=True,
    ):

        ## super调用了父类 PrimateReaching 的构造函数，并将当前类的所有参数传递给父类。这确保了 PretrainPrimateReaching 类继承了 PrimateReaching 类的功能，并初始化了父类的所有属性。
        super().__init__(
            file_path,
            filename,
            num_steps,
            train_ratio,
            label_series,
            biological_delay,
            spike_sorting,
            stride,
            bin_width,
            max_segment_length,
            split_num,
            remove_segments_inactive,
            download,
        )

    def download(self):
        """Download the Primate Reaching data if it doesn't exist already."""

        if self.filename in self.md5s.keys():
            md5 = self.md5s[self.filename]
        else:
            md5 = None

        if self._check_exists(self.file_path, md5):
            return

        os.makedirs(os.path.dirname(self.file_path), exist_ok=True)

        # download file
        url = f"{self.url}{self.filename}"
        try:
            print(f"Downloading {url}")
            download_url(url, self.file_path, md5=md5)
        except URLError as error:
            print(f"Failed to download (trying next):\n{error}")
        finally:
            print()


class DatasetLoader_foundation:
    """Loads the data from the PrimateReaching dataset and splits it into train, val and test sets. The train and valid sets are split into samples of a given length, while the test set is kept as a single sample. The data is returned as a tuple of stork RasDatasets. This datasets can then be used as usual with the stork StandardGenerator."""

    def __init__(
        self,
        basepath,
        num_steps=1,
        dt=0.004,
        ratio_val=0.25,
        biological_delay=0,
        spike_sorting=False,
        label_series=False,
        random_val=False,
        extend_data=True,
        sample_duration=2,
        bin_width=None,
        stride=None,
        remove_segments_inactive=False,
        dtype=torch.float32,
        p_drop=0.0,
        p_insert=0.0,
        jitter_sigma=0.0,
        output_feedback=False,
        session_classfication=False,
        testFlag=False,
        padding="zeros",
    ):
        """Initialize

        Args:
            basepath (str): the path to the data folder
            filename (str): the name of the specific data file to load
            num_steps (int, optional): Argument for the Neurobench dataloader. Should be 1. Defaults to 1.
            dt (float, optional): Time step, should be 0.004 for the monkey data. Defaults to 0.004.
            ratio_val (list, optional): Ratio for validation set. Defaults to 0.25
            biological_delay (int, optional): Delay of readout w.r.t input. Defaults to 0.
            spike_sorting (bool, optional): If True, using single unit activities, otherwise multi unit activities. Defaults to False.
            label_series (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            random_val (bool, optional): If True, samples the validation samples randomly from the train data. Otherwise takes the last samples. Defaults to False.
            extend_data (bool, optional): If true, extends the data to overlapping samples. Defaults to True.
            sample_duration (int, optional): sample duration in seconds. Defaults to 2.
            bin_width (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            stride (_type_, optional): Some neurobench argument. Just leave as is. Defaults to None.
            remove_segments_inactive (bool, optional): Some neurobench argument. Just leave as is. Defaults to False.
            dtype (_type_, optional): The dtype of the datasets. Defaults to torch.float32.
        """
        # 获取当前服务器的特征（例如检查主目录）
        home_dir = os.path.expanduser("~")
        if sys.platform == "win32":
            print("Windows")
            basepath = basepath.replace("/home/User/Data_dir/", "E:/User/")
        elif sys.platform == "linux":
            print("Linux")
            if home_dir == "/home2/User":
                # 如果是在服务器2上，路径不需要做改变
                basepath = basepath.replace("/home/", "/home2/")

        self.basepath = basepath
        self.num_steps = num_steps
        self.dt = dt
        self.ratio_val = ratio_val
        self.biological_delay = biological_delay
        self.spike_sorting = spike_sorting
        self.label_series = label_series
        self.random_val = random_val
        self.extend_data = extend_data
        self.sample_duration = sample_duration
        self.remove_segments_inactive = remove_segments_inactive
        self.dtype = dtype
        self.p_drop = p_drop
        self.p_insert = p_insert
        self.jitter_sigma = jitter_sigma
        self.output_feedback=output_feedback
        self.session_classfication=session_classfication
        self.testFlag=testFlag
        self.padding=padding

        if bin_width is None:
            self.bin_width = self.dt
        else:
            self.bin_width = bin_width
        if stride is None:
            self.stride = self.dt
        else:
            self.stride = stride

        self.n_time_steps = int(sample_duration / dt)

    def get_single_session_data(self, filename, nb_inputs, with_S1=False, only_S1=False, zscore=False, session_code=None):
        dataset = PretrainPrimateReaching(
            file_path=self.basepath,
            filename=filename,
            num_steps=self.num_steps,
            train_ratio=0.5,  # Hardcoded here for 25 % test split
            bin_width=self.bin_width,
            biological_delay=self.biological_delay,
            remove_segments_inactive=self.remove_segments_inactive,
        )

        # If we want to remove inactive segments, we need to load the data again
        # with remove_segments_inactive=False for the test set
        if self.remove_segments_inactive:
            dataset_test = PretrainPrimateReaching(
                file_path=self.basepath,
                filename=filename,
                num_steps=self.num_steps,
                train_ratio=0.5,  # Hardcoded here for 25 % test split
                bin_width=self.bin_width,
                biological_delay=self.biological_delay,
                remove_segments_inactive=False,
            )
        else:
            dataset_test = dataset

        """Loads data of a single session and returns a tuple of stork RasDatasets containing the train, val and test data.

        Returns:
            tuple of stork RasDatasets: Train, val and test datasets
        """

        # Sum train & validation data (75 %) and make own validation split ##合并验证集和测试集
        ind_tv = dataset.ind_train + dataset.ind_val

        # Effective validation ratio = val_ratio / 0.75  ## 验证集的有效比例，重新划分验证集
        eff_ratio_val = self.ratio_val / 0.75

        n_val = int(np.round(len(dataset) * eff_ratio_val))  ##　验证集的大小

        ## 如果 self.random_val 为 True，则从 ind_tv 中随机选择 n_val 个样本作为验证集，并将剩余的样本作为训练集。
        ## 否则，按照顺序分割训练集和验证集。
        if self.random_val:
            start_idx = np.random.choice(a=ind_tv[:-n_val], size=1)[0]
            ind_val = np.array(ind_tv[start_idx : start_idx + n_val])
            ind_train = np.array(sorted(set(ind_tv) - set(ind_val)))

        else:
            ind_train = np.array(ind_tv[:-n_val])
            ind_val = np.array(sorted(set(ind_tv) - set(ind_train)))

        spikes = dataset.samples.T  ## T代表转置（Transpose）
        spikes_testdat = dataset_test.samples.T
        labels = dataset.labels.T
        labels_testdat = dataset_test.labels.T
        if zscore:
            # z-score标准化
            labels = (labels - labels.mean()) / labels.std()
            labels_testdat = (labels_testdat - labels_testdat.mean()) / labels_testdat.std()


        self.ind_train = ind_train
        self.ind_val = ind_val
        self.ind_test = dataset_test.ind_test

        # split into train, val and test  ## 数据分割成训练集、验证集和测试集
        if with_S1:
            if spikes.shape[1]==96:
                spikes_train = spikes[ind_train][:, 0:nb_inputs]
                spikes_val = spikes[ind_val][:, 0:nb_inputs]
                spikes_test = spikes_testdat[dataset_test.ind_test][:, 0:nb_inputs]

                labels_train = labels[ind_train]
                labels_val = labels[ind_val]
                labels_test = labels_testdat[dataset_test.ind_test]

            elif spikes.shape[1]==192:
                spikes_train = torch.cat((spikes[ind_train][:, :nb_inputs],spikes[ind_train][:, nb_inputs:]))
                spikes_val = torch.cat((spikes[ind_val][:, :nb_inputs], spikes[ind_val][:, nb_inputs:]))
                spikes_test = torch.cat((spikes_testdat[dataset_test.ind_test][:, :nb_inputs], spikes_testdat[dataset_test.ind_test][:, nb_inputs:]))

                labels_train = torch.cat((labels[ind_train],labels[ind_train]))
                labels_val = torch.cat((labels[ind_val],labels[ind_val]))
                labels_test = torch.cat((labels_testdat[dataset_test.ind_test],labels_testdat[dataset_test.ind_test]))

            else:
                raise ValueError("data dimension err!")

        elif only_S1:
            spikes_train = spikes[ind_train][:, nb_inputs:]
            spikes_val = spikes[ind_val][:, nb_inputs:]
            spikes_test = spikes_testdat[dataset_test.ind_test][:, nb_inputs:]

            labels_train = labels[ind_train]
            labels_val = labels[ind_val]
            labels_test = labels_testdat[dataset_test.ind_test]

        else:
            if nb_inputs==192 and spikes.shape[1] == 96:

                if self.padding=="zeros":
                    print("Padding with zeros...")
                    zeros_train = torch.zeros_like(spikes[ind_train])
                    zeros_val = torch.zeros_like(spikes[ind_val])
                    zeros_test = torch.zeros_like(spikes_testdat[dataset_test.ind_test])

                    spikes_train = torch.cat((spikes[ind_train], zeros_train), dim=1)
                    spikes_val = torch.cat((spikes[ind_val], zeros_val), dim=1)
                    spikes_test = torch.cat((spikes_testdat[dataset_test.ind_test], zeros_test), dim=1)

                    labels_train = labels[ind_train]
                    labels_val = labels[ind_val]
                    labels_test = labels_testdat[dataset_test.ind_test]
                else:
                    print("Padding with copying...")
                    spikes_train = torch.cat((spikes[ind_train], spikes[ind_train]), dim=1)
                    spikes_val = torch.cat((spikes[ind_val], spikes[ind_val]), dim=1)
                    spikes_test = torch.cat((spikes_testdat[dataset_test.ind_test], spikes_testdat[dataset_test.ind_test]), dim=1)

                    labels_train = labels[ind_train]
                    labels_val = labels[ind_val]
                    labels_test = labels_testdat[dataset_test.ind_test]

            else:
                spikes_train = spikes[ind_train][:, 0:nb_inputs]
                spikes_val = spikes[ind_val][:, 0:nb_inputs]
                spikes_test = spikes_testdat[dataset_test.ind_test][:, 0:nb_inputs]

                labels_train = labels[ind_train]
                labels_val = labels[ind_val]
                labels_test = labels_testdat[dataset_test.ind_test]

        if self.output_feedback:
            shifted_data = torch.zeros_like(labels_train)
            shifted_data[1:] = labels_train[:-1]
            # shifted_data[0] = shifted_data[1]
            spikes_train = torch.cat((spikes_train, shifted_data), dim=1)

        if session_code is not None:
            if self.session_classfication:
                # 将session code补到labels中
                session_code=torch.tensor(session_code, dtype=self.dtype).unsqueeze(0)
                session_code_train = session_code.repeat(labels_train.shape[0], 1)
                session_code_val = session_code.repeat(labels_val.shape[0], 1)
                session_code_test = session_code.repeat(labels_test.shape[0], 1)

                labels_train = torch.cat((labels_train, session_code_train), dim=1)
                labels_val = torch.cat((labels_val, session_code_val), dim=1)
                labels_test = torch.cat((labels_test, session_code_test), dim=1)
            else:
                # 将session code补到spikes中
                session_code = [int(char) for char in session_code]
                session_code = torch.tensor(session_code, dtype=self.dtype).unsqueeze(0)

                session_code_train = session_code.repeat(spikes_train.shape[0], 1)
                session_code_val = session_code.repeat(spikes_val.shape[0], 1)
                session_code_test = session_code.repeat(spikes_test.shape[0], 1)

                spikes_train = torch.cat((spikes_train, session_code_train), dim=1)
                spikes_val = torch.cat((spikes_val, session_code_val), dim=1)
                spikes_test = torch.cat((spikes_test, session_code_test), dim=1)

        if self.testFlag:
            spikes_test = spikes_test[0:3600, :]
            labels_test = labels_test[0:3600, :]

        # split val and train data into single samples  ## 增加数据量，并将数据切割成n_time_steps（500）的段落
        if self.extend_data:
            logger.info("Extending data...")
            train_data, train_labels = self.extend_spikes(
                spikes_train, labels_train, self.n_time_steps, chunksize=int(self.n_time_steps/5)
            )
            val_data, val_labels = self.extend_spikes(
                spikes_val, labels_val, self.n_time_steps, chunksize=int(self.n_time_steps/5)
            )
        else:
            train_data, train_labels = self.extend_spikes(
                spikes_train, labels_train, chunks=99
            )
            val_data, val_labels = self.extend_spikes(
                spikes_val, labels_val, chunks=99
            )

        test_data = [spikes_test] ## 封装成列表
        test_labels = [labels_test]

        test_data = torch.stack(test_data)
        test_labels = torch.stack(test_labels)
        
        # Get augmentation kwargs for training dataset
        if any([self.p_drop > 0, self.p_insert > 0, self.jitter_sigma > 0]):
            
            data_augmentation_kwargs = dict(
                data_augmentation=True,
                p_drop=self.p_drop, 
                p_insert=self.p_insert, 
                sigma_t=self.jitter_sigma
            )
        else:
            data_augmentation_kwargs = {}

        # make it ras datasets
        if self.output_feedback:
            train_ras_data = self.to_dataset(train_data, train_labels,
                                         **data_augmentation_kwargs)
        else:
            train_ras_data = self.to_ras(train_data, train_labels,
                                         **data_augmentation_kwargs)

        val_ras_data = self.to_ras(val_data, val_labels)
        test_ras_data = self.to_ras(test_data, test_labels)

        return train_ras_data, val_ras_data, test_ras_data

    def get_multiple_sessions_data(self, filenames, nb_inputs, with_S1=False, zscore=False, session_codes=None):
        """Loads data from multiple sessions and concatenates them into a single dataset (split in train, test and validation).

        Args:
            filenames (list): List of filenames to load (all files should be in the folder specified by basepath    )

        Returns:
            tuple of stork RasDatasets for train and validation and a list of test dataset (one dataset for each session)
        """

        ds_train, ds_valid, ds_test = [], [], []

        if session_codes is not None:
            for filename, session_code in zip(filenames, session_codes):
                monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
                    self.get_single_session_data(filename, nb_inputs=nb_inputs, with_S1=with_S1, zscore=zscore, session_code=session_code)
                )
                ds_train.append(monkey_ds_train)
                ds_valid.append(monkey_ds_valid)
                ds_test.append(monkey_ds_test)
        else:
            for filename in filenames:
                monkey_ds_train, monkey_ds_valid, monkey_ds_test = (
                    self.get_single_session_data(filename, nb_inputs=nb_inputs, with_S1=with_S1, zscore=zscore)
                )
                ds_train.append(monkey_ds_train)
                ds_valid.append(monkey_ds_valid)
                ds_test.append(monkey_ds_test)

        dataset_train = torch.utils.data.ConcatDataset(ds_train)
        dataset_valid = torch.utils.data.ConcatDataset(ds_valid)

        return dataset_train, dataset_valid, ds_test

    def extend_spikes(self, spikes, labels, chunks="all", chunksize=100):
        """Given spike data and labels of the shape [time x neuron], it cuts it into overlapping samples of shape [samples x n_time_steps x neuron]"""
        ##  将给定的脉冲数据（spikes）和标签（labels）按时间切割成重叠的样本，并将其转换为一个三维的数组，方便后续处理和训练。重叠数：chunks

        if chunks == "all":
            chunks = self.n_time_steps

        extended_spikes = []
        extended_labels = []

        for t in range(0, chunks, chunksize):
            curr_spikes = spikes[t:]
            curr_labels = labels[t:]

            splitter = np.arange(
                self.n_time_steps, curr_spikes.shape[0], self.n_time_steps
            )

            extended_spikes += np.split(curr_spikes, splitter)[:-1]
            extended_labels += np.split(curr_labels, splitter)[:-1]

        ## 张量拼接
        extended_spikes = torch.stack(extended_spikes)
        extended_labels = torch.stack(extended_labels)

        return extended_spikes, extended_labels

    def to_ras(self, data, labels, **data_augmentation_kwargs):
        ras_data = [[[], []] for _ in data]

        for i, sample in enumerate(data):
            for j in range(sample.shape[-1]):
                spike_times = np.where(sample[:, j] == 1)[0].tolist()
                ras_data[i][0] += spike_times
                ras_data[i][1] += [j] * len(spike_times)
            ras_data[i] = torch.tensor(ras_data[i], dtype=self.dtype)
        
        monkey_ds_kwargs = dict(
            nb_steps=data.shape[-2], nb_units=data.shape[-1], time_scale=1.0
        )

        monkey_ds = stork.datasets.RasDataset(
            (ras_data, labels), dtype=self.dtype, 
            **monkey_ds_kwargs, **data_augmentation_kwargs
        )

        return monkey_ds

    def to_dataset(self, data, labels, **data_augmentation_kwargs):


        # 保持原始的脉冲结构或直接使用原始数据
        processed_data = data.clone().to(dtype=self.dtype)

        # 使用PyTorch的TensorDataset封装
        dataset = torch.utils.data.TensorDataset(processed_data, labels.to(dtype=self.dtype))

        return dataset
