from numbers import Real

import torch
from torchvision.transforms import RandomVerticalFlip, RandomHorizontalFlip,\
    Compose

from mnist_auto_aug.transforms import OneEightyRotation, NinetyRotation,\
    Identity


KEY_TO_TRANSFORM = {
    'no-aug': Identity,
    'hflip': RandomHorizontalFlip,
    'vflip': RandomVerticalFlip,
    '180rot': OneEightyRotation,
    '90rot': NinetyRotation,
}


class ClasswiseTransform(torch.nn.Module):
    def __init__(self, transforms):
        super().__init__()
        self.transform_per_class = transforms

    def forward(self, img, target):
        return self.transform_per_class[target](img)


def pair_to_transform(pair):
    """ Maps a pair of [transform_name, probability] into an instantiated
    operation object
    """
    # Unfold the pair (operation, probability)
    assert len(pair) == 2,\
        f"pair should be of length 2. Got {len(pair)}"
    transform_key, probability = pair

    # Check its consistency
    assert isinstance(transform_key, str),\
        "First element of pair should be a string. " +\
        f"Got {pair}."
    assert (transform_key in KEY_TO_TRANSFORM), (
        "transforms must be one or more among " +
        str(KEY_TO_TRANSFORM.keys()) +
        f"Got {transform_key}."
    )
    assert isinstance(probability, Real) and 0 <= probability <= 1,\
        f"probability must be a float between 0 and 1. Got {probability}."

    # Create and return instantiated Transform object
    return KEY_TO_TRANSFORM[transform_key](p=probability)


def make_transform_objects(transform_tuples, compose=True):
    """Creates list of transform objects out of list of tuples of the form
    (transform_name, probability)

    Parameters
    ----------
    transform_tuple : tuple
        Should be a tuple of the form (transform_name, probability),
        where transform_name is a str listed in
        classwise_transforms.KEY_TO_TRANSFORM and probability are floats
        between 0 and 1.
    compose : boolean, optional
        Whether to compose the resulting chain of Transforms (True) or return
        a list of transforms. Defaults to True.
    """
    if transform_tuples is None:
        return Identity(1.0)
    if isinstance(transform_tuples, tuple) and len(transform_tuples) == 2:
        return pair_to_transform(transform_tuples)

    if isinstance(transform_tuples, list):
        transforms = [
            pair_to_transform(transform_tuple)
            for transform_tuple in transform_tuples
        ]
        if compose:
            return Compose(transforms)
        return transforms
    else:
        raise ValueError(
            'When not omitted, transform_tuples should be a tuple or list of '
            'tuples of the form (transform_name, probability) and '
            'type (str, float).'
        )


intuitive_policy_1 = ClasswiseTransform(
    [
        RandomVerticalFlip(0.5),
        RandomVerticalFlip(0.5),
        Identity(1),
        RandomVerticalFlip(0.5),
        Identity(1.),
        Identity(1.),
        Identity(1.),
        Identity(1.),
        RandomHorizontalFlip(0.5),
        Identity(1.),
    ]
)
