import functools
from typing import Tuple, Union, List, Callable, Optional

import pandas as pd
import tifffile as tiff
import torch
from torch.utils.data import Dataset as PyTorchDataset

from spc.dfconst import FILEPATH_COLUMN, TREATMENT_COLUMN, COMPOUND_NAME_COLUMN, COMPOUND_UM_COLUMN


def load_image_from_df_cached(
    df: pd.DataFrame,
    max_cache_size: int,
) -> Callable:
    """ Returns a function that loads an image from a dataframe, wrapped in an lru cache. """
    def f(index: int) -> torch.Tensor:
        image_fpath = df[FILEPATH_COLUMN].iloc[index]
        im = tiff.imread(image_fpath)
        im = torch.tensor(im)
        return im
    return functools.lru_cache(maxsize=max_cache_size)(f)


class LabelledDataset(PyTorchDataset):
    def __init__(
        self,
        csv_fpath: Union[str, List[str]],
        label_cols: Union[str, List[str]],
        max_cache_size: Optional[int] = 0,
        dna_pct_threshold: Optional[float] = None,
        select_channels: Optional[List[int]] = None,
    ):
        if isinstance(csv_fpath, str):
            csv_fpath = [csv_fpath]
        dfs = []
        for csv in csv_fpath:
            df = pd.read_csv(csv)
            if dna_pct_threshold is not None:
                df = df[df['cell_pct'] >= dna_pct_threshold].reset_index(drop=True)
            dfs.append(df)
        self.df = pd.concat(dfs, ignore_index=True)
        if TREATMENT_COLUMN not in self.df.columns:
            self.df[TREATMENT_COLUMN] = self.df[COMPOUND_NAME_COLUMN].astype(str) + '_' + self.df[COMPOUND_UM_COLUMN].astype(str)
        if type(label_cols) == str:
            label_cols = [label_cols]
        self.label_cols = label_cols
        # set integer label for each label
        self.df['label'] = self.df.groupby(label_cols).ngroup()
        self.load_im = load_image_from_df_cached(df=self.df, max_cache_size=max_cache_size)
        self.select_channels = select_channels

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        if self.select_channels is not None:
            im = self.load_im(index)
            im = im[self.select_channels]
            return im, int(self.df.iloc[index]['label'])
        return self.load_im(index), int(self.df.iloc[index]['label'])

    def __len__(self):
        return len(self.df)

    def get_meta(self, indices: Union[int, torch.Tensor]) -> pd.DataFrame:
        return self.df.iloc[indices]

    def get_df(self) -> pd.DataFrame:
        return self.df

    def nlabels(self):
        return self.df['label'].nunique()

    def get_label_string(self, index: int):
        # get label string joined with underscore
        return '_'.join([str(self.df.iloc[index][col]) for col in self.label_cols])


class IndexFilteredDataset(PyTorchDataset):
    def __init__(self, source_dataset: LabelledDataset, retained_indices: List[int]):
        super().__init__()
        self.source_dataset = source_dataset
        self.retained_indices = retained_indices

    def __getitem__(self, subset_index: int) -> Tuple[torch.Tensor, int]:
        return self.source_dataset[int(self.retained_indices[subset_index])]

    def __len__(self) -> int:
        return len(self.retained_indices)

    def get_df(self) -> pd.DataFrame:
        return self.source_dataset.get_df().iloc[self.retained_indices].reset_index(drop=True)


