from typing import Optional, Sequence, Tuple

import torch
import torch.nn as nn
import torchvision.transforms.functional as F

from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
    CenterCrop

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD


class ResizeMaxSize(nn.Module):

    def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
        super().__init__()
        if not isinstance(max_size, int):
            raise TypeError(f"Size should be int. Got {type(max_size)}")
        self.max_size = max_size
        self.interpolation = interpolation
        self.fn = min if fn == 'min' else min
        self.fill = fill

    def forward(self, img):
        if isinstance(img, torch.Tensor):
            height, width = img.shape[:2]
        else:
            width, height = img.size
        scale = self.max_size / float(max(height, width))
        if scale != 1.0:
            new_size = tuple(round(dim * scale) for dim in (height, width))
            img = F.resize(img, new_size, self.interpolation)
            pad_h = self.max_size - new_size[0]
            pad_w = self.max_size - new_size[1]
            img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
        return img


def _convert_to_rgb(image):
    return image.convert('RGB')


# class CatGen(nn.Module):
#     def __init__(self, num=4):
#         self.num = num
#     def mixgen_batch(image, text):
#         batch_size = image.shape[0]
#         index = np.random.permutation(batch_size)

#         cat_images = []
#         for i in range(batch_size):
#             # image mixup
#             image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
#             # text concat
#             text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
#         text = torch.stack(text)
#         return image, text

class MimicCLIPImageProcessor:
    
    def __init__(self, preprocess_func):
        self.preprocess_func = preprocess_func
        self.image_mean = preprocess_func.transforms[-1].mean
        self.image_std = preprocess_func.transforms[-1].std
        
        for transform in preprocess_func.transforms:
            if 'crop' in str(transform.__class__).lower():
                self.crop_size = {'height': transform.size[0], 'width': transform.size[1]}
        
    def preprocess(self, image, return_tensors='pt'):
        if not isinstance(image, list):
            image = [image]
            
        image_tensor = []
        for img in image:
            image_tensor.append(self.preprocess_func(img))
        image_tensor = torch.stack(image_tensor)
        return {'pixel_values': image_tensor}
        

def image_transform(
        image_size: int,
        is_train: bool,
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
        resize_longest_max: bool = False,
        fill_color: int = 0,
):
    mean = mean or OPENAI_DATASET_MEAN
    if not isinstance(mean, (list, tuple)):
        mean = (mean,) * 3

    std = std or OPENAI_DATASET_STD
    if not isinstance(std, (list, tuple)):
        std = (std,) * 3

    if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
        # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
        image_size = image_size[0]

    normalize = Normalize(mean=mean, std=std)
    if is_train:
        return Compose([
            RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])
    else:
        if resize_longest_max:
            transforms = [
                ResizeMaxSize(image_size, fill=fill_color)
            ]
        else:
            transforms = [
                Resize(image_size, interpolation=InterpolationMode.BICUBIC),
                CenterCrop(image_size),
            ]
        transforms.extend([
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])
        return Compose(transforms)
