import io
import torch
import nori2 as nori

from PIL import Image
import torchvision.transforms as transforms
from .utils import NORMALIZE, get_unsupervised_transform


def DOMAIN_TRAIN_TRANSFORM(spatial_size=64):
    return transforms.Compose([
        transforms.RandomResizedCrop(spatial_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        NORMALIZE
    ])

def DOMAIN_TEST_TRANSFORM(spatial_size=64):
    return transforms.Compose([
        transforms.Resize(spatial_size + 16 * int(spatial_size / 96)),
        transforms.CenterCrop(spatial_size),
        transforms.ToTensor(),
        NORMALIZE
    ])



class DomainNet(torch.utils.data.Dataset):
    def __init__(self, train, domain='domainnet', domain_path spatial_size=64, unsupervised_transform=False):
        super(DomainNet, self).__init__()

        domain = '{}.{}.nori.list'.format(nori_prefix, 'train' if train else 'val')
        nori_path = '/data/Dataset/DomainNet/' + nori_name

        self.f = None # nori.Fetcher()
        self.f_list = []

        with open(nori_path) as g:
            l = g.readline()
            while l:
                ls = l.split()
                self.f_list.append(ls)
                l = g.readline()
        
        if unsupervised_transform:
            self.transform = get_unsupervised_transform(normalize=NORMALIZE, spatial_size=spatial_size)
        else:
            if train:
                self.transform = DOMAIN_TRAIN_TRANSFORM(spatial_size=spatial_size) 
            else:
                self.transform =  DOMAIN_TEST_TRANSFORM(spatial_size=spatial_size)
    
    def __getitem__(self, idx):
        if self.f is None:
            self.f = nori.Fetcher()

        ls = self.f_list[idx]
        raw_img = Image.open(io.BytesIO(self.f.get(ls[0]))).convert('RGB')
        if self.transform is not None:
            img = self.transform(raw_img)
        label = int(ls[1])
        return img, label
    
    def __len__(self):
        return len(self.f_list)


class DomainNet_TwoCrops(DomainNet):
    def __init__(self, train, nori_prefix='domainnet', spatial_size=64):
        super(DomainNet_TwoCrops, self).__init__(train, nori_prefix=nori_prefix)
        self.transform = get_unsupervised_transform(normalize=NORMALIZE, spatial_size=spatial_size)
        self.transform_ = get_unsupervised_transform(normalize=NORMALIZE, spatial_size=spatial_size)

    def __getitem__(self, idx):
        if self.f is None:
            self.f = nori.Fetcher()
        ls = self.f_list[idx]
        image = Image.open(io.BytesIO(self.f.get(ls[0]))).convert('RGB')
        image1 = self.transform_(image)
        image2 = self.transform(image)
        label = int(ls[1])
        return (image1, image2), label