import math
import torch


def un_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
    """Reverse normalization of an image. Move to CPU.

    Args:
        tensor (torch.Tensor): Tensor with pixel values in range (-1, 1).
        mean (tuple): Mean per channel.
        std (tuple): Standard deviation per channel.

    Returns:
        tensor: Un-normalized image as PyTorch Tensor in range [0, 255], on CPU.
    """
    tensor = tensor.cpu().float()
    for i in range(len(mean)):
        tensor[:, i, :, :] *= std[i]
        tensor[:, i, :, :] += mean[i]
    tensor *= 255.
    tensor = tensor.type(torch.uint8)

    return tensor


def make_grid(tensor, nrow=8, padding=2,
              normalize=False, range_=None, scale_each=False, pad_value=0):
    """Make a grid of images.

    Adapted from https://github.com/pytorch/vision/blob/master/torchvision/utils.py

    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
        nrow (int, optional): Number of images displayed in each row of the grid.
            The Final grid size is (B / nrow, nrow). Default is 8.
        padding (int, optional): amount of padding. Default is 2.
        normalize (bool, optional): If True, shift the image to the range (0, 1),
            by subtracting the minimum and dividing by the maximum pixel value.
        range_ (tuple, optional): tuple (min, max) where min and max are numbers,
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
        scale_each (bool, optional): If True, scale each image in the batch of
            images separately rather than the (min, max) over all images.
        pad_value (float or tuple, optional): Value for the padded pixels.
            If tuple, one per channel.
    Example:
        See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
    """
    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))

    # if list of tensors, convert to a 4D mini-batch Tensor
    if isinstance(tensor, list):
        tensor = torch.stack(tensor, dim=0)

    if tensor.dim() == 2:  # single image H x W
        tensor = tensor.view(1, tensor.size(0), tensor.size(1))
    if tensor.dim() == 3:  # single image
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
            tensor = torch.cat((tensor, tensor, tensor), 0)
        tensor = tensor.view(1, tensor.size(0), tensor.size(1), tensor.size(2))

    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
        tensor = torch.cat((tensor, tensor, tensor), 1)

    if normalize is True:
        tensor = tensor.clone()  # avoid modifying tensor in-place
        if range_ is not None:
            assert isinstance(range_, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, min, max):
            img.clamp_(min=min, max=max)
            img.add_(-min).div_(max - min + 1e-5)

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
                norm_ip(t, float(t.min()), float(t.max()))

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range_)
        else:
            norm_range(tensor, range_)

    if tensor.size(0) == 1:
        return tensor.squeeze()

    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
    grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding)

    # fill with the pad value
    if isinstance(pad_value, float) or isinstance(pad_value, int):
        grid.fill_(pad_value)
    else:
        if len(pad_value) != 3:
            raise ValueError('Specified tuple pad_value per channel, \
                              but has {} != 3 elements'.format(len(pad_value)))
        # Pad per channel
        for i, v in enumerate(pad_value):
            grid[i, :, :] = v

    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            if k >= nmaps:
                break
            grid.narrow(1, y * height + padding, height - padding)\
                .narrow(2, x * width + padding, width - padding)\
                .copy_(tensor[k])
            k = k + 1
    return grid
