import torch
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.transforms import InterpolationMode
import numpy as np

initial_resnet_transform = transforms.Compose(
    [
        transforms.Resize(
            256, interpolation=InterpolationMode.BILINEAR
        ),  # Resize to 256x256
        transforms.ToTensor(),  # Convert to Tensor and rescale to [0, 1]
    ]
)

custom_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop(
            size=(224, 224)
        ),  # Randomly crop the image to 224x224
        transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    ]
)
# center crops:
final_resnet_transform = transforms.Compose(
    [
        transforms.CenterCrop(
            size=(224, 224)
        ),  # Center crop the image to 224x224
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),  # Normalize
    ]
)

augmented_transform = transforms.Compose(
    [
        initial_resnet_transform,  # Initial resizing and conversion to tensor
        custom_transform,  # Apply custom augmentations
        final_resnet_transform,  # Final normalization step
    ]
)
"""The inference transforms are available at ResNet18_Weights.IMAGENET1K_V1.transforms
 and perform the following preprocessing operations: Accepts PIL.Image, batched (B, C, H, W) 
 and single (C, H, W) image torch.Tensor objects. The images are resized to resize_size=[256] 
 using interpolation=InterpolationMode.BILINEAR, followed by a central crop of crop_size=[224]. 
 Finally the values are first rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] 
 and std=[0.229, 0.224, 0.225].
"""
resnet_transform = transforms.Compose(
    [
        initial_resnet_transform,  # Initial resizing and conversion to tensor
        final_resnet_transform,  # Final normalization step
    ]
)

# have a func that based the input_size, returns the transforms


def get_resnet_transform(input_size, model_name):
    initial_resnet_transform = transforms.Compose(
        [
            transforms.Resize(
                256, interpolation=InterpolationMode.BILINEAR
            ),  # Resize to 256x256
            transforms.ToTensor(),  # Convert to Tensor and rescale to [0, 1]
        ]
    )
    final_resnet_transform = transforms.Compose(
        [
            transforms.CenterCrop(
                size=(input_size, input_size)
            ),  # Center crop the image to 224x224
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),  # Norm
        ]
    )
    resnet_transform = transforms.Compose(
        [
            initial_resnet_transform,  # Initial resizing and conversion to tensor
            final_resnet_transform,  # Final normalization step
        ]
    )
    if model_name in ['lwp']:
        # other cl models utilizing additional augmentation, to match the similar amount of augmentation
        resnet_transform = transforms.Compose(
            [
                initial_resnet_transform,  # Initial resizing and conversion to tensor
                transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
                # transforms.RandomRotation(10),  # Randomly rotate the image
                final_resnet_transform,  # Final normalization step
            ]
        )
        

    return resnet_transform


from .augmentations import jitter, magnitude_warp


def get_custom_transform(input_size, dataset_name):
    if dataset_name in ["celeba", "mtfl", "fairface"]:
        custom_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size=(input_size, input_size)
                ),  # Randomly crop the image to 224x224
                transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
            ]
        )
    elif dataset_name == "physiq":
        # jitter and magnitude warp are custom augmentations for time series but it takes np array as input
        # so do a wrapper for these two functions
        # TODO: torch implementation of time series augmentations
        class augmentation(object):
            def __init__(self, augmentation_func):
                self.augmentation_func = augmentation_func

            def __call__(self, x: torch.Tensor):
                device = x.device
                type = x.dtype
                x = x.cpu().numpy()
                if len(x.shape) <= 2:
                    x = self.augmentation_func(x[np.newaxis, ...]).squeeze()
                else:
                    x = self.augmentation_func(x)
                x = torch.from_numpy(x).to(device).type(type)
                return x

        custom_transform = transforms.Compose(
            [
                augmentation(jitter),  # Custom augmentation
                augmentation(magnitude_warp),  # Custom augmentation
            ]
        )
    else:
        raise ValueError("Dataset not found.")
    return custom_transform
