import os
import PIL
from typing import Optional
from functools import partial

import numpy as np
import pandas as pd

import torch
from MorphoMNIST.morphomnist import io
from torchvision import datasets, transforms
from torchvision.datasets.utils import verify_str_arg

from .image_dataset import ImageDataset

from utils.data_utils import AttrEncoder

EPS = 1e-3

class MorphoMNISTDataset(ImageDataset):
    def __init__(self, root, label_names = ["thickness", "intensity"], split="train"):
        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
        self.split = split
        self.root = root

        def get_data_df(split): 
            if split == "train":
                image_dir = "train-images"
                path_to_labels = os.path.join(root, "original_data", "train-labels-idx1-ubyte.gz")
                path_to_attr = os.path.join(root, "original_data", "train-morpho.csv")

                cf_image_paths = None
            else:
                image_dir = "test-images"
                path_to_labels = os.path.join(root, "original_data", "t10k-labels-idx1-ubyte.gz")
                path_to_attr = os.path.join(root, "original_data", "t10k-morpho.csv")

                cf_absolute_image_dir = os.path.join(root, "perturbation_data", image_dir) # "perturbation_data_thickness", "perturbation_data_intensity"
                cf_image_paths = os.listdir(cf_absolute_image_dir)
                cf_image_paths = sorted(cf_image_paths, key=lambda pth: int(pth.split(".")[0]))
                cf_image_paths = [os.path.join(root, "perturbation_data", image_dir, pth) for pth in cf_image_paths] # "perturbation_data_thickness", "perturbation_data_intensity"

            self.image_dir = image_dir

            absolute_image_dir = os.path.join(root, "original_data", image_dir)
            image_paths = os.listdir(absolute_image_dir)
            image_paths = sorted(image_paths, key=lambda pth: int(pth.split(".")[0]))
            image_paths = [os.path.join(root, "original_data", image_dir, pth) for pth in image_paths]

            labels = io.load_idx(path_to_labels)
            attrs = pd.read_csv(path_to_attr)
            attrs['label'] = labels
            attrs.drop(list(attrs.filter(regex='Unnamed')), axis=1, inplace=True)

            return cf_image_paths, image_paths, attrs

        _, train_image_paths, train_df = get_data_df("train")
        cf_test_image_paths, test_image_paths, test_df = get_data_df("test")
        complete_df = pd.concat((train_df, test_df))
        attr_encoder_labels = complete_df[label_names].values

        if split == "train":
            image_paths = train_image_paths
            labels = train_df[label_names].values
        elif split == "valid" or split == "test":
            image_paths = test_image_paths
            labels = test_df[label_names].values

        cf_image_paths = cf_test_image_paths
        cf_label_df = pd.read_csv(os.path.join(root, "perturbation_data", "t10k-morpho.csv")) # "perturbation_data_thickness", "perturbation_data_intensity"
        cf_label_df = cf_label_df.set_index("index")
        self.cf_label_df = cf_label_df

        transform = self.get_transform(split=split)
        target_transform = self.get_target_transform(attr_encoder_labels)

        self.cf_image_paths = cf_image_paths
        super().__init__(image_paths, labels, transform=transform, target_transform=target_transform)

    def get_transform(self, split="train"):
        if split == "train":
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomCrop((32, 32), padding=4)
            ])
        else:
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.Pad(padding=2)
            ])

    def get_target_transform(self, labels):
        # return labels
        return AttrEncoder(labels)

    def get_cf_image(self, index: int): 
        image = PIL.Image.open(self.cf_image_paths[index])
        return self.transform(image)

    # def get_cf_label(self, index):
    #     return self.cf_label_df.iloc[index][["thickness", "intensity"]].to_numpy()

    def get_cf_label(self, index):
        if self.split == "train":
            return super().get_cf_label(index)
        else:
            return self.cf_label_df.iloc[index][["thickness", "intensity"]].to_numpy()
    