import torch
import os.path as osp
from PIL import Image

from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
import numpy as np
#from .randaugment import RandAugmentMC
import jpeg4py as jpeg
from model.dataloader.augmentation import *



def identity(x):
    return x

def get_transforms(size, backbone, s = 1,bs = 8 , strong = True):
    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
    
    if backbone == 'ConvNet':
        normalization = transforms.Normalize(np.array([0.485, 0.456, 0.406]),
                                             np.array([0.229, 0.224, 0.225]))       
    elif backbone == 'Res12':
        normalization = transforms.Normalize(np.array([x / 255.0 for x in [120.39586422,  115.59361427, 104.54012653]]),
                                             np.array([x / 255.0 for x in [70.68188272,   68.27635443,  72.54505529]]))
    elif backbone == 'Res18' or backbone == 'Res50':
        normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
    else:
        raise ValueError('Non-supported Network Types. Please Revise Data Pre-Processing Scripts.')
    if strong:
        augmentation= [ #transforms.Resize(84 + 8),
                        #transforms.CenterCrop(84),
                        transforms.RandomResizedCrop(size, scale=(0.5, 1.)),
                        transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
                        transforms.RandomGrayscale(p=0.2),
                        transforms.RandomApply([GaussianBlur([0.1, 2])], p=0.5),
                        transforms.RandomHorizontalFlip(),
                        #transforms.Resize(84),
                        transforms.ToTensor(),
                        normalization
                    ]
    else:
        augmentation= [ #transforms.Resize(84 + 8),
                        #transforms.CenterCrop(84),
                        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                        transforms.RandomGrayscale(p=0.2),
                        transforms.RandomApply([GaussianBlur([0, 1])], p=0.5),
                        transforms.RandomHorizontalFlip(),
                        transforms.Resize(size),
                        transforms.ToTensor(),
                        normalization
                    ]
    data_transforms_aug  = multiCropsTransform(transforms.Compose(augmentation),task_num = bs)
    
    data_transforms = transforms.Compose([transforms.Resize(size + 8),
                                          transforms.CenterCrop(size),
                                          transforms.ToTensor(),
                                          normalization])
    
    return data_transforms_aug, data_transforms


class fs_dataset(Dataset):
    """ Usage:
    """
    def __init__(self, setname, args):

        IMAGE_PATH = 'data/'
        SPLIT_PATH = 'data/'

        if args.dataset == 'CUB':
            IMAGE_PATH += 'CUB/images'
        else:
            IMAGE_PATH += args.dataset
        SPLIT_PATH += args.dataset

        self.split_map = {'train':IMAGE_PATH, 'val':IMAGE_PATH, 'test':IMAGE_PATH}

        csv_path = osp.join(SPLIT_PATH, setname + '.csv')
        self.data, self.label = self.parse_csv(csv_path, setname)
        self.num_class = len(set(self.label))
        if args.dataset == 'CIFARFS' or args.dataset == 'FC100':
            image_size = 32
        else:
            image_size = 84
        self.transform_aug, self.transform = get_transforms(image_size, args.backbone_class, bs = args.batchsize , strong = args.strong)

    def parse_csv(self, csv_path, setname):
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]

        data = []
        label = []
        lb = -1

        self.wnids = []

        for l in tqdm(lines, ncols=64):
            context = l.split(',')
            name = context[0] 
            wnid = context[1]
            path = osp.join(self.split_map[setname], name)
            if wnid not in self.wnids:
                self.wnids.append(wnid)
                lb += 1
            data.append(path)
            label.append(lb)

        return data, label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        data, label = self.data[i], self.label[i]
        try:
            image = Image.fromarray(jpeg.JPEG(data).decode()).convert('RGB')
        except:
            image = Image.open(data).convert('RGB')
        # image, label = self.data[i], self.label[i]
        aug_image = self.transform_aug(image)
        image = self.transform(image)
        return image, aug_image, label
