### Preamble ##########################################################################################################

"""
DO NOT USE: superseded by torchvision.transforms.

Image preprocessing functions that allow autograd with respect to the image. Necessary for classifier diffusion 
guidance and adversarial diffusion.
"""

#######################################################################################################################

### Imports ###########################################################################################################

import torch
from torch.nn.functional import interpolate, pad
from typing import Union, Iterable, Optional, Tuple

#######################################################################################################################


def grad_normalize(
    input: torch.Tensor,
    mean: Union[float, torch.Tensor],
    std: Union[float, torch.Tensor],
) -> torch.Tensor:
    """
    :param input: torch.Tensor
        (*, h, w) array of values to be normalised.
    :param mean: float or torch.Tensor
        (*) array of floats or float representing mean(s).
    :param std: ndarray
        (*) array of floats or float representing standard deviation(s).

    Normalises 'input', i.e.,

    output = (input - mean) / std

    Note that 'input' is expected to be of shape (*, h, w), 'mean' and 'std' can be either scalar or of shape (*).
    Normalisation is performed across each (h, w) slice.
    """

    allowed_types = [
        torch.float16,
        torch.bfloat16,
        torch.float32,
        torch.float64,
    ]

    # Tensor and dtype checks
    if not isinstance(input, torch.Tensor):
        raise TypeError("'input' must be a pytorch Tensor")
    if not input.dtype in allowed_types:
        raise ValueError("'input' must be a floating point tensor")

    if not isinstance(mean, torch.Tensor):
        mean = torch.tensor(mean, dtype=input.dtype, device=input.device)
    elif (mean.dtype != input.dtype) or (mean.device != input.device):
        mean = mean.to(dtype=input.dtype, device=input.device)

    if not isinstance(std, torch.Tensor):
        std = torch.tensor(std, dtype=input.dtype, device=input.device)
    elif (std.dtype != input.dtype) or (std.device != input.device):
        std = std.to(dtype=input.dtype, device=input.device)

    # Dimension checks
    allowed_dims = [
        torch.Size([]),
        torch.Size([1]),
        torch.Size(input.shape[:-2]),
    ]
    if input.ndim == 1:
        raise ValueError("'input' must have a minimum of 2 dims")
    elif input.ndim >= 2:
        if not mean.shape in allowed_dims:
            raise ValueError(
                "'mean' must be either scalar or iterable with " f"shape in {allowed_dims} got {mean.shape}"
            )
        if not std.shape in allowed_dims:
            raise ValueError(
                "'std' must be either scalar or iterable with " f"shape in {allowed_dims} got {std.shape}"
            )

    # Normalise
    # Don't do in-place operations to ensure autograd support
    mean = mean.unsqueeze(-1).unsqueeze(-1)
    std = std.unsqueeze(-1).unsqueeze(-1)

    output = (input - mean) / std

    return output


def grad_scale(
    input: torch.Tensor,
    scale: Union[float, torch.Tensor],
) -> torch.Tensor:
    """
    :param input: torch.Tensor
        (*, h, w) array of values to be normalised.
    :param scale: float or torch.Tensor
        (*) array of floats or float representing scale(s).

    Scales 'input', i.e.,

    output = input * scale

    Note that 'input' is expected to be of shape (*, h, w), and 'scale' can be either scalar or of shape (*). Scaling
    is performed across each (h, w) slice.
    """

    allowed_types = [
        torch.float16,
        torch.bfloat16,
        torch.float32,
        torch.float64,
    ]

    # Tensor and dtype checks
    if not isinstance(input, torch.Tensor):
        raise TypeError("'input' must be a pytorch Tensor")
    if not input.dtype in allowed_types:
        raise ValueError("'input' must be a floating point tensor")

    if not isinstance(scale, torch.Tensor):
        scale = torch.tensor(scale, dtype=input.dtype, device=input.device)
    elif (scale.dtype != input.dtype) or (scale.device != input.device):
        scale = scale.to(dtype=input.dtype, device=input.device)

    # Dimension checks
    allowed_dims = [
        torch.Size([]),
        torch.Size([1]),
        torch.Size(input.shape[:-2]),
    ]
    if input.ndim == 1:
        raise ValueError("'input' must have a minimum of 2 dims")
    elif input.ndim >= 2:
        if not scale.shape in allowed_dims:
            raise ValueError(
                "'scale' must be either scalar or iterable with " f"shape in {allowed_dims} got {scale.shape}"
            )

    # Normalise
    # Don't do in-place operations to ensure autograd support
    scale = scale.unsqueeze(-1).unsqueeze(-1)

    output = input * scale

    return output


def grad_resize(
    input: torch.Tensor,
    size: Union[int, Tuple[int, int]],
    preserve_aspctratio: bool = True,
    mode: str = "nearest",
) -> torch.Tensor:
    """
    :param input: torch.Tensor
        (_, _, h, w), (_, h, w), or (h, w) array of images to be resized.
    :param size: int or (int, int)
        The length of the smallest resized side or the resized image dimensions, i.e., (h', w'). Note that if a single
        `int` is provided, then the smallest side will be scaled to the desired length and the larger side will be
        scaled to preserve the aspect ratio.
    :param preserve_aspctratio: bool
        Whether to preserve the aspect ratio of the image during the transformation. Note that this may clip a
        dimension of the image. Had no effect if `size` is a single integer.
    :param mode: str
        Method to be used for the interpolation. Options are "nearest" and "bilinear". These are passed to the
        torch.nn.functional.interpolate method. See pytorch documentation for further details.

    Resizes a batch of images to the desired (h', w') dimensions. Supports preserving aspect ratio by clipping to the
    desired size. Note that this function is a wrapper for torch.nn.functional.interpolate.
    """

    allowed_types = [
        torch.float16,
        torch.bfloat16,
        torch.float32,
        torch.float64,
    ]

    # Tensor and dtype checks
    if not isinstance(input, torch.Tensor):
        raise TypeError("'input' must be a pytorch Tensor")
    if not input.dtype in allowed_types:
        raise ValueError("'input' must be a floating point tensor")

    indim = input.ndim
    if indim == 2:
        input = input.unsqueeze(0).unsqueeze(0)
    elif indim == 3:
        input = input.unsqueeze(0)
    elif (indim < 2) or (indim > 4):
        raise ValueError("'input' must have dimensions 2, 3, or 4")

    # size check
    if isinstance(size, tuple):
        if not all(isinstance(i, int) for i in size):
            raise TypeError("'size' must be an integer or tuple of integers, got non int")
        if len(size) != 2:
            raise ValueError(f"'size' must have length 2, got {len(size)}")
    elif not isinstance(size, int):
        raise TypeError(f"'size' must be an integer or tuple of integers, got {type(size)}")

    # mode check
    if mode not in ["nearest", "bilinear"]:
        raise ValueError("'mode' must be either 'nearest' or 'bilinear'")

    if isinstance(size, int):
        h = input.shape[-2]
        w = input.shape[-1]
        if h > w:
            scale_factor = size / w
            h = int(round(h * scale_factor, 0))
            w = size
        elif h < w:
            scale_factor = size / h
            w = int(round(w * scale_factor, 0))
            h = size
        else:
            h = size
            w = size
        output = interpolate(input, size=(h, w), mode=mode)
    elif preserve_aspctratio:
        ah = size[0] / input.shape[-2]
        aw = size[1] / input.shape[-1]
        if ah > aw:
            h = size[0]
            w = int(round(ah * input.shape[-1], 0))

            lpoint = int((w - size[1]) // 2)
            upoint = lpoint + size[1]
            output = interpolate(input, size=(h, w), mode=mode)
            output = output[:, :, :, lpoint:upoint]
        else:
            h = int(round(aw * input.shape[-2], 0))
            w = size[1]

            lpoint = int((h - size[0]) // 2)
            upoint = lpoint + size[0]
            output = interpolate(input, size=(h, w), mode=mode)
            output = output[:, :, lpoint:upoint, :]
    else:
        h = size[0]
        w = size[1]
        output = interpolate(input, size=(h, w), mode=mode)

    if indim == 2:
        output = output[0, 0, :, :]
    elif indim == 3:
        output = output[0, :, :, :]

    return output


def grad_centercrop(input: torch.Tensor, size: Union[int, Tuple[int, int]]) -> torch.Tensor:
    """
    :param input: torch.Tensor
        (*, h, w) array of images to be cropped.
    :param size: int or (int, int)
        The output size of  the image(s).

    Center crops the images to the desired shape. Pads the image with 0s if it is smaller than the desired crop size.
    """

    # Tensor and dtype checks
    if not isinstance(input, torch.Tensor):
        raise TypeError("'input' must be a pytorch Tensor")

    # size check
    if isinstance(size, tuple):
        if not all(isinstance(i, int) for i in size):
            raise TypeError("'size' must be an integer or tuple of integers, got non int")
        if not all(i > 0 for i in size):
            raise ValueError(f"'size' must be an integer or tuple of integers > 0, got {size}")
        if len(size) != 2:
            raise ValueError(f"'size' must have length 2, got {len(size)}")
        ch, cw = size
    elif isinstance(size, int):
        if size > 0:
            ch = size
            cw = size
        else:
            raise ValueError(f"'size' must be an integer or tuple of integers > 0, got {size}")
    else:
        raise TypeError(f"'size' must be an integer or tuple of integers, got {type(size)}")

    h = input.shape[-2]
    w = input.shape[-1]

    # Padding if input is smaller than the crop dims
    if (ch > h) or (cw > w):
        lh_pad, rh_pad, lw_pad, rw_pad = (0, 0, 0, 0)
        if ch > h:
            lh_pad = (ch - h) // 2
            rh_pad = ch - h - lh_pad
        if cw > w:
            lw_pad = (cw - w) // 2
            rw_pad = cw - w - lw_pad

        input = pad(input, (lw_pad, rw_pad, lh_pad, rh_pad))
        h = input.shape[-2]
        w = input.shape[-1]

    lh = (h - ch) // 2
    uh = lh + ch

    lw = (w - cw) // 2
    uw = lw + cw

    return input[..., lh:uh, lw:uw]


#######################################################################################################################
