# Encoders

from SNN.Encoders import *
from SNN.transforms import *

from Datasets import *

from AbstractModels.util import *

from timm.optim import optim_factory

import torch

from torchvision import datasets as torchvision_datasets
from torchvision.transforms import v2 as transforms

import tonic
import tonic.transforms as tonic_transforms

encoders = {
    'copy': CopyEncoder,
    'identity': IdentityEncoder
}

decoders = {
    'mean': decode_mean,
    'max': decode_max,
    'last': decode_last,
    'first': decode_first,
    'sum': decode_sum,
    'same': decode_same,
    'identity': decode_identity,
}

datasets = {
    'fmnist': torchvision_datasets.FashionMNIST,
    'cifar10': CIFAR10,
    'cifar100': CIFAR100,
    'tinyimagenet': TinyImageNet
}

neuromorphic_datasets = {
    'dvsgesture': DVS128Gesture,
    'cifar10dvs': CIFAR10DVS,
    'ncaltech101': NCaltech101,
    'ncars': NCars
}

dataset_classes = {
    'fmnist': 10,
    'cifar10': 10,
    'cifar100': 100,
    'tinyimagenet': 200,
    'dvsgesture': 11,
    'cifar10dvs': 10,
    'ncaltech101': 101,
    'ncars': 2
}

collate_fns = {
    'fmnist': None,
    'cifar10dvs': tonic.collation.PadTensors(batch_first=False),
    'ncaltech101': tonic.collation.PadTensors(batch_first=False),
    'ncars': tonic.collation.PadTensors(batch_first=False),
    'dvsgesture': tonic.collation.PadTensors(batch_first=False),
    'cifar10': None,
    'cifar100': None,
    'tinyimagenet': None
}

# Data set transformation rules
# If the dataset name is added 'as is', then the transform is applied to both the training and validation set
# If the dataset name is added with '_train' or '_val', then the transform is applied to the training or validation set, respectively
dataset_transforms = {
    'fmnist': transforms.Compose([
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=(0.2860,), std=(0.3205,))
    ]),
    'cifar10_train': transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        Cutout(n_holes=1, length=16),
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
    ]),
    'cifar10_val': transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
    ]),
    'cifar100_train': transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
        Cutout(n_holes=1, length=16),
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]),
    ]),
    'cifar100_val': transforms.Compose([
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
    ]),
    # Neuromorphic Datasets
    'cifar10dvs_train': transforms.Compose([
        torch.as_tensor,
        transforms.Resize((48, 48)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        # roll(interval=10),
        transforms.ToDtype(torch.float, scale=True)
    ]),
    'cifar10dvs_val': transforms.Compose([
        torch.as_tensor,
        transforms.Resize((48, 48)),
        transforms.ToDtype(torch.float, scale=True)
    ]),
    'ncaltech101_train': transforms.Compose([
        torch.as_tensor,
        transforms.Resize((48, 48)),
        NDA(intensity=NDA.Intensity.LOW),
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True)
    ]),
    'ncaltech101_val': transforms.Compose([
        torch.as_tensor,
        transforms.Resize((48, 48)),
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True)
    ]),
    'dvsgesture_train': tonic_transforms.Compose([
        torch.from_numpy,
        # transforms.Resize((64, 64), interpolation=transforms.InterpolationMode.NEAREST),
        roll(interval=5),
        transforms.ToDtype(torch.float, scale=True)
    ]),
    'dvsgesture_val': tonic_transforms.Compose([
        torch.from_numpy,
        # transforms.Resize((64, 64), interpolation=transforms.InterpolationMode.NEAREST),
        transforms.ToDtype(torch.float, scale=True)
    ]),
    'ncars_train': transforms.Compose([
        tonic.transforms.Downsample(
            sensor_size=NCars.sensor_size,
            target_size=(48, 48)
        ),
        tonic.transforms.ToFrame(
            sensor_size=(48, 48, 2),
            n_time_bins=10
        ),
        NDA(intensity=NDA.Intensity.LOW),
        transforms.ToDtype(torch.float, scale=True)
    ]),
    'ncars_val': transforms.Compose([
        tonic.transforms.Downsample(
            sensor_size=NCars.sensor_size,
            target_size=(48, 48)
        ),
        tonic.transforms.ToFrame(
            sensor_size=(48, 48, 2),
            n_time_bins=10
        ),
        transforms.ToDtype(torch.float, scale=True)
    ]),
    'tinyimagenet_train': transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        Cutout(n_holes=1, length=16),
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=[0.480, 0.448, 0.398], std=[0.229,  0.226, 0.225])
    ]),
    'tinyimagenet_val': transforms.Compose([
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=[0.480, 0.448, 0.398], std=[0.229,  0.226, 0.225])
    ]),
}

optimizers = {
    'adam': torch.optim.Adam,
    'adamw': torch.optim.AdamW,
    'sgd': torch.optim.SGD,
    'lamb': optim_factory.Lamb,
}

schedulers = {
    'None': None,
    'cosine': torch.optim.lr_scheduler.CosineAnnealingLR,
    'step': torch.optim.lr_scheduler.StepLR,
    'multistep': torch.optim.lr_scheduler.MultiStepLR
}

losses = {
    'crossentropy': torch.nn.CrossEntropyLoss
}

surrogate_gradients = ["heaviside", "super", "triangle", "tanh", "circ", "heavi_erfc", "rectangle", "asym_rectangle"]