from enum import Enum

from torchvision.transforms import v2 as transforms

import torch

import numpy as np
import random

from .Cutout import Cutout

class NDA(transforms.Transform):
    """
    Neural Data Augmentations policy described in the work
    [1] Y. Li, Y. Kim, H. Park, T. Geller, and P. Panda, “Neuromorphic Data Augmentation for Training Spiking Neural Networks,” 
    Jul. 20, 2022, arXiv: arXiv:2203.06145. Accessed: Jul. 23, 2024. [Online]. Available: http://arxiv.org/abs/2203.06145

    Original implementation taken from
    https://github.com/Intelligent-Computing-Lab-Yale/NDA_SNN/blob/main/functions/data_loaders.py
    with modifications to fit the PyTorch Transform API.
    """

    class Intensity(Enum):
        LOW = [3, 15, 8, 15] # [roll, rotate, cutout, shear]
        MEDIUM = [5, 30, 16, 30] # [roll, rotate, cutout, shear]
        HIGH = [7, 45, 24, 45] # [roll, rotate, cutout, shear]

    def __init__(self, intensity: Intensity = Intensity.LOW):
        super(NDA, self).__init__()
        self.intensity = intensity.value
        self.choices = ['roll', 'rotate', 'shear', 'cutout']
        
        self.flip = transforms.RandomHorizontalFlip()

        self.rotate = transforms.RandomRotation(degrees=self.intensity[1])
        self.cutout = Cutout(self.intensity[2])
        self.shearx = transforms.RandomAffine(degrees=0, shear=(-self.intensity[3], self.intensity[3]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if type(x) is not torch.Tensor:
            x = torch.as_tensor(x)
        x = self.flip(x)

        aug = np.random.choice(self.choices)
        if aug == 'roll':
            off1 = random.randint(-self.intensity[0], self.intensity[0])
            off2 = random.randint(-self.intensity[0], self.intensity[0])
            x = torch.roll(x, shifts=(off1, off2), dims=(2, 3))
        if aug == 'rotate':
            x = self.rotate(x)
        if aug == 'shear':
            x = self.shearx(x)
        if aug == 'cutout':
            x = self.cutout(x)

        return x
