import inspect
from typing import List
from typing import Union

import kornia
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from kornia.augmentation.container import AugmentationSequential


def torchvision_to_kornia(torchvision_transforms: List[nn.Module]) -> AugmentationSequential:
    kornia_augmentations = {k: v for k, v in inspect.getmembers(kornia.augmentation)}

    def tv_to_kornia(tv: nn.Module) -> nn.Module:
        if type(tv).__name__ in kornia_augmentations:
            kornia_clz = kornia_augmentations.get(type(tv).__name__)
            kornia_trans = kornia_clz(
                **{k: tv.__dict__[k] for k in set(tv.__dict__).intersection(inspect.getfullargspec(kornia_clz).args)}
            )
            return kornia_trans
        else:
            return tv

    kornia_transforms = AugmentationSequential(
        *[tv_to_kornia(tv) for tv in torchvision_transforms]
    )
    return kornia_transforms


class GlobalContrastNormalization(torch.nn.Module):
    def __init__(self, scale='l2'):
        """
        Applies global contrast normalization to a tensor; i.e., subtracts the mean across features (pixels) and normalizes by scale,
        which is either the standard deviation, L1- or L2-norm across features (pixels).
        Note that this is a *per sample* normalization globally across features (and not across the dataset).
        """
        super().__init__()
        assert scale in ('l1', 'l2')
        self.scale = scale

    def __call__(self, x: Union[torch.tensor, Image.Image], ):
        assert isinstance(x, torch.Tensor), 'No PIL-Images supported yet.'
        n_features = int(np.prod(x.shape))
        mean = torch.mean(x)  # mean over all features (pixels) per sample
        x -= mean
        if self.scale == 'l1':
            x_scale = torch.mean(torch.abs(x))
        if self.scale == 'l2':
            x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features
        x /= x_scale
        return x


class RandomGaussianNoise(torch.nn.Module):
    def __init__(self, std: float, p: float = 0.5):
        """ Applies additive white Gaussian noise. """
        super().__init__()
        self.std = std
        self.p = p

    def __call__(self, x: torch.tensor, ):
        assert isinstance(x, torch.Tensor), 'No PIL-Images supported yet.'
        p = (torch.rand(size=(x.size(0), )) < self.p)[(None, )*(x.dim()-1)].transpose(0, -1)
        return x + self.std * torch.randn_like(x) * p

    def __repr__(self):
        return f"{self.__class__.__name__}(p={self.p}, std={self.std})"
