import tensorflow as tf

import torch
import torchvision as tv
import PIL

from .config import PP_CONFIG

class DictWrapper(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.targets = base_dataset.targets

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, index):
        image, label = self.base_dataset[index]
        return {"image": image, "label": label}

class UnifiedPreprocessor:
    def __init__(self, model_name, dataset):

        if 'cifar' in dataset:
            self.mean = (0.4914, 0.4822, 0.4465)
            self.var = (0.2023, 0.1994, 0.2010)
        else:
            self.mean = (0.5, 0.5, 0.5)
            self.var = (0.5, 0.5, 0.5)

        if PP_CONFIG[model_name]=='BiT':
            self.tfds_pp = self._BiT_tfds_pp
            self.torch_pp = self._BiT_torch_pp()
        else:
            self.tfds_pp = self._tfds_pp
            self.torch_pp = self._torch_pp()

    @staticmethod
    def _BiT_tfds_pp(image):
        image = tf.image.resize(image, [128, 128])
        image = (image - 127.5) / 127.5
        return image

    @staticmethod
    def _BiT_torch_pp():
        return tv.transforms.Compose([
            tv.transforms.Resize((128, 128), interpolation=PIL.Image.BILINEAR),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
    def _tfds_pp(self,image):
        image = tf.image.resize(image, [32, 32])
        image = image/255.
        image = (image - self.mean) / self.var
        return image
    
    def _torch_pp(self):
        return tv.transforms.Compose([  
            tv.transforms.CenterCrop(size=(32, 32)),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize(self.mean, self.var)
        ])