import os
from PIL import Image

import numpy as np
import pandas as pd

import torch
from torchvision import datasets, transforms
from torchvision.datasets.utils import verify_str_arg

from .base_dataset import BaseDataset


class PendulumDataset(BaseDataset):
    def __init__(self, root, label_names=["i", "j"], image_size=(64, 64), split="train"):
        attrs_df = pd.read_csv(os.path.join(root, "pendulum_label_downstream.txt"))
        splits = attrs_df['partition'].values
        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
        split = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
        mask = slice(None) if split is None else (splits == split)

        # image
        images = []
        for split in ("train", "valid", "test"):
            imgdir = os.path.join(root, split)
            imgs = os.listdir(imgdir)
            images += [os.path.join(imgdir, k) for k in imgs]

        transform = self.get_transform(image_size)

        images = np.array(images)[mask]

        # label
        labels = attrs_df[label_names].values
        labels[labels == -1] = 0

        target_transform = self.get_target_transform(labels)

        labels = labels[mask].astype(np.float)

        super().__init__(images, labels, transform=transform, target_transform=target_transform, load_fn=Image.open)

    def get_transform(self, image_size):
        return transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor()
        ])

    def get_target_transform(self, labels):
        return None
