from typing import Union, Collection, Optional, Dict
import torch
import torchvision.transforms.functional as TF
from torch import nn
from transformers.image_processing_utils import get_size_dict
from transformers.image_transforms import (
    convert_to_rgb,
    get_resize_output_image_size,
    to_channel_dimension_format,
)
from transformers.image_utils import (
    infer_channel_dimension_format,
    validate_preprocess_arguments,
    make_list_of_images,
    valid_images,
    ImageInput,
    ChannelDimension,
    PILImageResampling,
    get_channel_dimension_axis,
)


class DifferentiableCLIPImageProcessor(nn.Module):
    def __init__(self, img_processor):
        super().__init__()
        self.original_processor = img_processor

    def resize(
        self,
        image: torch.Tensor,
        size: Dict[str, int],
        resample: PILImageResampling = PILImageResampling.BICUBIC,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ) -> torch.Tensor:
        """
        Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
        resized to keep the input aspect ratio.

        Args:
            image (`torch.Tensor`):
                Image to resize.
            size (`Dict[str, int]`):
                Size of the output image.
            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
                Resampling filter to use when resiizing the image.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the image. If not provided, it will be the same as the input image.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format of the input image. If not provided, it will be inferred.
        """
        default_to_square = True
        if "shortest_edge" in size:
            size = size["shortest_edge"]
            default_to_square = False
        elif "height" in size and "width" in size:
            size = (size["height"], size["width"])
        else:
            raise ValueError(
                "Size must contain either 'shortest_edge' or 'height' and 'width'."
            )

        output_size = get_resize_output_image_size(
            image,
            size=size,
            default_to_square=default_to_square,
            input_data_format=input_data_format,
        )
        return TF.resize(image, output_size, interpolation=resample)

    def center_crop(
        self,
        image: torch.Tensor,
        size: dict[str, int],
    ) -> torch.Tensor:
        """
        Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
        any edge, the image is padded with 0's and then center cropped.

        Args:
            image (`torch.Tensor`):
                Image to center crop.
            size (`Dict[str, int]`):
                Size of the output image.
        """
        size = get_size_dict(size)
        if "height" not in size or "width" not in size:
            raise ValueError(
                f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}"
            )

        return TF.center_crop(image, (size["height"], size["width"]))

    def normalize(
        self,
        image: torch.Tensor,
        mean: Union[float, Collection[float]],
        std: Union[float, Collection[float]],
        data_format: Optional[ChannelDimension] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ) -> torch.Tensor:
        """
        Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.

        image = (image - mean) / std

        Args:
            image (`torch.Tensor`):
                The image to normalize.
            mean (`float` or `Collection[float]`):
                The mean to use for normalization.
            std (`float` or `Collection[float]`):
                The standard deviation to use for normalization.
            data_format (`ChannelDimension`, *optional*):
                The channel dimension format of the output image. If unset, will use the inferred format from the input.
            input_data_format (`ChannelDimension`, *optional*):
                The channel dimension format of the input image. If unset, will use the inferred format from the input.
        """
        if input_data_format is None:
            input_data_format = infer_channel_dimension_format(image)

        channel_axis = get_channel_dimension_axis(
            image, input_data_format=input_data_format
        )
        num_channels = image.shape[channel_axis]

        if isinstance(mean, Collection):
            if len(mean) != num_channels:
                raise ValueError(
                    f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}"
                )
        else:
            mean = [mean] * num_channels
        mean = torch.as_tensor(mean, dtype=image.dtype, device=image.device)

        if isinstance(std, Collection):
            if len(std) != num_channels:
                raise ValueError(
                    f"std must have {num_channels} elements if it is an iterable, got {len(std)}"
                )
        else:
            std = [std] * num_channels
        std = torch.as_tensor(std, dtype=image.dtype, device=image.device)

        if input_data_format == ChannelDimension.LAST:
            image = (image - mean) / std
        else:
            image = ((image.T - mean) / std).T

        image = (
            to_channel_dimension_format(image, data_format, input_data_format)
            if data_format is not None
            else image
        )
        return image

    def preprocess(
        self, images: ImageInput, return_tensors="pt"
    ) -> torch.Tensor:
        do_resize = self.original_processor.do_resize
        size = self.original_processor.size
        size = get_size_dict(size, param_name="size", default_to_square=False)
        resample = self.original_processor.resample
        do_center_crop = self.original_processor.do_center_crop
        crop_size = self.original_processor.crop_size
        crop_size = get_size_dict(
            crop_size, param_name="crop_size", default_to_square=True
        )
        do_rescale = self.original_processor.do_rescale
        rescale_factor = self.original_processor.rescale_factor
        do_normalize = self.original_processor.do_normalize
        image_mean = self.original_processor.image_mean
        image_std = self.original_processor.image_std
        do_convert_rgb = self.original_processor.do_convert_rgb

        images = make_list_of_images(images)

        if not valid_images(images):
            raise ValueError(
                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
                "torch.Tensor, tf.Tensor or jax.ndarray."
            )
        validate_preprocess_arguments(
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
            do_center_crop=do_center_crop,
            crop_size=crop_size,
            do_resize=do_resize,
            size=size,
            resample=resample,
        )

        if do_convert_rgb:
            images = [convert_to_rgb(image) for image in images]

        # All transformations expect numpy arrays.
        input_data_format = infer_channel_dimension_format(images[0])

        all_images = []
        for image in images:
            if do_resize:
                image = self.resize(
                    image=image,
                    size=size,
                    resample=resample,
                    input_data_format=input_data_format,
                )

            if do_center_crop:
                image = self.center_crop(image=image, size=crop_size)

            # rescale not needed
            # if do_rescale:
            #     image = image * rescale_factor

            if do_normalize:
                image = self.normalize(
                    image=image,
                    mean=image_mean,
                    std=image_std,
                    input_data_format=input_data_format,
                )

            all_images.append(image)

        return (
            torch.stack(all_images, dim=0)
            if return_tensors == "pt"
            else all_images
        )
