import torch
import torchvision.transforms as transforms

from typing import List, Tuple, Optional, Dict
# from torchvision.transforms import InterpolationMode
# from torchvision.transforms import functional as F 
from torchvision.transforms.autoaugment import _apply_op
from torchvision.transforms import functional as F, InterpolationMode

from torchvision.transforms import TrivialAugmentWide


from torch import Tensor


class MyTrivialAugmentWide(TrivialAugmentWide):

    def __init__(
        self,
        num_magnitude_bins: int = 31,
        interpolation: InterpolationMode = InterpolationMode.NEAREST,
        fill: Optional[List[float]] = None,
    ) -> None:
        super().__init__(num_magnitude_bins, interpolation, fill)

    def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
        return {
            # op_name: (magnitudes, signed)
            "Identity": (torch.tensor(0.0), False),
            "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
            "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
            "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
            "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
            "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
            "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
            "Color": (torch.linspace(0.0, 0.99, num_bins), True),
            "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
            # "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
            # "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
            # "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
            "AutoContrast": (torch.tensor(0.0), False),
            # "Equalize": (torch.tensor(0.0), False),
        }


def cifar10_complex_transform():
    return transforms.Compose([
        # strength uses the default
        MyTrivialAugmentWide()])
