from __future__ import division
from typing import Any, List, Sequence, Tuple, Union
import torch
from torch.nn import functional as F
import numpy as np


class ImageList(object):
    """
    Structure that holds a list of images (of possibly
    varying sizes) as a single tensor.
    This works by padding the images to the same size,
    and storing in a field the original sizes of each image

    Attributes:
        image_sizes (list[tuple[int, int]]): each tuple is (h, w)
    """

    def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]):
        """
        Arguments:
            tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1
            image_sizes (list[tuple[int, int]]): Each tuple is (h, w).
        """
        self.tensor = tensor
        self.image_sizes = image_sizes

    def __len__(self) -> int:
        return len(self.image_sizes)

    def __getitem__(self, idx: Union[int, slice]) -> torch.Tensor:
        """
        Access the individual image in its original size.

        Returns:
            Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1
        """
        size = self.image_sizes[idx]
        return self.tensor[idx, ..., : size[0], : size[1]]  # type: ignore

    def to(self, *args: Any, **kwargs: Any) -> "ImageList":
        cast_tensor = self.tensor.to(*args, **kwargs)
        return ImageList(cast_tensor, self.image_sizes)

    @staticmethod
    def from_tensors(
        tensors: Sequence[torch.Tensor], size_divisibility: int = 0, pad_value: float = 0.0
    ) -> "ImageList":
        """
        Args:
            tensors: a tuple or list of `torch.Tensors`, each of shape (Hi, Wi) or
                (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded with `pad_value`
                so that they will have the same shape.
            size_divisibility (int): If `size_divisibility > 0`, also adds padding to ensure
                the common height and width is divisible by `size_divisibility`
            pad_value (float): value to pad

        Returns:
            an `ImageList`.
        """
        assert len(tensors) > 0
        assert isinstance(tensors, (tuple, list))
        for t in tensors:
            assert isinstance(t, torch.Tensor), type(t)
            assert t.shape[1:-2] == tensors[0].shape[1:-2], t.shape
        # per dimension maximum (H, W) or (C_1, ..., C_K, H, W) where K >= 1 among all tensors
        max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))

        if size_divisibility > 0:
            import math

            stride = size_divisibility
            max_size = list(max_size)  # type: ignore
            max_size[-2] = int(math.ceil(max_size[-2] / stride) * stride)  # type: ignore
            max_size[-1] = int(math.ceil(max_size[-1] / stride) * stride)  # type: ignore
            max_size = tuple(max_size)

        image_sizes = [im.shape[-2:] for im in tensors]

        if len(tensors) == 1:
            # This seems slightly (2%) faster.
            # TODO: check whether it's faster for multiple images as well
            image_size = image_sizes[0]
            padded = F.pad(
                tensors[0],
                [0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0]],
                value=pad_value,
            )
            batched_imgs = padded.unsqueeze_(0)
        else:
            batch_shape = (len(tensors),) + max_size
            batched_imgs = tensors[0].new_full(batch_shape, pad_value)
            for img, pad_img in zip(tensors, batched_imgs):
                pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)

        return ImageList(batched_imgs.contiguous(), image_sizes)

    @staticmethod
    def from_tensors_crop(
        tensors: Sequence[torch.Tensor], crop_size: int = 224, ratio: int=1
    ) -> "ImageList":
        """
        Args:
            tensors: a tuple or list of `torch.Tensors`, each of shape (Hi, Wi) or
                (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded with `pad_value`
                so that they will have the same shape.
            size_divisibility (int): If `size_divisibility > 0`, also adds padding to ensure
                the common height and width is divisible by `size_divisibility`
            pad_value (float): value to pad

        Returns:
            an `ImageList`.
        """
        assert len(tensors) > 0
        assert isinstance(tensors, (tuple, list))
        for t in tensors:
            assert isinstance(t, torch.Tensor), type(t)
            assert t.shape[1:-2] == tensors[0].shape[1:-2], t.shape
        # per dimension maximum (H, W) or (C_1, ..., C_K, H, W) where
        # K >= 1 among all tensors
        max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))

        image_sizes = [im.shape[-2:] for im in tensors]

        # resize the images to half size of the original size
        croped_tensors = torch.rand(len(tensors), tensors[0].size(0),
                                    crop_size, crop_size)

        new_image_sizes = []
        for i, tensor in enumerate(tensors):
            image_size = image_sizes[i]
            tensor = tensor.unsqueeze(1) # add the channel dimension here
            resized_image = F.interpolate(tensor, scale_factor=ratio).squeeze()
            crop_image = crop_tensor(resized_image, (crop_size, crop_size))
            croped_tensors[i] = crop_image
            new_image_sizes.append(crop_image.shape[-2:])

        return ImageList(croped_tensors.contiguous(), new_image_sizes)


def crop_tensor(image, crop_sizes):
    image = image.clone()
    indx = image.size(-2) - crop_sizes[0]
    indy = image.size(-1) - crop_sizes[1]
    if indx == 0:
        startx = 0
    else:
        startx = np.random.choice(indx)
    if indy == 0:
        starty = 0
    else:
        starty = np.random.choice(indy)
    return image[:, startx:startx+crop_sizes[0],
           starty:starty+crop_sizes[1]]
