import numbers
import random

import numpy as np
import torch
from PIL import Image
from torchvision import transforms


class ResizeImage:
    def __init__(self, size):
        if isinstance(size, int):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        th, tw = self.size
        return img.resize((th, tw))


class RandomSizedCrop(object):
    """Crop the given PIL.Image to random size and aspect ratio.
    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.
    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        h_off = random.randint(0, img.shape[1] - self.size)
        w_off = random.randint(0, img.shape[2] - self.size)
        img = img[:, h_off : h_off + self.size, w_off : w_off + self.size]
        return img


class Normalize(object):
    """Normalize an tensor image with mean and standard deviation.
    Given mean: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = channel - mean
    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
    """

    def __init__(self, mean=None, meanfile=None):
        if mean:
            self.mean = mean
        else:
            arr = np.load(meanfile)
            self.mean = torch.from_numpy(arr.astype("float32") / 255.0)[[2, 1, 0], :, :]

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        # TODO: make efficient
        for t, m in zip(tensor, self.mean):
            t.sub_(m)
        return tensor


class PlaceCrop(object):
    """Crops the given PIL.Image at the particular index.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (w, h), a square crop (size, size) is
            made.
    """

    def __init__(self, size, start_x, start_y):
        if isinstance(size, int):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.start_x = start_x
        self.start_y = start_y

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be cropped.
        Returns:
            PIL.Image: Cropped image.
        """
        th, tw = self.size
        return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th))


class ForceFlip(object):
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be flipped.
        Returns:
            PIL.Image: Randomly flipped image.
        """
        return img.transpose(Image.FLIP_LEFT_RIGHT)


class CenterCrop(object):
    """Crops the given PIL.Image at the center.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be cropped.
        Returns:
            PIL.Image: Cropped image.
        """
        w, h = (img.shape[1], img.shape[2])
        th, tw = self.size
        w_off = int((w - tw) / 2.0)
        h_off = int((h - th) / 2.0)
        img = img[:, h_off : h_off + th, w_off : w_off + tw]
        return img


def image_train(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    else:
        normalize = Normalize(meanfile="./ilsvrc_2012_mean.npy")
    return transforms.Compose(
        [
            ResizeImage(resize_size),
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
    )


def image_test(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    else:
        normalize = Normalize(meanfile="./ilsvrc_2012_mean.npy")

    start_center = (resize_size - crop_size - 1) / 2

    return transforms.Compose(
        [ResizeImage(resize_size), PlaceCrop(crop_size, start_center, start_center), transforms.ToTensor(), normalize]
    )


def image_test_10crop(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    else:
        normalize = Normalize(meanfile="./ilsvrc_2012_mean.npy")
    start_first = 0
    start_center = (resize_size - crop_size - 1) / 2
    start_last = resize_size - crop_size - 1
    data_transforms = [
        transforms.Compose(
            [
                ResizeImage(resize_size),
                ForceFlip(),
                PlaceCrop(crop_size, start_first, start_first),
                transforms.ToTensor(),
                normalize,
            ]
        ),
        transforms.Compose(
            [
                ResizeImage(resize_size),
                ForceFlip(),
                PlaceCrop(crop_size, start_last, start_last),
                transforms.ToTensor(),
                normalize,
            ]
        ),
        transforms.Compose(
            [
                ResizeImage(resize_size),
                ForceFlip(),
                PlaceCrop(crop_size, start_last, start_first),
                transforms.ToTensor(),
                normalize,
            ]
        ),
        transforms.Compose(
            [
                ResizeImage(resize_size),
                ForceFlip(),
                PlaceCrop(crop_size, start_first, start_last),
                transforms.ToTensor(),
                normalize,
            ]
        ),
        transforms.Compose(
            [
                ResizeImage(resize_size),
                ForceFlip(),
                PlaceCrop(crop_size, start_center, start_center),
                transforms.ToTensor(),
                normalize,
            ]
        ),
        transforms.Compose(
            [
                ResizeImage(resize_size),
                PlaceCrop(crop_size, start_first, start_first),
                transforms.ToTensor(),
                normalize,
            ]
        ),
        transforms.Compose(
            [ResizeImage(resize_size), PlaceCrop(crop_size, start_last, start_last), transforms.ToTensor(), normalize]
        ),
        transforms.Compose(
            [ResizeImage(resize_size), PlaceCrop(crop_size, start_last, start_first), transforms.ToTensor(), normalize]
        ),
        transforms.Compose(
            [ResizeImage(resize_size), PlaceCrop(crop_size, start_first, start_last), transforms.ToTensor(), normalize]
        ),
        transforms.Compose(
            [
                ResizeImage(resize_size),
                PlaceCrop(crop_size, start_center, start_center),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    ]
    return data_transforms


def inv_preprocess(imgs, num_images=1):
    """Inverse preprocessing of the batch of images.

    Args:
      imgs: batch of input images.
      num_images: number of images to apply the inverse transformations on.
      img_mean: vector of mean colour values.
      numpy_transform: whether change RGB to BGR during img_transform.

    Returns:
      The batch of the size num_images with the same spatial dimensions as the input.
    """

    def norm_ip(img, min, max):
        img.clamp_(min=min, max=max)
        img.add_(-min).div_(max - min + 1e-5)

    norm_ip(imgs, float(imgs.min()), float(imgs.max()))
    return imgs
