import warnings
from dataclasses import dataclass, asdict
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torchvision.transforms.functional as F

from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
    CenterCrop

from CLIP_utils.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD

"""
Image Preprocessing Tools for CLIP_utils

This module provides utilities and classes for image preprocessing and augmentation for CLIP_utils models,
including transformations for both training and evaluation with support for custom augmentation and resizing configurations.

Key Features:
1. **`image_transform` Function**:
   - Handles image preprocessing for both training and evaluation.
   - Provides default normalization based on `OPENAI_DATASET_MEAN` and `OPENAI_DATASET_STD`.
   - Supports custom augmentation configurations (`AugmentationCfg`) for training.
   - Offers option to resize images to a maximum longest edge (`resize_longest_max`).

2. **`ResizeMaxSize` Class**:
   - Resizes images to match a specified size on the longest edge while maintaining aspect ratio.
   - Adds padding to center the image in a square canvas.

3. **Augmentation Configuration**:
   - Uses `AugmentationCfg` dataclass to specify augmentation parameters such as scale, ratio, color jitter, interpolation, etc.
   - Integrates optional `timm` library augmentations for advanced transformations.

4. **Utilities**:
   - `_convert_to_rgb`: Ensures input images are in RGB format.
   - Supports PyTorch's `torchvision.transforms` for standard resizing, cropping, and normalization.

5. **Training and Evaluation Modes**:
   - In training, random cropping and optional `timm` augmentations are applied.
   - In evaluation, deterministic resizing and cropping are used to ensure consistency.

Notes:
- Depends on `torch` and `torchvision`.
- Supports optional `timm` library for advanced augmentations when `use_timm` is set to `True`.
"""


@dataclass
class AugmentationCfg:
    """
    Configuration dataclass for image augmentation parameters.
    
    Attributes:
        scale (Tuple[float, float]): Range of size of the origin size cropped. Default: (0.9, 1.0)
        ratio (Optional[Tuple[float, float]]): Range of aspect ratio of the origin aspect ratio cropped.
        color_jitter (Optional[Union[float, Tuple[float, float, float]]]): Color jitter parameters.
        interpolation (Optional[str]): Interpolation method for resizing.
        re_prob (Optional[float]): Random erase probability.
        re_count (Optional[int]): Number of random erase operations.
        use_timm (bool): Whether to use timm library for augmentations. Default: False
    """
    scale: Tuple[float, float] = (0.9, 1.0)
    ratio: Optional[Tuple[float, float]] = None
    color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
    interpolation: Optional[str] = None
    re_prob: Optional[float] = None
    re_count: Optional[int] = None
    use_timm: bool = False


class ResizeMaxSize(nn.Module):
    """
    Resize image to a maximum size while maintaining aspect ratio.
    
    This transform resizes the image so that the longest edge matches the specified 
    maximum size, then pads the image to create a square output.
    
    Args:
        max_size (int): Maximum size for the longest edge of the image.
        interpolation (InterpolationMode): Interpolation method. Default: InterpolationMode.BICUBIC
        fn (str): Function to determine resize behavior ('min' or 'max'). Default: 'max'
        fill (int): Padding fill value. Default: 0
    """

    def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
        super().__init__()
        if not isinstance(max_size, int):
            raise TypeError(f"Size should be int. Got {type(max_size)}")
        self.max_size = max_size
        self.interpolation = interpolation
        self.fn = min if fn == 'min' else min
        self.fill = fill

    def forward(self, img):
        """
        Apply the resize transform to the input image.
        
        Args:
            img (Union[torch.Tensor, PIL.Image]): Input image to resize.
            
        Returns:
            Union[torch.Tensor, PIL.Image]: Resized image with square padding if needed.
        """
        if isinstance(img, torch.Tensor):
            height, width = img.shape[:2]
        else:
            width, height = img.size
        scale = self.max_size / float(max(height, width))
        if scale != 1.0:
            new_size = tuple(round(dim * scale) for dim in (height, width))
            img = F.resize(img, new_size, self.interpolation)
            pad_h = self.max_size - new_size[0]
            pad_w = self.max_size - new_size[1]
            img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
        return img


def _convert_to_rgb(image):
    """
    Convert image to RGB format.
    
    Args:
        image (PIL.Image): Input image.
        
    Returns:
        PIL.Image: Image converted to RGB format.
    """
    return image.convert('RGB')


def image_transform(
        image_size: int,
        is_train: bool,
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
        resize_longest_max: bool = False,
        fill_color: int = 0,
        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
        resize_size: int = 256,
):
    """
    Create image transformation pipeline for CLIP models.
    
    Args:
        image_size (int or tuple): Target image size (square) or (height, width).
        is_train (bool): Whether to use training transforms (with augmentation) or evaluation transforms.
        mean (tuple, optional): Mean values for normalization. Default: OPENAI_DATASET_MEAN
        std (tuple, optional): Standard deviation values for normalization. Default: OPENAI_DATASET_STD
        resize_longest_max (bool): If True, resize image to max size maintaining aspect ratio. Default: False
        fill_color (int): Color to use when padding images. Default: 0 (black)
        aug_cfg (dict or AugmentationCfg, optional): Augmentation configuration.
        resize_size (int): Size for resize operation before center crop in validation mode. Default: 256
        
    Returns:
        torchvision.transforms.Compose: Composed transformation pipeline.
    """
    mean = mean or OPENAI_DATASET_MEAN
    if not isinstance(mean, (list, tuple)):
        mean = (mean,) * 3

    std = std or OPENAI_DATASET_STD
    if not isinstance(std, (list, tuple)):
        std = (std,) * 3

    if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
        # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
        image_size = image_size[0]

    if isinstance(aug_cfg, dict):
        aug_cfg = AugmentationCfg(**aug_cfg)
    else:
        aug_cfg = aug_cfg or AugmentationCfg()
    normalize = Normalize(mean=mean, std=std)
    if is_train:
        aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
        use_timm = aug_cfg_dict.pop('use_timm', False)
        if use_timm:
            from timm.data import create_transform  # timm can still be optional
            if isinstance(image_size, (tuple, list)):
                assert len(image_size) >= 2
                input_size = (3,) + image_size[-2:]
            else:
                input_size = (3, image_size, image_size)
            # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
            aug_cfg_dict.setdefault('interpolation', 'random')
            aug_cfg_dict.setdefault('color_jitter', None)  # disable by default
            train_transform = create_transform(
                input_size=input_size,
                is_training=True,
                hflip=0.,
                mean=mean,
                std=std,
                re_mode='pixel',
                **aug_cfg_dict,
            )
        else:
            train_transform = Compose([
                RandomResizedCrop(
                    resize_size,
                    scale=aug_cfg_dict.pop('scale'),
                    interpolation=InterpolationMode.BICUBIC,
                ),
                _convert_to_rgb,
                ToTensor(),
                normalize,
            ])
            if aug_cfg_dict:
                warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
        return train_transform
    else:
        if resize_longest_max:
            transforms = [
                ResizeMaxSize(resize_size, fill=fill_color)
            ]
        else:
            transforms = [
                Resize(resize_size, interpolation=InterpolationMode.BICUBIC),
                CenterCrop(image_size),
            ]
        transforms.extend([
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])
        return Compose(transforms)