"""
This code was adapted from:
M. Xiao, Q. Meng, Z. Zhang, D. He, and Z. Lin, "Online Training Through Time for Spiking Neural Networks,"
NeurIPS, 2022.
"""
from torch.utils.data import Dataset
import os
import numpy as np
import threading
import zipfile
from torchvision.datasets import utils
import torch


class FunctionThread(threading.Thread):
    def __init__(self, f, *args, **kwargs):
        super().__init__()
        self.f = f
        self.args = args
        self.kwargs = kwargs
    def run(self):
        self.f(*self.args, **self.kwargs)


def integrate_events_to_frames(events, height, width, frames_num=10, split_by='time', normalization=None):
    '''
    * :ref:`API in English <integrate_events_to_frames.__init__-en>`

    .. _integrate_events_to_frames.__init__-cn:

    :param events: 键是{'t', 'x', 'y', 'p'}，值是np数组的的字典
    :param height: 脉冲数据的高度，例如对于CIFAR10-DVS是128
    :param width: 脉冲数据的宽度，例如对于CIFAR10-DVS是128
    :param frames_num: 转换后数据的帧数
    :param split_by: 脉冲数据转换成帧数据的累计方式，允许的取值为 ``'number', 'time'``
    :param normalization: 归一化方法，允许的取值为 ``None, 'frequency', 'max', 'norm', 'sum'``
    :return: 转化后的frames数据，是一个 ``shape = [frames_num, 2, height, width]`` 的np数组

    记脉冲数据为 :math:`E_{i} = (t_{i}, x_{i}, y_{i}, p_{i}), i=0,1,...,N-1`，转换为帧数据 :math:`F(j, p, x, y), j=0,1,...,M-1`。

    若划分方式 ``split_by`` 为 ``'time'``，则

    .. math::

        \\Delta T & = [\\frac{t_{N-1} - t_{0}}{M}] \\\\
        j_{l} & = \\mathop{\\arg\\min}\\limits_{k} \\{t_{k} | t_{k} \\geq t_{0} + \\Delta T \\cdot j\\} \\\\
        j_{r} & = \\begin{cases} \\mathop{\\arg\\max}\\limits_{k} \\{t_{k} | t_{k} < t_{0} + \\Delta T \\cdot (j + 1)\\} + 1, & j <  M - 1 \\cr N, & j = M - 1 \\end{cases} \\\\
        F(j, p, x, y) & = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I_{p, x, y}(p_{i}, x_{i}, y_{i})}

    若划分方式 ``split_by`` 为 ``'number'``，则

    .. math::

        j_{l} & = [\\frac{N}{M}] \\cdot j \\\\
        j_{r} & = \\begin{cases} [\\frac{N}{M}] \\cdot (j + 1), & j <  M - 1 \\cr N, & j = M - 1 \\end{cases}\\\\
        F(j, p, x, y) & = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I_{p, x, y}(p_{i}, x_{i}, y_{i})}

    其中 :math:`\\mathcal{I}` 为示性函数，当且仅当 :math:`(p, x, y) = (p_{i}, x_{i}, y_{i})` 时为1，否则为0。

    若 ``normalization`` 为 ``'frequency'``，

        若 ``split_by`` 为 ``time`` 则

            .. math::
                F_{norm}(j, p, x, y) = \\begin{cases} \\frac{F(j, p, x, y)}{\\Delta T}, & j < M - 1
                \\cr \\frac{F(j, p, x, y)}{\\Delta T + (t_{N-1} - t_{0}) \\bmod M}, & j = M - 1 \\end{cases}

        若 ``split_by`` 为 ``number`` 则

            .. math::
                F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{t_{j_{r}} - t_{j_{l}}}


    若 ``normalization`` 为 ``'max'`` 则

    .. math::
        F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{\\mathrm{max} F(j, p)}

    若 ``normalization`` 为 ``'norm'`` 则

    .. math::
        F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y) - \\mathrm{E}(F(j, p))}{\\sqrt{\\mathrm{Var}(F(j, p))}}

    若 ``normalization`` 为 ``'sum'`` 则

    .. math::
        F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{\\sum_{a, b} F(j, p, a, b)}

    * :ref:`中文API <integrate_events_to_frames.__init__-cn>`

    .. _integrate_events_to_frames.__init__-en:

    :param events: a dict with keys are {'t', 'x', 'y', 'p'} and values are numpy arrays
    :param height: the height of events data, e.g., 128 for CIFAR10-DVS
    :param width: the width of events data, e.g., 128 for CIFAR10-DVS
    :param frames_num: frames number
    :param split_by: how to split the events, can be ``'number', 'time'``
    :param normalization: how to normalize frames, can be ``None, 'frequency', 'max', 'norm', 'sum'``
    :return: the frames data with ``shape = [frames_num, 2, height, width]``

    The events data are denoted by :math:`E_{i} = (t_{i}, x_{i}, y_{i}, p_{i}), i=0,1,...,N-1`, and the converted frames
    data are denoted by :math:`F(j, p, x, y), j=0,1,...,M-1`.

    If ``split_by`` is ``'time'``, then

    .. math::

        \\Delta T & = [\\frac{t_{N-1} - t_{0}}{M}] \\\\
        j_{l} & = \\mathop{\\arg\\min}\\limits_{k} \\{t_{k} | t_{k} \\geq t_{0} + \\Delta T \\cdot j\\} \\\\
        j_{r} & = \\begin{cases} \\mathop{\\arg\\max}\\limits_{k} \\{t_{k} | t_{k} < t_{0} + \\Delta T \\cdot (j + 1)\\} + 1, & j <  M - 1 \\cr N, & j = M - 1 \\end{cases} \\\\
        F(j, p, x, y) & = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I_{p, x, y}(p_{i}, x_{i}, y_{i})}

    If ``split_by`` is ``'number'``, then

    .. math::

        j_{l} & = [\\frac{N}{M}] \\cdot j \\\\
        j_{r} & = \\begin{cases} [\\frac{N}{M}] \\cdot (j + 1), & j <  M - 1 \\cr N, & j = M - 1 \\end{cases}\\\\
        F(j, p, x, y) & = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I_{p, x, y}(p_{i}, x_{i}, y_{i})}

    where :math:`\\mathcal{I}` is the characteristic function，if and only if :math:`(p, x, y) = (p_{i}, x_{i}, y_{i})`,
    this function is identically 1 else 0.

    If ``normalization`` is ``'frequency'``,

        if ``split_by`` is ``time``,

            .. math::
                F_{norm}(j, p, x, y) = \\begin{cases} \\frac{F(j, p, x, y)}{\\Delta T}, & j < M - 1
                \\cr \\frac{F(j, p, x, y)}{\\Delta T + (t_{N-1} - t_{0}) \\bmod M}, & j = M - 1 \\end{cases}

        if ``split_by`` is ``number``,

            .. math::
                F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{t_{j_{r}} - t_{j_{l}}}

    If ``normalization`` is ``'max'``, then

    .. math::
        F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{\\mathrm{max} F(j, p)}

    If ``normalization`` is ``'norm'``, then

    .. math::
        F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y) - \\mathrm{E}(F(j, p))}{\\sqrt{\\mathrm{Var}(F(j, p))}}

    If ``normalization`` is ``'sum'``, then

    .. math::
        F_{norm}(j, p, x, y) = \\frac{F(j, p, x, y)}{\\sum_{a, b} F(j, p, a, b)}
    '''
    frames = np.zeros(shape=[frames_num, 2, height * width])

    # 创建j_{l}和j_{r}
    j_l = np.zeros(shape=[frames_num], dtype=int)
    j_r = np.zeros(shape=[frames_num], dtype=int)
    if split_by == 'time':
        events['t'] -= events['t'][0]  # 时间从0开始
        assert events['t'][-1] > frames_num
        dt = events['t'][-1] // frames_num  # 每一段的持续时间
        idx = np.arange(events['t'].size)
        for i in range(frames_num):
            t_l = dt * i
            t_r = t_l + dt
            mask = np.logical_and(events['t'] >= t_l, events['t'] < t_r)
            idx_masked = idx[mask]
            ################### fix bug ######################
            if len(idx_masked) == 0: # no events
                j_l[i] = -1
                j_r[i] = -1
            else:
                j_l[i] = idx_masked[0]
                j_r[i] = idx_masked[-1] + 1 if i < frames_num - 1 else events['t'].size

    elif split_by == 'number':
        di = events['t'].size // frames_num
        for i in range(frames_num):
            j_l[i] = i * di
            j_r[i] = j_l[i] + di if i < frames_num - 1 else events['t'].size
    else:
        raise NotImplementedError

    # 开始累计脉冲
    # 累计脉冲需要用bitcount而不能直接相加，原因可参考下面的示例代码，以及
    # https://stackoverflow.com/questions/15973827/handling-of-duplicate-indices-in-numpy-assignments
    # height = 3
    # width = 3
    # frames = np.zeros(shape=[2, height, width])
    # events = {
    #     'x': np.asarray([1, 2, 1, 1]),
    #     'y': np.asarray([1, 1, 1, 2]),
    #     'p': np.asarray([0, 1, 0, 1])
    # }
    #
    # frames[0, events['y'], events['x']] += (1 - events['p'])
    # frames[1, events['y'], events['x']] += events['p']
    # print('wrong accumulation\n', frames)
    #
    # frames = np.zeros(shape=[2, height, width])
    # for i in range(events['p'].__len__()):
    #     frames[events['p'][i], events['y'][i], events['x'][i]] += 1
    # print('correct accumulation\n', frames)
    #
    # frames = np.zeros(shape=[2, height, width])
    # frames = frames.reshape(2, -1)
    #
    # mask = [events['p'] == 0]
    # mask.append(np.logical_not(mask[0]))
    # for i in range(2):
    #     position = events['y'][mask[i]] * height + events['x'][mask[i]]
    #     events_number_per_pos = np.bincount(position)
    #     idx = np.arange(events_number_per_pos.size)
    #     frames[i][idx] += events_number_per_pos
    # frames = frames.reshape(2, height, width)
    # print('correct accumulation by bincount\n', frames)

    for i in range(frames_num):
        ################### fix bug ######################
        if j_l[i] != -1:
            x = events['x'][j_l[i]:j_r[i]]
            y = events['y'][j_l[i]:j_r[i]]
            p = events['p'][j_l[i]:j_r[i]]
            mask = []
            mask.append(p == 0)
            mask.append(np.logical_not(mask[0]))
            for j in range(2):
                position = y[mask[j]] * height + x[mask[j]]
                events_number_per_pos = np.bincount(position)
                frames[i][j][np.arange(events_number_per_pos.size)] += events_number_per_pos

            if normalization == 'frequency':
                if split_by == 'time':
                    if i < frames_num - 1:
                        frames[i] /= dt
                    else:
                        frames[i] /= (dt + events['t'][-1] % frames_num)
                elif split_by == 'number':
                        frames[i] /= (events['t'][j_r[i]] - events['t'][j_l[i]])  # 表示脉冲发放的频率

                else:
                    raise NotImplementedError

        # 其他的normalization方法，在数据集类读取数据的时候进行通过调用normalize_frame(frames: np.ndarray, normalization: str)
        # 函数操作，而不是在转换数据的时候进行
    return frames.reshape((frames_num, 2, height, width))

def normalize_frame(frames: np.ndarray or torch.Tensor, normalization: str):
    eps = 1e-5  # 涉及到除法的地方，被除数加上eps，防止出现除以0
    for i in range(frames.shape[0]):
        if normalization == 'max':
            frames[i][0] /= max(frames[i][0].max(), eps)
            frames[i][1] /= max(frames[i][1].max(), eps)

        elif normalization == 'norm':
            frames[i][0] = (frames[i][0] - frames[i][0].mean()) / np.sqrt(max(frames[i][0].var(), eps))
            frames[i][1] = (frames[i][1] - frames[i][1].mean()) / np.sqrt(max(frames[i][1].var(), eps))

        elif normalization == 'sum':
            frames[i][0] /= max(frames[i][0].sum(), eps)
            frames[i][1] /= max(frames[i][1].sum(), eps)

        else:
            raise NotImplementedError
    return frames

def convert_events_dir_to_frames_dir(events_data_dir, frames_data_dir, suffix, read_function, height, width,
                                              frames_num=10, split_by='time', normalization=None, thread_num=1, compress=False):
    # 遍历events_data_dir目录下的所有脉冲数据文件，在frames_data_dir目录下生成帧数据文件
    def cvt_fun(events_file_list):
        for events_file in events_file_list:
            frames = integrate_events_to_frames(read_function(events_file), height, width, frames_num, split_by,
                                                normalization)
            if compress:
                frames_file = os.path.join(frames_data_dir,
                                           os.path.basename(events_file)[0: -suffix.__len__()] + '.npz')
                np.savez_compressed(frames_file, frames)
            else:
                frames_file = os.path.join(frames_data_dir,
                                           os.path.basename(events_file)[0: -suffix.__len__()] + '.npy')
                np.save(frames_file, frames)
    events_file_list = utils.list_files(events_data_dir, suffix, True)
    if thread_num == 1:
        cvt_fun(events_file_list)
    else:
        # 多线程加速
        thread_list = []
        block = events_file_list.__len__() // thread_num
        for i in range(thread_num - 1):
            thread_list.append(FunctionThread(cvt_fun, events_file_list[i * block: (i + 1) * block]))
            thread_list[-1].start()
            print(f'thread {i} start, processing files index: {i * block} : {(i + 1) * block}.')
        thread_list.append(FunctionThread(cvt_fun, events_file_list[(thread_num - 1) * block:]))
        thread_list[-1].start()
        print(f'thread {thread_num} start, processing files index: {(thread_num - 1) * block} : {events_file_list.__len__()}.')
        for i in range(thread_num):
            thread_list[i].join()
            print(f'thread {i} finished.')

def extract_zip_in_dir(source_dir, target_dir):
    '''
    :param source_dir: 保存有zip文件的文件夹
    :param target_dir: 保存zip解压后数据的文件夹
    :return: None

    将 ``source_dir`` 目录下的所有*.zip文件，解压到 ``target_dir`` 目录下的对应文件夹内
    '''

    for file_name in os.listdir(source_dir):
        if file_name[-3:] == 'zip':
            with zipfile.ZipFile(os.path.join(source_dir, file_name), 'r') as zip_file:
                zip_file.extractall(os.path.join(target_dir, file_name[:-4]))

class EventsFramesDatasetBase(Dataset):
    @staticmethod
    def get_wh():
        '''
        :return: (width, height)
            width: int
                events或frames图像的宽度
            height: int
                events或frames图像的高度
        :rtype: tuple
        '''
        raise NotImplementedError

    @staticmethod
    def read_bin(file_name: str):
        '''
        :param file_name: 脉冲数据的文件名
        :type file_name: str
        :return: events
            键是{'t', 'x', 'y', 'p'}，值是np数组的的字典
        :rtype: dict
        '''
        raise NotImplementedError

    @staticmethod
    def get_events_item(file_name):
        '''
        :param file_name: 脉冲数据的文件名
        :type file_name: str
        :return: (events, label)
            events: dict
                键是{'t', 'x', 'y', 'p'}，值是np数组的的字典
            label: int
                数据的标签
        :rtype: tuple
        '''
        raise NotImplementedError

    @staticmethod
    def get_frames_item(file_name):
        '''
        :param file_name: 帧数据的文件名
        :type file_name: str
        :return: (frames, label)
            frames: np.ndarray
                ``shape = [frames_num, 2, height, width]`` 的np数组
            label: int
                数据的标签
        :rtype: tuple
        '''
        raise NotImplementedError

    @staticmethod
    def download_and_extract(download_root: str, extract_root: str):
        '''
        :param download_root: 保存下载文件的文件夹
        :type download_root: str
        :param extract_root: 保存解压后文件的文件夹
        :type extract_root: str

        下载数据集到 ``download_root``，然后解压到 ``extract_root``。
        '''
        raise NotImplementedError

    @staticmethod
    def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num: int, split_by: str, normalization: str or None):
        '''
        :param events_data_dir: 保存脉冲数据的文件夹，文件夹的文件全部是脉冲数据
        :type events_data_dir: str
        :param frames_data_dir: 保存帧数据的文件夹
        :type frames_data_dir: str
        :param frames_num: 转换后数据的帧数
        :type frames_num: int
        :param split_by: 脉冲数据转换成帧数据的累计方式
        :type split_by: str
        :param normalization: 归一化方法
        :type normalization: str or None

        将 ``events_data_dir`` 文件夹下的脉冲数据全部转换成帧数据，并保存在 ``frames_data_dir``。
        转换参数的详细含义，参见 ``integrate_events_to_frames`` 函数。
        '''
        raise NotImplementedError