"""
Dataset Transformation Utilities

This module provides standardized image and video transformation pipelines for multimodal
AI training and evaluation. It includes vision encoder-specific normalization parameters
and preprocessing strategies optimized for different model architectures.

The transformations handle the conversion from raw image/video data to model-ready tensors
with appropriate normalization, augmentation, and resizing operations. The module supports
multiple vision encoder types including InternVideo, CLIP, and ViT architectures.

Key Features:
    - Vision encoder-specific normalization parameters
    - Training and evaluation transformation pipelines
    - Support for random augmentation during training
    - Configurable image resolution and interpolation methods
    - Standardized preprocessing for multimodal model compatibility

Supported Vision Encoders:
    - InternVideo: Video understanding models with ImageNet normalization
    - CLIP: Contrastive language-image models with custom normalization
    - ViT: Vision Transformer models with ImageNet normalization
    - UMT: Unified Multimodal Transformer models

Dependencies:
    - torchvision: PyTorch vision library for image transformations
    - torchvision.transforms: Standard image transformation operations

Technical Details:
    - Input format: torch.Tensor of torch.uint8, shape (T, C, H, W) where T=1 for images
    - Output format: Normalized float tensors ready for model input
    - Normalization: Per-channel mean and standard deviation normalization
    - Interpolation: BICUBIC interpolation for high-quality resizing

Author: AI Model Development Team
License: MIT
"""

from torchvision import transforms
from torchvision.transforms import InterpolationMode


def get_train_transform(vision_enc_name="pretrain_internvideo2_1b_patch14_224", random_aug=False, image_res=224):
    """
    Create training transformation pipeline for multimodal vision encoders.
    
    This function generates a comprehensive transformation pipeline optimized for training
    multimodal models. It includes data augmentation, normalization, and preprocessing
    steps tailored to specific vision encoder architectures.
    
    The transformation pipeline converts raw image/video tensors from uint8 format to
    normalized float tensors suitable for model training. It applies encoder-specific
    normalization parameters and optional data augmentation for improved generalization.
    
    Args:
        vision_enc_name (str, optional): Name of the vision encoder model to determine
                                       normalization parameters. Defaults to 
                                       "pretrain_internvideo2_1b_patch14_224".
                                       Supported encoders: internvideo, vit, umt, clip
        random_aug (bool, optional): Whether to apply random augmentation during training.
                                    Defaults to False. When True, applies RandAugment.
        image_res (int, optional): Target image resolution for resizing. Defaults to 224.
                                 Common values: 224, 256, 384 depending on model architecture.
    
    Returns:
        transforms.Compose: Composed transformation pipeline including:
            - Optional random augmentation (RandAugment)
            - Random resized crop with scale (0.5, 1.0)
            - Random horizontal flip for data augmentation
            - Type conversion from uint8 to float32 and normalization to [0,1]
            - Per-channel normalization with encoder-specific mean and std
    
    Raises:
        NotImplementedError: If the vision encoder name is not supported
    
    Technical Details:
        - Input: torch.Tensor of shape (T, C, H, W) with uint8 values [0, 255]
        - Output: torch.Tensor of shape (T, C, image_res, image_res) with normalized float values
        - Random resized crop scale range: 0.5 to 1.0 for robust training
        - Interpolation: BICUBIC for high-quality resizing
    
    Normalization Parameters:
        - InternVideo/ViT/UMT: ImageNet normalization (mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        - CLIP: Custom normalization (mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711])
    
    Example:
        >>> transform = get_train_transform(
        ...     vision_enc_name="clip_vit_base_patch16",
        ...     random_aug=True,
        ...     image_res=256
        ... )
        >>> processed_tensor = transform(raw_tensor)
    """
    # Determine normalization parameters based on vision encoder architecture
    if "internvideo" in vision_enc_name or "vit" in vision_enc_name or "umt" in vision_enc_name:
        # ImageNet normalization parameters for standard vision models
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
    elif "clip" in vision_enc_name:
        # CLIP-specific normalization parameters
        mean = (0.48145466, 0.4578275, 0.40821073)
        std = (0.26862954, 0.26130258, 0.27577711)
    else:
        raise NotImplementedError(f"Vision encoder '{vision_enc_name}' is not supported")
    
    # Create normalization transform with encoder-specific parameters
    normalize = transforms.Normalize(mean, std)
    
    # Convert from uint8 [0, 255] to float32 [0, 1]
    # Input format: torch.Tensor of torch.uint8, shape (T, C, H, W) where T=1 for images
    type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
    
    # Apply random augmentation if requested
    if random_aug:
        aug_transform = transforms.RandAugment()
    else:
        aug_transform = transforms.Lambda(lambda x: x)  # Identity transform
    
    # Compose complete training transformation pipeline
    train_transform = transforms.Compose([
        aug_transform,  # Optional random augmentation for training diversity
        transforms.RandomResizedCrop(
            image_res,  # Target resolution (e.g., 224x224)
            scale=(0.5, 1.0),  # Scale range for random cropping
            interpolation=InterpolationMode.BICUBIC,  # High-quality interpolation
        ),
        transforms.RandomHorizontalFlip(),  # Random horizontal flip for augmentation
        type_transform,  # Convert to float32 and normalize to [0,1]
        normalize,  # Apply encoder-specific normalization
    ])
    
    return train_transform

def get_test_transform(vision_enc_name="pretrain_internvideo2_1b_patch14_224", image_res=224):
    """
    Create evaluation transformation pipeline for multimodal vision encoders.
    
    This function generates a deterministic transformation pipeline optimized for model
    evaluation and inference. Unlike training transforms, it excludes data augmentation
    to ensure consistent and reproducible results during validation and testing.
    
    The transformation pipeline provides standardized preprocessing that matches the
    training pipeline's normalization while using deterministic resizing for reliable
    evaluation metrics and consistent model behavior.
    
    Args:
        vision_enc_name (str, optional): Name of the vision encoder model to determine
                                       normalization parameters. Defaults to 
                                       "pretrain_internvideo2_1b_patch14_224".
                                       Supported encoders: internvideo, vit, umt, clip
        image_res (int, optional): Target image resolution for resizing. Defaults to 224.
                                 Should match the resolution used during training.
    
    Returns:
        transforms.Compose: Composed transformation pipeline including:
            - Deterministic resize to target resolution
            - Type conversion from uint8 to float32 and normalization to [0,1]
            - Per-channel normalization with encoder-specific mean and std
    
    Raises:
        NotImplementedError: If the vision encoder name is not supported
    
    Technical Details:
        - Input: torch.Tensor of shape (T, C, H, W) with uint8 values [0, 255]
        - Output: torch.Tensor of shape (T, C, image_res, image_res) with normalized float values
        - Resize: Deterministic center resize to exact target resolution
        - Interpolation: BICUBIC for high-quality resizing
        - No augmentation: Ensures reproducible evaluation results
    
    Normalization Parameters:
        - InternVideo/ViT/UMT: ImageNet normalization (mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        - CLIP: Custom normalization (mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711])
    
    Example:
        >>> transform = get_test_transform(
        ...     vision_enc_name="clip_vit_base_patch16",
        ...     image_res=256
        ... )
        >>> processed_tensor = transform(raw_tensor)
        
    Note:
        The test transform should use the same normalization parameters as the training
        transform to ensure consistent model performance between training and evaluation.
    """
    # Determine normalization parameters based on vision encoder architecture
    if "internvideo" in vision_enc_name or "vit" in vision_enc_name or "umt" in vision_enc_name:
        # ImageNet normalization parameters for standard vision models
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
    elif "clip" in vision_enc_name:
        # CLIP-specific normalization parameters
        mean = (0.48145466, 0.4578275, 0.40821073)
        std = (0.26862954, 0.26130258, 0.27577711)
    else:
        raise NotImplementedError(f"Vision encoder '{vision_enc_name}' is not supported")

    # Create normalization transform with encoder-specific parameters
    normalize = transforms.Normalize(mean, std)
    
    # Convert from uint8 [0, 255] to float32 [0, 1]
    # Input format: torch.Tensor of torch.uint8, shape (T, C, H, W) where T=1 for images
    type_transform = transforms.Lambda(lambda x: x.float().div(255.0))

    # Compose deterministic evaluation transformation pipeline
    test_transform = transforms.Compose([
        transforms.Resize(
            (image_res, image_res),  # Deterministic resize to exact target size
            interpolation=InterpolationMode.BICUBIC,  # High-quality interpolation
        ),
        type_transform,  # Convert to float32 and normalize to [0,1]
        normalize,  # Apply encoder-specific normalization
    ])
    
    return test_transform
