from utils import set_seed
import math
import numpy as np

from torch.utils.data import DataLoader
from torch.utils.data import random_split
from typing import Callable, Optional

import torchvision.transforms as transforms
import tonic
from torchvision.datasets import utils
from spikingjelly.datasets import pad_sequence_collate, np_savez, configure

import torch
from torchaudio.transforms import Spectrogram, MelScale, AmplitudeToDB, Resample
from torchaudio.datasets.speechcommands import SPEECHCOMMANDS
from torchvision import transforms
from torch.utils.data import Dataset
import augmentations
from typing import Union, Tuple
from dataclasses import dataclass
from torch.nn.utils.rnn import pad_sequence
import os
import h5py
import time
from concurrent.futures import ThreadPoolExecutor

def drop_event_numpy(events: np.ndarray, drop_probability: float):
    n_events = events.shape[0]
    n_dropped_events = int(drop_probability * n_events + 0.5)
    dropped_event_indices = np.random.choice(n_events, n_dropped_events, replace=False)
    return np.delete(events, dropped_event_indices, axis=0)

@dataclass(frozen=True)
class DropEvent:
    """Randomly drops events with probability p. If random_p is selected, the drop probability is
    randomized between 0 and p.

    Parameters:
        p (float or tuple of floats): Probability of dropping events. Can be a tuple of floats (p_min, p_max), so that p is sampled from the range.

    Example:
        >>> transform1 = tonic.transforms.DropEvent(p=0.2)
        >>> transform2 = tonic.transforms.DropEvent(p=(0, 0.5))
    """

    p: Union[float, Tuple[float, float]]

    @staticmethod
    def get_params(p: Union[float, Tuple[float, float]]):
        if type(p) == tuple:
            p = (p[1] - p[0]) * np.random.random_sample() + p[0]
        return p

    def __call__(self, events):
        p = self.get_params(p=self.p)
        return drop_event_numpy(events=events, drop_probability=p)
    
class DropEventChunk:
    """
    Randomly drop a chunk of events
    """
    def __init__(self, p, max_drop_size):
        self.drop_prob = p
        self.max_drop_size = max_drop_size

    def __call__(self, events):
        max_drop_events = self.max_drop_size * len(events['x'])
        if np.random.rand() < self.drop_prob:
            drop_size = np.random.randint(1, max_drop_events)
            start = np.random.randint(0, len(events) - drop_size)
            events = np.delete(events, slice(start, start + drop_size), axis=0)
        return events
    
class Jitter1D:
    """
    Apply random jitter to event coordinates
    Parameters:
        max_roll (int): maximum number of pixels to roll by
    """
    def __init__(self, sensor_size, var):
        self.sensor_size = sensor_size
        self.var = var

    def __call__(self, events):
        # roll x, y coordinates by a random amount
        shift = np.random.normal(0, self.var, len(events)).astype(np.int16)
        events['x'] = events['x'] + shift
        # remove events who got shifted out of the sensor size
        mask = (events['x'] >= 0) & (events['x'] < self.sensor_size)
        events = events[mask]
        return events
    
def time_skew_numpy(events: np.ndarray, coefficient: float, offset: int = 0):
    if isinstance(coefficient, tuple):
        coefficient = (
            coefficient[1] - coefficient[0]
        ) * np.random.random_sample() + coefficient[0]

    if isinstance(offset, tuple):
        offset = (offset[1] - offset[0]) * np.random.random_sample() + offset[0]

    events["t"] = events["t"] * coefficient + offset

    return events


@dataclass(frozen=True)
class TimeSkew:
    coefficient: Union[float, Tuple[float, float]]
    offset: Union[float, Tuple[float, float]] = 0

    def __call__(self, events):
        events = events.copy()
        return time_skew_numpy(events, self.coefficient, self.offset)

def time_jitter_numpy(
    events: np.ndarray,
    std: float = 1,
    clip_negative: bool = False,
    sort_timestamps: bool = False,
):
    shifts = np.random.normal(0, std, len(events)) / 1000
    events["t"] += shifts
    if clip_negative:
        events = np.delete(events, (np.where(events["t"] < 0)), axis=0)
    if sort_timestamps:
        events = events[np.argsort(events["t"])]
    return events

@dataclass(frozen=True)
class TimeJitter:
    std: float
    clip_negative: bool = False
    sort_timestamps: bool = True

    def __call__(self, events):
        events = events.copy()
        return time_jitter_numpy(
            events, self.std, self.clip_negative, self.sort_timestamps
        )

def uniform_noise_numpy(events: np.ndarray, sensor_size: int, n: int):
    """Adds a fixed number of noise events that are uniformly distributed across sensor size
    dimensions.

    Parameters:
        events: ndarray of shape (n_events, n_event_channels)
        sensor_size: 3-tuple of integers for x, y, p
        n: the number of noise events added.
    """
    noise_events = np.zeros(n, dtype=events.dtype)
    for channel in events.dtype.names:
        if channel == "x":
            low, high = 0, sensor_size
        if channel == "t":
            low, high = events["t"].min(), events["t"].max()
        noise_events[channel] = np.random.uniform(low=low, high=high, size=n)
    noisy_events = np.concatenate((events, noise_events))
    return noisy_events[np.argsort(noisy_events["t"])]


@dataclass(frozen=True)
class UniformNoise:
    """Adds a fixed number of n noise events that are uniformly distributed across sensor size
    dimensions such as x, y, t and p. Not applied if the input is empty.

    Parameters:
        sensor_size: a 3-tuple of x,y,p for sensor_size
        n: Number of events that are added. Can be a tuple of integers,
           so that n is sampled from a range.

    Example:
        >>> transform = tonic.transforms.UniformNoise(sensor_size=(340, 240, 2), n=3000)
    """

    sensor_size: int
    n: Union[int, Tuple[int, int]]

    @staticmethod
    def get_params(n: Union[int, Tuple[int, int]]):
        if type(n) == tuple:
            n = int((n[1] - n[0]) * np.random.random_sample() + n[0])
        return n

    def __call__(self, events):
        if len(events) == 0:
            return events

        n = self.get_params(n=self.n)
        return uniform_noise_numpy(
            events=events, sensor_size=self.sensor_size, n=n
        )

def event_collate_fun(batch: list):
    x_list = []
    y_list = []
    for x, y in batch:
       x_list.append(torch.as_tensor(x).permute(1, 0))
       y_list.append(y)
       
    x_list = pad_sequence(x_list, batch_first=True, padding_value=-1.) # N, T, C
    
    padding_mask = (x_list == -1)

    x_list = x_list.clamp_min_(0)
    
    return x_list, torch.as_tensor(y_list), padding_mask

def SHD_dataloaders(config):
    set_seed(config.seed)
    train_custom_integrate_function = lambda h5_file, i, output_dir, W: integrate_events_file_to_frames_file_by_fixed_duration_shd(config=config, augmentation=True, duration=config.time_step, h5_file=h5_file, output_dir=output_dir, i=i, W=W)

    test_custom_integrate_function = lambda h5_file, i, output_dir, W: integrate_events_file_to_frames_file_by_fixed_duration_shd(config=config, augmentation=False, duration=config.time_step, h5_file=h5_file, output_dir=output_dir, i=i, W=W)

    train_dataset = BinnedSpikingHeidelbergDigits(config.datasets_path, config.n_bins, train=True, train_custom_integrate_function=train_custom_integrate_function, test_custom_integrate_function=test_custom_integrate_function, custom_integrated_frames_dir_name="duration_10")
    test_dataset = BinnedSpikingHeidelbergDigits(config.datasets_path, config.n_bins, train=False, custom_integrated_frames_dir_name="duration_10")
    
    

    train_loader = DataLoader(train_dataset, collate_fn=event_collate_fun, batch_size=config.batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, collate_fn=event_collate_fun, batch_size=config.batch_size, num_workers=4)

    return train_loader, test_loader

def integrate_events_segment_to_frame_shd(x: np.ndarray, W: int, j_l: int = 0, j_r: int = -1) -> np.ndarray:
    frame = np.zeros(shape=[W])
    x = x[j_l: j_r].astype(int)  # avoid overflow

    position = x
    events_number_per_pos = np.bincount(position)
    frame[np.arange(events_number_per_pos.size)] += events_number_per_pos
    return frame

def integrate_events_by_fixed_duration_shd(events: np.ndarray, duration: int, W: int) -> np.ndarray:
    x = events['x']
    t = 1000*events['t']
    t = t - t[0]
    
    N = len(t)

    frames_num = int(math.ceil(t[-1] / duration))
    frames = np.zeros([frames_num, W])
    frame_index = t // duration
    left = 0

    for i in range(frames_num - 1):
        right = np.searchsorted(frame_index, i + 1, side='left')
        frames[i] = integrate_events_segment_to_frame_shd(x, W, left, right)
        left = right

    frames[-1] = integrate_events_segment_to_frame_shd(x, W, left, N)
    return frames
def integrate_events_file_to_frames_file_by_fixed_duration_shd(config, augmentation, duration: int, h5_file: h5py.File, i: int, output_dir: str, W: int) -> None:
    events = np.array(list(zip(h5_file['spikes']['times'][i], h5_file['spikes']['units'][i])), dtype=[('t', float), ('x', int)])
    label = h5_file['labels'][i]

    if augmentation:
        trans = tonic.transforms.Compose([
            Jitter1D(sensor_size=700, var=config.spatial_jitter),
            UniformNoise(sensor_size=700, n=(0, config.noise)),
            DropEvent(p=config.drop_event),
            DropEventChunk(p=0.3, max_drop_size=config.max_drop_chunk),
        ])
        events = trans(events)

    fname = os.path.join(output_dir, str(label), str(i))
    frames = integrate_events_by_fixed_duration_shd(events, duration, W)
    np_savez(fname, frames=frames)
    return frames.shape[0]

class BinnedSpikingHeidelbergDigits(object):
    def __init__(
            self,
            root: str,
            n_bins: int,
            train: bool = None,
            train_custom_integrate_function: Callable = None,
            test_custom_integrate_function: Callable = None,
            custom_integrated_frames_dir_name: str = 'frames',
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
    ) -> None:
        """
        The Spiking Heidelberg Digits (SHD) dataset, which is proposed by `The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks <https://doi.org/10.1109/TNNLS.2020.3044364>`_.

        Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.

        .. admonition:: Note
            :class: note

            Events in this dataset are in the format of ``(x, t)`` rather than ``(x, y, t, p)``. Thus, this dataset is not inherited from :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` directly. But their procedures are similar.

        :class:`spikingjelly.datasets.shd.custom_integrate_function_example` is an example of ``custom_integrate_function``, which is similar to the cunstom function for DVS Gesture in the ``Neuromorphic Datasets Processing`` tutorial.
        """
        self.n_bins = n_bins
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        download_root = os.path.join(root, 'download')
        extract_root = os.path.join(root, 'extract')

        if not os.path.exists(extract_root):
            if os.path.exists(download_root):
                print(f'The [{download_root}] directory for saving downloaded files already exists, check files...')
                # check files
                resource_list = self.resource_url_md5()
                for i in range(resource_list.__len__()):
                    file_name, url, md5 = resource_list[i]
                    fpath = os.path.join(download_root, file_name)
                    if not utils.check_integrity(fpath=fpath, md5=md5):
                        print(f'The file [{fpath}] does not exist or is corrupted.')

                        if os.path.exists(fpath):
                            # If file is corrupted, we will remove it.
                            os.remove(fpath)
                            print(f'Remove [{fpath}]')

                        if self.downloadable():
                            # If file does not exist, we will download it.
                            print(f'Download [{file_name}] from [{url}] to [{download_root}]')
                            utils.download_url(url=url, root=download_root, filename=file_name, md5=md5)
                        else:
                            raise NotImplementedError(
                                f'This dataset can not be downloaded by SpikingJelly, please download [{file_name}] from [{url}] manually and put files at {download_root}.')

            else:
                os.mkdir(download_root)
                print(f'Mkdir [{download_root}] to save downloaded files.')
                resource_list = self.resource_url_md5()
                if self.downloadable():
                    # download and extract file
                    for i in range(resource_list.__len__()):
                        file_name, url, md5 = resource_list[i]
                        print(f'Download [{file_name}] from [{url}] to [{download_root}]')
                        utils.download_url(url=url, root=download_root, filename=file_name, md5=md5)
                else:
                    raise NotImplementedError(f'This dataset can not be downloaded by SpikingJelly, '
                                              f'please download files manually and put files at [{download_root}]. '
                                              f'The resources file_name, url, and md5 are: \n{resource_list}')

            os.mkdir(extract_root)
            print(f'Mkdir [{extract_root}].')
            self.extract_downloaded_files(download_root, extract_root)

        else:
            print(f'The directory [{extract_root}] for saving extracted files already exists.\n'
                  f'SpikingJelly will not check the data integrity of extracted files.\n'
                  f'If extracted files are not integrated, please delete [{extract_root}] manually, '
                  f'then SpikingJelly will re-extract files from [{download_root}].')
            # shutil.rmtree(extract_root)
            # print(f'Delete [{extract_root}].')

        frames_np_root = os.path.join(root, custom_integrated_frames_dir_name)
        if os.path.exists(frames_np_root):
            print(f'The directory [{frames_np_root}] already exists.')
        else:
            os.mkdir(frames_np_root)
            print(f'Mkdir [{frames_np_root}].')

            frames_np_train_root = os.path.join(frames_np_root, 'train')
            os.mkdir(frames_np_train_root)
            print(f'Mkdir [{frames_np_train_root}].')
            for i in range(self.classes_number()):
                os.mkdir(os.path.join(frames_np_train_root, str(i)))
                print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].')

            frames_np_test_root = os.path.join(frames_np_root, 'test')
            os.mkdir(frames_np_test_root)
            print(f'Mkdir [{frames_np_test_root}].')
            for i in range(self.classes_number()):
                os.mkdir(os.path.join(frames_np_test_root, str(i)))
                print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].')

            # use multi-thread to accelerate
            t_ckp = time.time()
            with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
                print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
                sub_threads = []

                h5_file = h5py.File(os.path.join(extract_root, 'shd_train.h5'))
                for i in range(h5_file['labels'].__len__()):
                    print(f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].')
                    sub_threads.append(tpe.submit(train_custom_integrate_function, h5_file, i, frames_np_train_root, self.get_W()))

                h5_file = h5py.File(os.path.join(extract_root, 'shd_test.h5'))
                for i in range(h5_file['labels'].__len__()):
                    print(f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].')
                    sub_threads.append(tpe.submit(test_custom_integrate_function,  h5_file, i, frames_np_test_root, self.get_W()))

                for sub_thread in sub_threads:
                    if sub_thread.exception():
                        print(sub_thread.exception())
                        exit(-1)
            print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')

        self.frames_np_root = frames_np_root
        if self.train:
            sub_dir = 'train'
        else:
            sub_dir = 'test'
        
        self.frames_path = []
        self.frames_label = []
        for i in range(self.classes_number()):
            for fname in os.listdir(os.path.join(self.frames_np_root, sub_dir, str(i))):
                self.frames_path.append(
                    os.path.join(self.frames_np_root, sub_dir, str(i), fname)
                )
                self.frames_label.append(i)

        self.length = self.frames_label.__len__()
    
    def classes_number(self):
        return 20

    def __getitem__(self, i: int):
        frames = np.load(self.frames_path[i], allow_pickle=True)['frames'].astype(np.float32)
        label = self.frames_label[i]
        binned_len = frames.shape[1]//self.n_bins
        binned_frames = []
        for i in range(binned_len):
            binned_frames.append(frames[:, self.n_bins*i : self.n_bins*(i+1)].sum(axis=1))
        
        binned_frames = np.array(binned_frames)
        if self.transform is not None:
            binned_frames = self.transform(binned_frames)
        if self.target_transform is not None:
            label = self.target_transform(label)

        return binned_frames, label
    
 
    def __len__(self):
        return self.length
    
    def get_W(self):
        return 700

def build_transform(is_train):
    sample_rate=16000
    window_size=256
    hop_length=80
    n_mels=140
    f_min=50
    f_max=14000

    t = [augmentations.PadOrTruncate(sample_rate),
         Resample(sample_rate, sample_rate // 2)]
    if is_train:
        t.extend([augmentations.RandomRoll(dims=(1,)),
                  augmentations.SpeedPerturbation(rates=(0.5, 1.5), p=0.5)
                 ])

    t.append(Spectrogram(n_fft=window_size, hop_length=hop_length, power=2))

    if is_train:
        pass

    t.extend([MelScale(n_mels=n_mels,
                       sample_rate=sample_rate // 2,
                       f_min=f_min,
                       f_max=f_max,
                       n_stft=window_size // 2 + 1),
              AmplitudeToDB()
             ])

    return transforms.Compose(t)

labels = ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']

target_transform = lambda word : torch.tensor(labels.index(word))
