# Encoders
from SNN.Encoders import *
from SNN.transforms import *

from Datasets import *

from AbstractModels.util import *

import torch

from norse.torch.module import encode

from torchvision import datasets as torchvision_datasets
from torchvision.transforms import v2 as transforms

import tonic

encoders = {
    'copy': CopyEncoder,
    'identity': IdentityEncoder
}

decoders = {
    'mean': decode_mean
}

datasets = {
    'fmnist': torchvision_datasets.FashionMNIST,
    'mnist': torchvision_datasets.MNIST,
    'cifar10': CIFAR10,
    'cifar100': torchvision_datasets.CIFAR100,
    'imagenet': ImageNet
}

neuromorphic_datasets = {
    'cifar10dvs': CIFAR10DVS
}

dataset_classes = {
    'fmnist': 10,
    'cifar10': 10,
    'cifar100': 100,
    'imagenet': 1000,
    'cifar10dvs': 10,
}

collate_fns = {
    'fmnist': None,
    'cifar10dvs': tonic.collation.PadTensors(batch_first=False),
    'cifar10': None,
    'cifar100': None,
    'imagenet': None,
}

# Data set transformation rules
# If the dataset name is added 'as is', then the transform is applied to both the training and validation set
# If the dataset name is added with '_train' or '_val', then the transform is applied to the training or validation set, respectively
dataset_transforms = {
    'fmnist_train': transforms.Compose([
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=(0.2860,), std=(0.3205,))
    ]),
    'fmnist_val': transforms.Compose([
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=(0.2860,), std=(0.3205,))
    ]),
    'cifar10_train': transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        Cutout(n_holes=1, length=16),
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
    ]),
    'cifar10_val': transforms.Compose([
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
    ]),
    'cifar100_train': transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
        Cutout(n_holes=1, length=16),
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]),
    ]),
    'cifar100_val': transforms.Compose([
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
    ]),
    'imagenet_train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'imagenet_val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToImage(),
        transforms.ToDtype(torch.float, scale=True),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    # Neuromorphic Datasets
    # These are just placeholders now, real transforms are incorporated directly into the dataset class
    'cifar10dvs_train': transforms.Compose([
        torch.as_tensor,
        transforms.Resize((48, 48), interpolation=transforms.InterpolationMode.NEAREST),
        transforms.ToDtype(torch.float, scale=True)
    ]),
    'cifar10dvs_val': transforms.Compose([
        torch.as_tensor,
        transforms.Resize((48, 48), interpolation=transforms.InterpolationMode.NEAREST),
        transforms.ToDtype(torch.float, scale=True)
    ])
}

optimizers = {
    'adam': torch.optim.Adam,
    'adamw': torch.optim.AdamW,
    'sgd': torch.optim.SGD
}

schedulers = {
    'None': None,
    'cosine': torch.optim.lr_scheduler.CosineAnnealingLR,
}

losses = {
    'crossentropy': torch.nn.CrossEntropyLoss,
}

surrogate_gradients = [
    "triangle", 
    "rectangle"
]
