import io
import torch
import nori2 as nori
import torchvision.transforms as transforms
from PIL import Image

from .utils import get_unsupervised_transform


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMAGENET_NORMALIZE = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)

def IMAGENET_TRAIN_TRANSFORM(spatial_size=224):
    return transforms.Compose([
        transforms.RandomResizedCrop(spatial_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        IMAGENET_NORMALIZE
    ])

def IMAGENET_EVAL_TRANSFORM(spatial_size=224):
    return transforms.Compose([
        transforms.Resize(spatial_size + 16 * int(spatial_size / 96)),
        transforms.CenterCrop(spatial_size),
        transforms.ToTensor(),
        IMAGENET_NORMALIZE
    ])

class ImageNet(torch.utils.data.Dataset):
    def __init__(self, train, ws=False, nori_prefix='imagenet100', spatial_size=224, unsupervised_transform=False):
        super(ImageNet, self).__init__()
        if nori_prefix == 'imagenet':
            nori_name = 'imagenet.{}.nori.list'.format('train' if train else 'val')
        else:
            nori_name = 'imagenet100.{}.nori.list'.format('train' if train else 'val')
        
        if ws:
            nori_path = '/unsullied/sharefs/g:brain/imagenet/ILSVRC2012/' + nori_name
        else:
            nori_path = '/data/Dataset/ImageNet2012/' + nori_name
        
        
        self.f = None # nori.Fetcher()
        self.f_list = []

        with open(nori_path) as g:
            l = g.readline()
            while l:
                ls = l.split()
                self.f_list.append(ls)
                l = g.readline()
        
        if unsupervised_transform:
            self.transform = get_unsupervised_transform(spatial_size=spatial_size, normalize=IMAGENET_NORMALIZE)
        else:
            if train:
                self.transform = IMAGENET_TRAIN_TRANSFORM(spatial_size=spatial_size)
            else:
                self.transform = IMAGENET_EVAL_TRANSFORM(spatial_size=spatial_size)
    
    def __getitem__(self, idx):
        if self.f is None:
            self.f = nori.Fetcher()

        ls = self.f_list[idx]
        raw_img = Image.open(io.BytesIO(self.f.get(ls[0]))).convert('RGB')
        if self.transform is not None:
            img = self.transform(raw_img)
        label = int(ls[1])
        return img, label
    
    def __len__(self):
        return len(self.f_list)


class ImageNet_TwoCrops(ImageNet):
    def __init__(self, train, ws=False, nori_prefix='imagenet', spatial_size=224):
        super(ImageNet_TwoCrops, self).__init__(train, ws=ws, nori_prefix=nori_prefix)
        self.transform = get_unsupervised_transform(spatial_size=spatial_size, normalize=IMAGENET_NORMALIZE)
        self.transform_ = get_unsupervised_transform(spatial_size=spatial_size, normalize=IMAGENET_NORMALIZE)
    
    def __getitem__(self, idx):
        if self.f is None:
            self.f = nori.Fetcher()
        ls = self.f_list[idx]
        image = Image.open(io.BytesIO(self.f.get(ls[0]))).convert('RGB')
        image1 = self.transform_(image)
        image2 = self.transform(image)
        label = int(ls[1])
        return (image1, image2), label