import os
import torch
import numpy as np
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import (
    default_loader,
    IMG_EXTENSIONS,
    has_file_allowed_extension,
)
from typing import cast, Tuple, Callable
import pandas as pd


def subsample_frames(samples, n_samples):
    videos = np.unique([el[2] for el in samples])
    array_samples = np.array(samples)
    subsampled_el = []
    for vid in videos:
        vid_frames = np.where(array_samples[:, 2] == vid)[0]
        indices = np.random.choice(vid_frames, size=n_samples, replace=False)
        subsampled_el += indices.tolist()
    return subsampled_el


def make_dataset_fov(directory, class_to_idx=None, extensions=None, is_valid_file=None):
    """Generates a list of samples of a form (path_to_sample, class, fov)"""
    directory = os.path.expanduser(directory)

    # TODO RAISE FOV.CSV exception

    if class_to_idx is None:
        raise ValueError(
            "class_to_idx should not be None (should have been initialized earlier)"
        )
    elif not class_to_idx:
        raise ValueError(
            "'class_to_index' must have at least one entry to collect any samples."
        )

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError(
            "Both extensions and is_valid_file cannot be None or not None at the same time"
        )

    if extensions is not None:

        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

    is_valid_file = cast(Callable[[str], bool], is_valid_file)
    df_fov = pd.read_csv(
        f"{directory}/fov.csv", delimiter="\t", header=None, names=["path", "fov"]
    )
    df_fov = df_fov.set_index("path")

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                if is_valid_file(fname):
                    path = os.path.join(root, fname)
                    fov = df_fov.loc[path]["fov"]
                    item = path, class_index, fov
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = (
            f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        )
        if extensions is not None:
            msg += f"Supported extensions are: {', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances


class ImageFolderFov(ImageFolder):
    def __init__(
        self,
        root,
        get_fov=True,
        transform=None,
        target_transform=None,
        loader=default_loader,
        is_valid_file=None,
    ):
        super(ImageFolderFov, self).__init__(
            root=root,
            transform=transform,
            target_transform=target_transform,
            loader=loader,
            is_valid_file=is_valid_file,
        )
        self.get_fov = get_fov

    @staticmethod
    def make_dataset(
        directory,
        class_to_idx,
        extensions=None,
        is_valid_file=None,
    ):
        """Generates a list of samples of a form (path_to_sample, class) or (path_to_sample, class, fov)

        Args:
            directory (str): root dataset directory, corresponding to ``self.root``.
            class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
            extensions (optional): A list of allowed extensions.
                Either extensions or is_valid_file should be passed. Defaults to None.
            is_valid_file (optional): A function that takes path of a file
                and checks if the file is a valid file
                (used to check of corrupt files) both extensions and
                is_valid_file should not be passed. Defaults to None.
        Returns:
            List[Tuple[str, int]]: samples of a form (path_to_sample, class, optional fov)
        """
        if class_to_idx is None:
            raise ValueError("The class_to_idx parameter cannot be None.")
        return make_dataset_fov(
            directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file
        )

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target, fov = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.get_fov:
            return sample, target, fov
        else:
            return sample, target
