import torch
import numpy as np
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import Dataset

__all__ = ['CELEBA']

class CELEBA(Dataset):
    def __init__(self, root, size=[32, 32], train=True, download=False, visualize=False):
        self.name = 'CELEBA'
        self.split = 'train' if train else 'test'
        self.size = size
        self.visualize = visualize
        
        if self.visualize:
            self.size = [218, 178]

        self.transform = transforms.Compose([
            transforms.Resize(self.size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

        self.dataset = datasets.CelebA(root=root, split=self.split, download=download, transform=self.transform)
        
        self.coordinates = torch.from_numpy(np.array([[int(i/self.size[1])/self.size[0],(i%self.size[1])/self.size[1]] for i in range(self.size[0] * self.size[1])])).float()

    def __getitem__(self, index):
        target, _ = self.dataset[index]
        target = target.view(target.size(0), -1).transpose(0,1)
        context = self.coordinates.clone()
        target_x, target_y = context, target
        
        if self.visualize:
            randperm = torch.randperm(218*178)[:1000]
            context_x = target_x[randperm, :]
            context_y = target_y[randperm, :]
        else:
            context_x, context_y = context, target
        return context_x, context_y, target_x, target_y

    def __len__(self):
        if self.visualize:
            return 200
        else:
            return len(self.dataset)
