from collections.abc import Sequence
import torch
from PIL import Image

try:
    import accimage
except ImportError:
    accimage = None

from torchvision.transforms import functional as F

_pil_interpolation_to_str = {
    Image.NEAREST: 'PIL.Image.NEAREST',
    Image.BILINEAR: 'PIL.Image.BILINEAR',
    Image.BICUBIC: 'PIL.Image.BICUBIC',
    Image.LANCZOS: 'PIL.Image.LANCZOS',
    Image.HAMMING: 'PIL.Image.HAMMING',
    Image.BOX: 'PIL.Image.BOX',
}

class RandomResize(torch.nn.Module):
    """Resize the input image to the given size.
    The image can be a PIL Image or a torch Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions

    Args:
        size_range (sequence): The image is resized with its shorter side randomly sampled in size_range for scale augmentation
        as in ResNet paper (https://arxiv.org/pdf/1512.03385.pdf)
        interpolation (int, optional): Desired interpolation enum defined by `filters`_.
            Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
            and ``PIL.Image.BICUBIC`` are supported.
    """

    def __init__(self, size_range, interpolation=Image.BILINEAR):
        super().__init__()
        if not isinstance(size_range, (int, Sequence)):
            raise TypeError("Size should be int or sequence. Got {}".format(type(size_range)))
        if isinstance(size_range, Sequence) and len(size_range) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
        self.size_range = size_range
        self.interpolation = interpolation
        self.resize_to = None

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be scaled.

        Returns:
            PIL Image or Tensor: Rescaled image.
        """
        self.resize_to = self.size_range[0] + int(float(torch.rand(1)) * (self.size_range[1] - self.size_range[0]))
        return F.resize(img, self.resize_to, self.interpolation)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size_range={0}, interpolation={1})'.format(self.size_range, interpolate_str)


# Lighting data augmentation take from here - https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py
class Lighting(object):
    """Lighting noise(AlexNet - style PCA - based noise)"""

    def __init__(self, alphastd, eigval, eigvec):
        self.alphastd = alphastd
        self.eigval = eigval
        self.eigvec = eigvec

    def __call__(self, img):
        if self.alphastd == 0:
            return img

        alpha = img.new().resize_(3).normal_(0, self.alphastd)
        rgb = self.eigvec.type_as(img).clone()\
            .mul(alpha.view(1, 3).expand(3, 3))\
            .mul(self.eigval.view(1, 3).expand(3, 3))\
            .sum(1).squeeze()
        return img.add(rgb.view(3, 1, 1).expand_as(img))