import os
from PIL import Image
from typing import Optional

import numpy as np
import pandas as pd

import torch
from torchvision import datasets, transforms

from .base_dataset import BaseDataset

from ..utils.general_utils import load_idx
from ..utils.data_utils import AttrEncoder


class MorphoMNISTDataset:
    base_folder = "original_data"

    def __init__(self, root, label_names=["thickness", "intensity"], image_size=(32, 32)):
        self.root = root
        self.label_names = label_names

        self.train_image_paths, self.train_attr = self.get_data_df("train")
        self.test_image_paths, self.test_attr = self.get_data_df("test")
        full_attr = pd.concat((self.train_attr, self.test_attr))

        self.transform = self.get_transform(image_size)
        self.target_transform = self.get_target_transform(full_attr.values)

    def get_split(self, split="train"):
        if split == "train":
            return BaseDataset(self.train_image_paths, self.train_attr.values,
                transform=self.transform, target_transform=self.target_transform, load_fn=Image.open)
        else:
            return BaseDataset(self.test_image_paths, self.test_attr.values,
                transform=self.transform, target_transform=self.target_transform, load_fn=Image.open)

    def get_data_df(self, split):
        if split == "train":
            image_dir = "train-images"
            path_to_labels = os.path.join(self.root, self.base_folder, "train-labels-idx1-ubyte.gz")
            path_to_attr = os.path.join(self.root, self.base_folder, "train-morpho.csv")
        else:
            image_dir = "test-images"
            path_to_labels = os.path.join(self.root, self.base_folder, "t10k-labels-idx1-ubyte.gz")
            path_to_attr = os.path.join(self.root, self.base_folder, "t10k-morpho.csv")

        image_dir = os.path.join(self.root, self.base_folder, image_dir)
        image_paths = os.listdir(image_dir)
        image_paths = sorted(image_paths, key=lambda pth: int(pth.split(".")[0]))
        image_paths = [os.path.join(image_dir, pth) for pth in image_paths]

        attrs = pd.read_csv(path_to_attr)
        attrs['label'] = load_idx(path_to_labels)
        attrs.drop(list(attrs.filter(regex='Unnamed')), axis=1, inplace=True)
        attrs = attrs[self.label_names]

        return image_paths, attrs

    def get_transform(self, image_size):
        return transforms.Compose([
            transforms.Lambda(lambda img: img.convert('L').convert('RGB') if img.mode != 'RGB' else img),
            transforms.Resize(image_size),
            transforms.ToTensor()
        ])

    def get_target_transform(self, labels):
        return AttrEncoder(labels)
