import os
from typing import Callable, Optional
import torch
import pandas as pd

from .datasetfolder import DatasetFolder, StandardTransform
from .utils import BinaryLoader


class CSVDataset(DatasetFolder):
    yaml_tag = u'!CSVDataset'

    def __init__(
        self,
        csv_path: str,
        data_path: str,
        extensions = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
        cache_size: int = None
    ) -> None:
        data_path = os.path.expanduser(data_path)
        self.csv_path = csv_path
        self.data_path = data_path

        has_transforms = transforms is not None
        has_separate_transform = transform is not None or target_transform is not None
        if has_transforms and has_separate_transform:
            raise ValueError("Only transforms or transform/target_transform can be passed as argument")

        if has_separate_transform:
            transforms = StandardTransform(transform, target_transform)
        self.transforms = transforms
        classes, class_to_idx, samples, metadata_paths = self.make_dataset(self.csv_path, self.data_path)

        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.metadata_paths = metadata_paths
        self.targets = [s[1] for s in samples]

        if cache_size:
            self._loader = BinaryLoader(cache_size=cache_size)
        else:
            self._loader = BinaryLoader()
        self.loader = lambda x: self._loader(x, metadata_path=self.metadata_paths.get(x, None))

    def make_dataset(self, csv_path: str, data_path: str):
        return make_dataset(csv_path, data_path)

    def get_metadata_path(self, path: str):
        return self.metadata_paths.get(path, None)


def make_dataset(csv_path: str, data_path: str):
    df = pd.read_csv(csv_path, dtype={'path': str, 'metadata_path': str, 'target': int, 'class': str})
    # metadata_path may be missing
    df.metadata_path = df.metadata_path.replace({pd.NA: None})

    classes = df["class"].unique().tolist()
    _df = df.drop_duplicates(subset=["class"])
    class_to_idx = {c: t for t, c in zip(_df["target"], _df["class"])}

    def get_path(rel_path):
        if rel_path is None:
            return None
        return os.path.join(data_path, rel_path)

    paths = df["path"].apply(get_path)
    metadata_paths = df["metadata_path"].apply(get_path)
    targets = df["target"]
    metadata_paths = dict(zip(paths.to_list(), metadata_paths.to_list()))
    samples = list(zip(paths.to_list(), targets.to_list()))
    return classes, class_to_idx, samples, metadata_paths
