import csv
import time
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.transforms import transforms

import kornia.augmentation as K
from kornia.constants import Resample

MEAN = {
    'cifar': (.5245, .5013, .4562),
    'tiny': (.4823, .4471, .3952),
    'mini': (.4713, .4503, .4039),
    'cub': (.4851, .4957, .4202),
    'cars': (.4709, .4609, .4555),
    'aircraft': (.4914, .5154, .5389)
}

STD = {
    'cifar': (0.2687, 0.2599, 0.2815),
    'tiny': (0.2788, 0.2693, 0.2827),
    'mini': (0.2750, 0.2661, 0.2824),
    'cub': (0.2324, 0.2273, 0.2629),
    'cars': (.2919, .2905, .2991),
    'aircraft': (.2522, .2445, .2611)
}

class CustomTransform:
    def __init__(self, data_name, img_size):
        super(CustomTransform, self).__init__()
        self.img_size = img_size
        rnd_resizedcrop = K.RandomResizedCrop(
            size=(img_size, img_size), scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333),
            resample=Resample.BILINEAR.name,
            p=1.0, same_on_batch=False
        )
        rnd_hflip = K.RandomHorizontalFlip(
            p=0.5, same_on_batch=False
        )
        rnd_color_jitter = K.ColorJitter(
            brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, 
            p=0.8, same_on_batch=False
        )
        rnd_gray = K.RandomGrayscale(
            p=0.2, same_on_batch=False
        )

        self.transform = nn.Sequential(
            rnd_resizedcrop, 
            rnd_hflip,
            rnd_color_jitter,
            rnd_gray
        )

        self.resize = K.Resize((img_size, img_size))
        self.normalize = K.Normalize(
            MEAN[data_name], STD[data_name], p=1.
        )
    
    def __call__(self, x, test=False):
        # test
        if test:
            with torch.no_grad():
                x = self.resize(x)
                x = self.normalize(x)
            return x
        # train
        sigma = np.random.uniform(0.1, 2.0)
        rnd_gaussian_blur = K.RandomGaussianBlur(
            kernel_size=(int(self.img_size)+1,int(self.img_size)+1),
            sigma=(sigma,sigma),
            p=0.5, same_on_batch=False
        )
        with torch.no_grad():
            x = self.transform(x)
            x = rnd_gaussian_blur(x)
            x = self.normalize(x)
        return x

class SelfSupDataset(Dataset):
    def __init__(self):
        root = "ANONYMIZED"
        self.path = f"{root}/preprocess"

    def __getitem__(self, index):
        return torch.load(f"{self.path}/{index}.pt", map_location="cpu")

    def __len__(self):
        return 38400

class EpisodicDataset:
    def __init__(self, data_name, meta_level):
        assert data_name in ['cifar', 'tiny', 'mini', 'cub', 'cars', 'aircraft']

        data = torch.load(f'/ANONYMIZED/{data_name}/{meta_level}.pth')
        self.images, self.class_idx = data['images'], data['class_idx']
        self.C = self.class_idx.shape[0]        

    def get_task(self, way=5, support=10, query=15, rank=0):
        seed = int((time.time() % (rank + 1))*10000)
        torch.manual_seed(seed)

        classes = torch.randperm(self.C)[:way]

        x_tr, y_tr, x_te, y_te = [], [], [], []
        for c, cidx in enumerate(classes):
            idx = self.class_idx[cidx]
            if type(idx) != np.ndarray:
                idx = np.array(idx)
            idx = idx.astype(int)
            K = len(idx)

            x = self.images[idx]
            x = x[torch.randperm(K), :, :, :]

            x_tr.append(x[:support, :, :, :])
            y_tr.append(torch.LongTensor([c]*support))

            x_te.append(x[support:support+query, :, :, :])
            y_te.append(torch.LongTensor([c]*query))            
        
        x_tr, y_tr = torch.cat(x_tr, dim=0), torch.cat(y_tr, dim=0)
        x_te, y_te = torch.cat(x_te, dim=0), torch.cat(y_te, dim=0)

        return x_tr, y_tr, x_te, y_te