from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoImageProcessor
import torch
from image_shuffle_transform import PatchShuffleTransform

def get_clip_val_transforms(
    noise_option,
    image_size=224,
    mean=[0.48145466, 0.4578275, 0.40821073],
    std=[0.26862954, 0.26130258, 0.27577711],
):
    transform_list = [
        transforms.Resize(size=image_size, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True),
        transforms.CenterCrop(size=(image_size, image_size)),
        _convert_to_rgb,
        transforms.ToTensor(),
    ]

    if noise_option == 'patch_shuffle':
        transform_list.append(PatchShuffleTransform(patch_size=28))  # adjust patch size if needed

    transform_list.append(transforms.Normalize(mean=mean, std=std))
    print(transform_list)
    return transforms.Compose(transform_list)


def get_model_transforms(
    model_name,
    noise_option
):
    
    if model_name.startswith("open-clip:"):
        return get_clip_val_transforms(noise_option)
    else:
        try:
            img_processor = AutoImageProcessor.from_pretrained(model_name)
            img_size = img_processor.size['height']
            transform_list = [
                    transforms.Resize(size=img_size, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True),
                    transforms.CenterCrop(size=(img_size, img_size)),
                    _convert_to_rgb,
                    transforms.ToTensor(),
            ]

            if noise_option == 'patch_shuffle':
                transform_list.append(PatchShuffleTransform(patch_size=28))  # adjust patch size if needed

            transform_list.append(transforms.Normalize(mean=img_processor.image_mean, std=img_processor.image_std))
            print(transform_list)
            return transforms.Compose(transform_list)
        except:
            mean = [0.5, 0.5, 0.5]
            std = [0.5, 0.5, 0.5]
            return transforms.Compose([
                transforms.Resize(size=224, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True),
                transforms.CenterCrop(size=(224, 224)),
                _convert_to_rgb,
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)
            ])
            # raise ValueError(f"Image processor for {model_name} not found. Please define the appropriate data transforms")
    
