from torch.utils.data import Dataset
from skimage.io import imread
import glob
import pandas as pd
import numpy as np, os
import torch
import torchvision.transforms as trans
from torch.utils.data import Dataset, DataLoader, random_split
from functools import partialmethod



class CelebADataset(Dataset):
    """CelebA dataset with 64 by 64 images."""
    def __init__(self, path_to_data, subsample=1, transform=None, train = True, pivot = False, condition = None):
        """
        Parameters
        ----------
        subsample : int
            Only load every |subsample| number of images.
        """
        self.root = path_to_data
        #partition = [162770, 162770 + 19867, 162770 + 19867 + 19962]
        path_to_imgs = os.path.join(self.root, 'img_resize_celeba')
        #path_to_eval = os.path.join(self.root, 'Eval', 'list_eval_partition.txt')
        if pivot:
            self.img_paths = sorted(glob.glob(path_to_imgs + '/*'))[:1][::subsample]
        elif train:
            self.img_paths = sorted(glob.glob(path_to_imgs + '/*'))[:162770][::subsample]
        else:
            self.img_paths = sorted(glob.glob(path_to_imgs + '/*'))[162770:][::subsample]

        if condition == 'test':
            self.img_paths = self.img_paths[:500]
        self.transform = transform
        
        image_transforms = [trans.ToTensor(),
                            trans.ConvertImageDtype(torch.float32)]


        latent_transforms = [trans.Lambda(lambda x: torch.from_numpy(x).to(
                            dtype=torch.float32))]

        if transform is not None:
            image_transforms.append(transform)
        self.transform = trans.Compose(image_transforms)
        self.target_transform = trans.Compose(latent_transforms)

    def __len__(self):
        return len(self.img_paths)
    def get_supervised(self):
        return self
    def __getitem__(self, idx):
        sample_path = self.img_paths[idx]
        sample = imread(sample_path)

        if self.transform:
            sample = self.transform(sample)
        # Since there are no labels, we just return 0 for the "label" here
        return sample, 0