import os
import PIL
from typing import Optional, List, Tuple

import numpy as np
import pandas as pd

import torch
from torchvision import datasets, transforms
from torchvision.datasets.utils import verify_str_arg

from .image_dataset import ImageDataset

from utils.data_utils import AttrEncoder


class OpenBHBDataset(ImageDataset):
    def __init__(self, root: str, label_names: List[str] = ["age", "sex"],
                 image_size: Tuple[int] = (192, 160), pad_size: int = 8, split: str = "train"):
        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
        splits = pd.read_csv(os.path.join(root, "split.tsv"), index_col=0, sep="\t")
        attr = pd.read_csv(os.path.join(root, "attr.tsv"), index_col=0, sep="\t")

        images = [os.path.join(root, "img", image) for image in splits.index]
        labels = attr[label_names].values

        transform = self.get_transform(image_size, pad_size)
        target_transform = self.get_target_transform(labels)

        if split_ is not None:
            mask = (splits.values == split_).squeeze()
            images = images[mask]
            labels = labels[mask]

        super().__init__(images, labels, transform=transform, target_transform=target_transform)

    def get_transform(self, image_size, pad_size):
        return transforms.Compose([
            transforms.CenterCrop((218, 182)),
            transforms.Resize((image_size[0] - pad_size, image_size[1] - pad_size)),
            transforms.RandomCrop(image_size, padding=pad_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor()
        ])

    def get_target_transform(self, labels):
        return AttrEncoder(labels)
