from locale import normalize
import random
from torch.utils import data

import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torch

from PIL import Image

from SSL.loader import GaussianBlur

IMAGENETNORM = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

import numpy as np
import torch


class RandomPixelShuffle(object):
  """Randomly shuffle pixel and channels within an image."""

  def __init__(self, image_chw, seed):
    rng = np.random.RandomState(seed=seed)
    total_dim = np.prod(image_chw)

    self.image_chw = image_chw
    self.perm = torch.from_numpy(rng.permutation(total_dim))

  def __call__(self, tensor):
    return torch.reshape(torch.flatten(tensor)[self.perm], self.image_chw)
    

class RandomBlockShuffle(object):
  """Randomly shuffle blocks within an image."""

  def __init__(self, image_size, block_size, seed):
    if image_size % block_size != 0:
      raise KeyError(f'RandomBlockShuffle: image size {image_size} cannot be divided by block size {block_size}')
    
    image_size = int(image_size)
    block_size = int(block_size)

    self.image_size = int(image_size)
    self.block_size = int(block_size)
    self.n_blocks = int((image_size // block_size)**2)
    print(self.n_blocks)
    rng = np.random.RandomState(seed=seed)
    self.perm = torch.from_numpy(rng.permutation(self.n_blocks))

    self.unfold_op = torch.nn.Unfold(kernel_size=block_size, stride=block_size)
    self.fold_op = torch.nn.Fold(output_size=image_size, kernel_size=block_size, stride=block_size)

  def __call__(self, tensor):
    blocks = self.unfold_op(tensor.unsqueeze(0))  # (1, block_size, n_blocks)
    assert blocks.size(2) == self.n_blocks
    blocks = blocks[..., self.perm]  # shuffle blocks
    tensor = self.fold_op(blocks)  # (1, C, H, W)
    return tensor.squeeze(0)

def initialize_transform(transform_name, config, dataset, is_training):
    """
    By default, transforms should take in `x` and return `transformed_x`.
    For transforms that take in `(x, y)` and return `(transformed_x, transformed_y)`,
    set `do_transform_y` to True when initializing the WILDSSubset.    
    """
    if transform_name is None:
        return None
    elif transform_name=='image_base':
        return initialize_image_base_transform(config, dataset)
    elif transform_name=='image_resize_and_center_crop':
        return initialize_image_resize_and_center_crop_transform(config, dataset)
    elif transform_name=='image_base_ssl':
        return initialize_ssl_transform(config)
    elif transform_name=='cmnist':
        return initialize_cmnist_transform(config, dataset, is_training)
    elif transform_name=='cifar10':
        return initialize_cifar10_transform(config, dataset, is_training)
    elif transform_name == 'celebA':
        return initialize_celeba_transform(config, is_training)
    elif transform_name == 'sp':
        return initialize_sp_transform(config, is_training)
    elif transform_name == 'waterbirds':
        return initialize_wb_transform(config, is_training)
    elif transform_name == 'metashift':
        return initialize_metashift_transform(config, dataset, is_training)
    else:
        raise ValueError(f"{transform_name} not recognized")


def initialize_image_base_transform(config, dataset):
    transform_steps = []
    if dataset.original_resolution is not None and min(dataset.original_resolution)!=max(dataset.original_resolution):
        crop_size = min(dataset.original_resolution)
        transform_steps.append(transforms.CenterCrop(crop_size))
    if config.target_resolution is not None and config.dataset!='fmow':
        transform_steps.append(transforms.Resize(config.target_resolution))
    transform_steps += [
        transforms.ToTensor(),
        transforms.Normalize(*IMAGENETNORM)
        ]
    transform = transforms.Compose(transform_steps)
    return transform

def initialize_image_resize_and_center_crop_transform(config, dataset):
    """
    Resizes the image to a slightly larger square then crops the center.
    """
    assert dataset.original_resolution is not None
    assert config.resize_scale is not None
    scaled_resolution = tuple(int(res*config.resize_scale) for res in dataset.original_resolution)
    if config.target_resolution is not None:
        target_resolution = config.target_resolution
    else:
        target_resolution = dataset.original_resolution
    transform = transforms.Compose([
        transforms.Resize(scaled_resolution),
        transforms.CenterCrop(target_resolution),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return transform

def initialize_cmnist_transform(config, dataset, is_training):
    # normalize = IMAGENETNORM if config.normalize is None else config.normalize
    target_resolution = config.target_resolution
    return transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(*config.normalize)
    ])

def initialize_metashift_transform(config, dataset, is_training):
    normalize = IMAGENETNORM if config.normalize is None else config.normalize
    target_resolution = config.target_resolution
    if is_training:
        return transforms.Compose([
                transforms.RandomResizedCrop(size=target_resolution),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
    transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(target_resolution),
            transforms.ToTensor(),
            normalize,
        ])



def initialize_cifar10_transform(config, dataset, is_training):
    image_size = min(dataset.original_resolution)
    normalize = IMAGENETNORM if config.normalize is None else config.normalize
    target_resolution = config.target_resolution
    if is_training:
        return transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        # transforms.RandomCrop(32, padding=4),
        # transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0),
        #                              ratio=(3.0 / 4.0, 4.0 / 3.0),
        #                              interpolation=Image.BICUBIC),
        # transforms.Resize(target_resolution),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(*normalize)
    ])
    return transforms.Compose([
        # transforms.Resize(int(image_size * (8 / 7)), interpolation=Image.BICUBIC),
        # transforms.CenterCrop(target_resolution),
        # transforms.Resize(target_resolution),
        transforms.ToTensor(),
        transforms.Normalize(*normalize)
    ])

def initialize_celeba_transform(config, is_training):
    orig_w = 178
    orig_h = 218
    orig_min_dim = min(orig_w, orig_h)
    target_resolution = config.target_resolution
    if not is_training:
        transform = transforms.Compose([
            transforms.CenterCrop(orig_min_dim),
            transforms.Resize(target_resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    else:
        # Orig aspect ratio is 0.81, so we don't squish it in that direction any more
        transform = transforms.Compose([
            transforms.RandomResizedCrop(
                size=target_resolution,
                scale=(0.7, 1.0),
                ratio=(1.0, 1.3333333333333333),
                interpolation=2,
            ),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    return transform

def initialize_wb_transform(config, is_training):
    scale = 256.0 / 224.0
    target_resolution = config.target_resolution
    assert target_resolution is not None

    if not is_training:
        # Resizes the image to a slightly larger square then crops the center.
        transform = transforms.Compose([
            transforms.Resize((
                int(target_resolution[0] * scale),
                int(target_resolution[1] * scale),
            )),
            transforms.CenterCrop(target_resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    else:
        transform = transforms.Compose([
            transforms.RandomResizedCrop(
                size=target_resolution,
                scale=(0.7, 1.0),
                ratio=(0.75, 1.3333333333333333),
                interpolation=2,
            ),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    return transform

def initialize_sp_transform(config, is_training):
    resize_resolution = (256.0, 256.0)
    target_resolution = config.target_resolution
    assert target_resolution is not None

    if not is_training:
        # Resizes the image to a slightly larger square then crops the center.
        transform = transforms.Compose([
            transforms.Resize(resize_resolution),
            transforms.CenterCrop(target_resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    else:
        transform = transforms.Compose([
            transforms.RandomResizedCrop(
                target_resolution,
                scale=(0.7, 1.0),
                ratio=(0.75, 1.3333333333333333),
                interpolation=2,
            ),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    return transform

def initialize_ssl_transform(config): 
    normalize_factor = IMAGENETNORM 
    if hasattr(config, 'normalize') and config.normalize is not None:
        normalize_factor = config.normalize
    transform_steps = []
    if config.dataset == 'celebA':
        image_size = config.target_resolution
        transform_steps = [
            transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0),
                        ratio=(1.0, 1.3333333333333333),
                        interpolation=2,),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(*normalize_factor)
        ]
    elif config.dataset == 'waterbirds':
        image_size = config.target_resolution
        transform_steps = [
            transforms.RandomResizedCrop(
                image_size,
                scale=(0.7, 1.0),
                ratio=(0.75, 1.3333333333333333),
                interpolation=2,
            ),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(*normalize_factor)
        ]
    else: 
        image_size = config.target_resolution
        transform_steps = [
            transforms.RandomResizedCrop(image_size, scale=(0.2, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(*normalize_factor)
        ]
    
    additional_transforms = []
    # if config.img_shuffle in ['block', 'pixel']:
    #     if config.img_shuffle == 'block':
    #         block_size = image_size[0] / config.n_blocks
    #         block_shuffler = RandomBlockShuffle(image_size=image_size[0], block_size=block_size, seed=config.seed)
    #         additional_transforms.append(block_shuffler)
    #     if config.img_shuffle == 'pixel':
    #         pixel_shuffler = RandomPixelShuffle(image_chw=(3, image_size[0], image_size[1]),
    #                                             seed=config.seed)
    #         additional_transforms.append(pixel_shuffler)    

    train_transforms = transforms.Compose(transform_steps + additional_transforms)
        
    return train_transforms
