# -*- coding: UTF-8 -*-


import random
import cv2
from PIL import ImageFilter
import numpy as np
import torchvision.transforms as transforms


class MultiCropTransform(object):
    """
    Applies a series of random cropping transformations to an input image, generating multiple cropped versions.
    
    Args:
        old_transform (torchvision.transforms.Compose): The original transformation to be applied to the image.
        size_crops (List[int]): The desired size of each cropped image.
        nmb_crops (List[int]): The number of cropped images to generate for each crop size.
        min_scale_crops (List[float]): The minimum scale factor for the random cropping.
        max_scale_crops (List[float]): The maximum scale factor for the random cropping.
    
    Returns:
        List[torch.Tensor]: A list of cropped image tensors.
    """

    def __init__(self,
                 old_transform: transforms.Compose,
                 size_crops,
                 nmb_crops,
                 min_scale_crops,
                 max_scale_crops):
        """
        Applies a set of transformations to an input image, including multiple random crops of varying sizes and scales.
        
        Args:
            old_transform (transforms.Compose): The original transformation to be applied to the image.
            size_crops (List[int]): The desired size of the cropped images.
            nmb_crops (List[int]): The number of crops to generate for each crop size.
            min_scale_crops (List[float]): The minimum scale factor for the random resized crop.
            max_scale_crops (List[float]): The maximum scale factor for the random resized crop.
        
        Raises:
            AssertionError: If the lengths of the input lists do not match.
        """
        
        assert len(size_crops) == len(nmb_crops)
        assert len(min_scale_crops) == len(nmb_crops)
        assert len(max_scale_crops) == len(nmb_crops)

        trans = []
        for i in range(len(size_crops)):
            # REPLACE
            transform = []
            for t in old_transform.transforms:
                if isinstance(t, transforms.RandomResizedCrop):
                    transform.append(transforms.RandomResizedCrop(
                        size_crops[i],
                        scale=(min_scale_crops[i], max_scale_crops[i]),
                    ))
                    continue
                transform.append(t)

            trans.extend([transforms.Compose(transform)] * nmb_crops[i])
        self.trans = trans

    def __call__(self, img):
        """
        Applies a list of transformations to an input image and returns the resulting list of transformed images.
        
        Args:
            img (PIL.Image.Image): The input image to be transformed.
        
        Returns:
            list[PIL.Image.Image]: A list of transformed images.
        """

        multi_crops = list(map(lambda trans: trans(img), self.trans))
        return multi_crops
