import os
import PIL
from typing import Optional
from functools import partial

import numpy as np
import pandas as pd

import torch
from torchvision import datasets, transforms
from torchvision.datasets.utils import verify_str_arg

from .image_dataset import ImageDataset

from utils.data_utils import AttrEncoder


class CelebADataset(ImageDataset):
    permanent_images = ["172346", "168036"]

    def __init__(self, root, label_idx = [15, 31], image_size=(64, 64), split="train"):
        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]

        fn = partial(os.path.join, root)
        splits = pd.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
        attr = pd.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)

        mask = slice(None) if split_ is None else (splits[1] == split_)

        # image
        images = splits.index

        transform = self.get_transform(image_size)

        images = images[mask]
        images = [fn("img_align_celeba", image) for image in images]

        # label
        labels = attr.values[:, label_idx]
        labels = labels > 0

        target_transform = self.get_target_transform(labels)

        labels = labels[mask]

        super().__init__(images, labels, transform=transform, target_transform=target_transform)

    def get_transform(self, image_size):
        return transforms.Compose([
            transforms.CenterCrop(128),
            transforms.RandomCrop((120, 120)),
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor()
        ])

    def get_target_transform(self, labels):
        return AttrEncoder(labels)
