import torch
import os.path as osp
from PIL import Image

from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
import numpy as np
#from .randaugment import RandAugmentMC
import jpeg4py as jpeg
from model.dataloader.augmentation import *

THIS_PATH = osp.dirname(__file__)
ROOT_PATH = ''
ROOT_PATH2 = ''
# /amax/data/tiered_imagenet_raw
IMAGE_PATH1 = 'data/tieredimagenet'
SPLIT_PATH = osp.join(ROOT_PATH, 'data/tieredimagenet')
CACHE_PATH = osp.join(ROOT_PATH, '.cache/')
split_map = {'train':IMAGE_PATH1, 'val':IMAGE_PATH1, 'test':IMAGE_PATH1}

def identity(x):
    return x

def get_transforms(size, backbone, s = 1):
    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
    
    if backbone == 'ConvNet':
        normalization = transforms.Normalize(np.array([0.485, 0.456, 0.406]),
                                             np.array([0.229, 0.224, 0.225]))       
    elif backbone == 'Res12':
        normalization = transforms.Normalize(np.array([x / 255.0 for x in [120.39586422,  115.59361427, 104.54012653]]),
                                             np.array([x / 255.0 for x in [70.68188272,   68.27635443,  72.54505529]]))
    elif backbone == 'Res18' or backbone == 'Res50':
        normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
    else:
        raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.')
    
    augmentation= [ #transforms.Resize(84 + 8),
                    #transforms.CenterCrop(84),
                    transforms.RandomResizedCrop(size, scale=(0.5, 1.)),
                    transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.RandomApply([GaussianBlur([0.1, 2])], p=0.5),
                    transforms.RandomHorizontalFlip(),
                    #transforms.Resize(84),
                    transforms.ToTensor(),
                    normalization
                   ]
    data_transforms_aug  = multiCropsTransform(transforms.Compose(augmentation),task_num = 8)
    
    data_transforms = transforms.Compose([transforms.Resize(size + 8),
                                          transforms.CenterCrop(size),
                                          transforms.ToTensor(),
                                          normalization])
    
    return data_transforms_aug, data_transforms


class tieredImageNet(Dataset):
    """ Usage:
    """
    def __init__(self, setname, args):
        csv_path = osp.join(SPLIT_PATH, setname + '.csv')
        self.data, self.label = self.parse_csv(csv_path, setname)
        self.num_class = len(set(self.label))

        image_size = 84
        self.transform_aug, self.transform = get_transforms(image_size, args.backbone_class)

    def parse_csv(self, csv_path, setname):
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]

        data = []
        label = []
        lb = -1

        self.wnids = []

        for l in tqdm(lines, ncols=64):
            name, wnid = l.split(',')
            path = osp.join(split_map[setname], name)
            try:
                image = Image.fromarray(jpeg.JPEG(path).decode()).convert('RGB')
            except:
                image = Image.open(path).convert('RGB')
            if wnid not in self.wnids:
                self.wnids.append(wnid)
                lb += 1
            data.append( path )
            label.append(lb)

        return data, label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        # data, label = self.data[i], self.label[i]
        # try:
        #     image = Image.fromarray(jpeg.JPEG(data).decode()).convert('RGB')
        # except:
        #     image = Image.open(data).convert('RGB')
        image, label = self.data[i], self.label[i]
        aug_image = self.transform_aug(image)
        image = self.transform(image)
        return image, aug_image, label
