from collections import defaultdict
import logging
import copy
from math import ceil
from pathlib import Path
import shutil
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
from scipy.sparse import vstack
from sklearn.preprocessing import OneHotEncoder
import torch
from torch.utils.data import (
    BatchSampler,
    DataLoader,
    TensorDataset,
    WeightedRandomSampler,
    random_split,
    Subset,
)
import scanpy as sc

from ccvae.data.utils import load_mtx_format
from ccvae.data.configs import (
    CELLIGNER,
    KANG,
    KANG_TRVAE,
    CATEGORY_NAME,
    CONDITION_NAME,
    KANG_CATEGORIES,
    KANG_TRVAE_CATEGORIES,
    UCI_INCOME
)
from ccvae.data.dataset_file_cache import DatasetFileCache

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)


batch_sampling_modes = ['batched-by-category', 'inverse-weighted', 'uniform']

def make_strata(categories: pd.Series, batch_size: int) -> Dict[str, np.ndarray]:
    # Category (str) -> list(index) (List[int])
    id_dict = defaultdict(list)
    for idx, category in enumerate(categories):
        id_dict[category].append(idx)
    for id_list in id_dict.values():
        id_list.sort()

    # Transform id_dict, grouping together all categories with
    # fewer elements than the batch size under a new "Other"
    # category.
    strata: Dict[str, np.ndarray] = defaultdict(lambda: np.array([], dtype=np.int32))
    for k, v in id_dict.items():
        if len(v) < batch_size:
            strata["Other"] = np.concatenate([strata["Other"], v],
                                             axis=-1,
                                             dtype=np.int32)
        else:
            strata[k] = np.array(v)
    return strata

class StratifiedSampler:
    def __init__(self, strata, batch_size=128, drop_last=True):
        for s, v in strata.items():
            assert s == 'Other' or len(v) >= batch_size, f'Category {s} is smaller than the batch_size {batch_size}. Only the "Other" category of strata can be smaller.'
        self.strata = strata
        self.batch_size = batch_size
        self._drop_last = drop_last
        # __len__ should return the number of batches in one epoch
        if drop_last:
            self._len = sum(len(v) // self.batch_size for v in self.strata.values())
        else:
            self._len = sum(int(np.ceil(len(v) / self.batch_size)) for v in self.strata.values())

    def __len__(self):
        return self._len

    def __iter__(self):
        for stratum in self.strata.values():
            np.random.shuffle(stratum)
            if self._drop_last:
                L = (len(stratum) // self.batch_size) * self.batch_size
            else:
                L = int(np.ceil(len(stratum) / self.batch_size)) * self.batch_size
            for i in range(0, L, self.batch_size):
                batch = stratum[i:i + self.batch_size]
                yield batch


def strata_inv_weights(strata: Dict[str, np.ndarray]) -> pd.Series:
    counts = pd.DataFrame.from_dict({k: v.sum() for k, v in strata.items()},
                                    orient='index', dtype=np.int32, columns=['sample_count']).sample_count
    weights = counts.sum() / counts
    return weights


def inv_weighted_data_loaders(tensor_dataset: TensorDataset,
                              metadata_df: pd.DataFrame,
                              strata_column: str,
                              strata: Dict[str, np.ndarray],
                              batch_size: int,
                              other_stratum_label: str = 'Other') -> Tuple[DataLoader, DataLoader]:
    strata_weights = strata_inv_weights(strata)
    # Map each sample's stratum label to that label's inverse weight, handling the
    # grouping of smaller strata under an 'Other' label.
    sample_weights = metadata_df[strata_column].apply(
        lambda x: strata_weights[x]
        if x in strata_weights
        else strata_weights[other_stratum_label]
    )
    # Wrapping in a function wth a fixed seed to allow reproduction in the valid
    # loader below.
    def make_weighted_sampler(sample_weights):
        sampler_rng = torch.Generator()
        sampler_rng.manual_seed(0x62c663e50eb651da)
        weighted_sampler = WeightedRandomSampler(sample_weights.to_numpy(),
                                                 len(sample_weights),
                                                 replacement=False,
                                                 generator=sampler_rng)
        return weighted_sampler

    weighted_sampler = make_weighted_sampler(sample_weights)
    # XXXX No train/valid split for stratified sampling
    train_loader = DataLoader(
        tensor_dataset, batch_size=batch_size, sampler=weighted_sampler, shuffle=False, num_workers=0,
    )
    # Need to fix the order of iteration for the valid set, even if some smaller diseases
    # end up with suboptimal batch positions (e.g. only very few samples in one batch).
    # Create a separate weighted sampler with the same key to make it easier to reproduce
    # the order of samples for the training set, otherwise training would start from the
    # second pass through an epoch.
    fixed_sampler = BatchSampler(np.array(list(make_weighted_sampler(sample_weights))),
                                 batch_size=batch_size, drop_last=False)
    valid_loader = DataLoader(tensor_dataset, batch_sampler=fixed_sampler, num_workers=0)
    return train_loader, valid_loader


def load_celligner(data_dir, top_var_number=None):
    """Load the celligner data

    Returns:
        feature_df (pd.DataFrame): n x p data frame containing the gene expression features
        metadata_df (pd.DataFrame): n x m data frame containing metadata. For Celligner, this
            contains the type (CL or tumor), the disease and the disease subtype (if known)
    """
    data_dir = Path(data_dir)
    hgnc_df = pd.read_csv(data_dir / CELLIGNER["hgnc_file"], delimiter="\t")
    info_df = pd.read_csv(data_dir / CELLIGNER["info_file"], index_col=0)
    # Convert gex tables to float32 to save memory. Doing so within read_csv throws an error.
    tumor_df = (
        pd.read_csv(data_dir / CELLIGNER["tumor_file"], delimiter="\t", index_col=0).set_index("Gene").astype(np.float32).T
    )
    tumor_df = tumor_df.loc[:, ~tumor_df.columns.duplicated()]
    cl_df = pd.read_csv(data_dir / CELLIGNER["cl_file"], index_col=0).astype(np.float32)
    cl_df.columns = cl_df.columns.map(lambda s: s.split(" (ENS")[0])
    cl_df = cl_df.loc[:, ~cl_df.columns.duplicated()]

    common_genes = cl_df.columns & tumor_df.columns
    cl_df = cl_df[common_genes]
    assert cl_df.shape[1] == len(
        common_genes
    ), f"{cl_df.shape[1]} != {len(common_genes)}"
    tumor_df = tumor_df[common_genes]
    assert tumor_df.shape[1] == len(common_genes)

    # Filter most varying genes
    if top_var_number is not None:
        LOGGER.info(f"Variance filtering with {top_var_number} top gene features.")
        cl_top  = list(cl_df.var(axis=0).sort_values(ascending=False).iloc[0:top_var_number].index)
        tumor_top = list(tumor_df.var(axis=0).sort_values(ascending=False).iloc[0:top_var_number].index)
        all_top = list(set(cl_top + tumor_top))
        cl_df = cl_df.loc[:, all_top]
        tumor_df = tumor_df.loc[:, all_top]
    

    gex_df = pd.concat([tumor_df, cl_df])

    func_genes = set(
        hgnc_df[~hgnc_df.locus_group.isin(["non-coding RNA", "pseudogene"])].symbol
    )
    gex_df = gex_df[gex_df.columns.intersection(func_genes)]
    LOGGER.info(f"No. of selected gene features: {gex_df.shape[1]}")
    info_df = info_df.loc[gex_df.index]
    return gex_df, info_df[["disease", "subtype", "type"]]


def load_kang(data_dir):
    """Load the kang data

    Returns:
        feature_df (pd.DataFrame): n x p data frame containing the gene expression features
        metadata_df (pd.DataFrame): n x m data frame containing metadata. For Kang, this
            contains the perturbation status
    """
    filenames = KANG
    cells: Dict[str, List[Any]] = {"perturbed": [], "unperturbed": []}
    with open(data_dir / filenames["unperturbed_cell_file"], "r") as f:
        for line in f.readlines():
            # add suffix so that barcodes are unique across experiments
            cells["unperturbed"].append(line.strip() + '_unperturbed')
    LOGGER.info("Read %d unperturbed cell names", len(cells["unperturbed"]))
    with open(data_dir / filenames["perturbed_cell_file"], "r") as f:
        for line in f.readlines():
            # add suffix so that barcodes are unique across experiments
            cells["perturbed"].append(line.strip() + '_perturbed')
    LOGGER.info("Read %d perturbed cell names", len(cells["unperturbed"]))

    genes = []
    with open(data_dir / filenames["gene_name_file"]) as f:
        for line in f.readlines():
            _, name = [a.strip() for a in line.split("\t")]
            genes.append(name)

    unp_mat_file = data_dir / filenames["unperturbed_mat_file"]
    p_mat_file = data_dir / filenames["perturbed_mat_file"]
    unp_cache_file = data_dir / filenames["unperturbed_cache_file"]
    p_cache_file = data_dir / filenames["perturbed_cache_file"]
    cell_meta_file = data_dir / filenames["cell_metadata_file"]
    unperturbed = load_mtx_format(unp_mat_file, cache=unp_cache_file, reprocess=True)
    LOGGER.info("Loaded unperturbed matrix of size %s:", unperturbed.shape)
    perturbed = load_mtx_format(p_mat_file, cache=p_cache_file, reprocess=True)
    LOGGER.info("Loaded perturbed matrix of size %s:", perturbed.shape)
    additional_metadata = pd.read_csv(cell_meta_file, index_col=0)
    # add suffix so that barcodes are unique across experiments
    additional_metadata.index = np.where(
        additional_metadata['stim'] == 'ctrl',
        (additional_metadata.index + '_unperturbed').str.strip(),
        (additional_metadata.index + '_perturbed').str.strip(),
    )
    # replace errors in indices from muscData so that they match with GEO data
    additional_metadata.index = additional_metadata.index.str.replace('11', '1')
    LOGGER.info("Loaded additional cell metadata of size %s:", additional_metadata.shape)
    metadata = pd.DataFrame(
        data={
            "type": ["unperturbed"] * len(cells["unperturbed"])
            + ["perturbed"] * len(cells["perturbed"])
        },
        index=cells["unperturbed"] + cells["perturbed"],
    )
    metadata.index = metadata.index.str.strip()
    metadata = metadata.join(additional_metadata)
    LOGGER.info("Joined metadata with additional cell metadata. Now metadata is of size %s", metadata.shape)

    gex, metadata = prepare_singlecell_data(
        vstack([unperturbed, perturbed]),
        metadata,
        genes,
        keep_only_singlets=True, # remove doublets and ambiguous cells
    )
    LOGGER.info("Loaded gene expression data: %s", gex.shape)

    return gex, metadata


def load_kang_trvae(data_dir):
    """Load the kang data

    Returns:
        feature_df (pd.DataFrame): n x p data frame containing the gene expression features
        metadata_df (pd.DataFrame): n x m data frame containing metadata. For Kang, this
            contains the perturbation status
    """
    adata = sc.read(data_dir / KANG_TRVAE['counts'])
    LOGGER.info(f'Read data into scanpy, data size: {adata.X.shape}')
    sc.pp.normalize_total(adata, inplace=True)
    sc.pp.log1p(adata)
    LOGGER.info(f'Normalised data')
    sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=2000)
    LOGGER.info(f'Found 2000 most highly variable genes')
    adata = adata[:, adata.var['highly_variable']]
    LOGGER.info(f'Subsetted data using 2000 most highly variable genes, data size: {adata.X.shape}')
    gex = pd.DataFrame(
        adata.X,
        index=adata.obs.index,
        columns=adata.var.index
    )
    adata.obs['type'] = np.where(adata.obs['stim'] == 'CTRL', 'unperturbed', 'perturbed')
    adata.obs['cell'] = adata.obs['cell_type']
    LOGGER.info(f'Outputting gene expression and metadata DataFrames')
    return gex, adata.obs


def load_kang_trvae_counts(data_dir):
    """Load the kang data

    Returns:
        feature_df (pd.DataFrame): n x p data frame containing the gene expression features
        metadata_df (pd.DataFrame): n x m data frame containing metadata. For Kang, this
            contains the perturbation status
    """
    adata = sc.read(data_dir / KANG_TRVAE['counts'])
    adata_orig = copy.deepcopy(adata)
    LOGGER.info(f'Read data into scanpy, data size: {adata.X.shape}')
    sc.pp.normalize_total(adata, inplace=True)
    sc.pp.log1p(adata)
    LOGGER.info(f'Normalised data')
    sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=2000)
    LOGGER.info(f'Found 2000 most highly variable genes')
    adata_orig = adata_orig[:, adata.var['highly_variable']]
    LOGGER.info(f'Subsetted data using 2000 most highly variable genes, data size: {adata.X.shape}')
    gex = pd.DataFrame(
        adata_orig.X,
        index=adata_orig.obs.index,
        columns=adata_orig.var.index
    )
    adata_orig.obs['type'] = np.where(adata_orig.obs['stim'] == 'CTRL', 'unperturbed', 'perturbed')
    adata_orig.obs['cell'] = adata_orig.obs['cell_type']
    LOGGER.info(f'Outputting gene expression and metadata DataFrames')
    return gex, adata_orig.obs


# def load_kang_trvae(data_dir):
#     """Load the kang data

#     Returns:
#         feature_df (pd.DataFrame): n x p data frame containing the gene expression features
#         metadata_df (pd.DataFrame): n x m data frame containing metadata. For Kang, this
#             contains the perturbation status
#     """
#     LOGGER.info(f'Reading in expression and metadata DataFrames')
#     gex = pd.read_csv(data_dir / 'gex.csv', index_col=0)
#     metadata_df = pd.read_csv(data_dir / 'metadata.csv', index_col=0)
#     return gex, metadata_df


def load_uci_income(data_dir):
    def load_csv(path):
        return pd.read_csv(path, index_col=None, delimiter=',')
    features_df = load_csv(data_dir / UCI_INCOME['features_file'])
    metadata_df = load_csv(data_dir / UCI_INCOME['metadata_file'])
    assert len(features_df) == len(metadata_df)
    return features_df, metadata_df


def make_onehot_labels(identifiers, dtype, use_cuda):
    labels = torch.as_tensor(
        OneHotEncoder().fit_transform(identifiers.to_numpy().reshape(-1, 1)).todense(),
        dtype=dtype,
    )
    if use_cuda:
        return labels.cuda()
    else:
        return labels


def celligner_labels(include_diseases, metadata_df, use_cuda):
    columns = ['type']
    if include_diseases:
        columns.append('disease')
    tensors = []
    for column_name in columns:
        labels = make_onehot_labels(metadata_df[column_name], torch.float32, use_cuda)
        LOGGER.info(f'Using \'{column_name}\' labels with shape {labels.shape}')
        tensors.append(labels)
    return tensors


def kang_labels(metadata_df, use_cuda):
    tensors = []
    labels = make_onehot_labels(metadata_df['type'], torch.float32, use_cuda)
    LOGGER.info(f'Using \'type\' labels with shape {labels.shape}')
    tensors.append(labels)
    return tensors


def uci_income_labels(metadata_df, use_cuda):
    gender_col_name = 'type'
    tensors = []
    labels = make_onehot_labels(metadata_df[gender_col_name], torch.float32, use_cuda)
    LOGGER.info(f'Using gender labels from \'type\' column with shape {labels.shape}')
    tensors.append(labels)
    return tensors


def load_direct(loader_fn, data_dir):
    LOGGER.info("Loading directly, not from and to cache")
    features_df, metadata_df = loader_fn(data_dir)
    features_tensor = torch.as_tensor(features_df.to_numpy(), dtype=torch.float32)
    return dict(features_tensor=features_tensor, metadata_df=metadata_df)


def load_cached(loader_fn, data_dir, dataset_name):
    features_cache_filename = 'features.pt'
    metadata_df_cache_filename = 'metadata_df.parquet'

    cache = DatasetFileCache(data_dir)
    if not cache.is_cache_valid():
        try:
            LOGGER.info(f'load_cached: Cache invalid for {dataset_name} at {str(data_dir)}.')
            cache_dir = cache.init_cache_dir()
            LOGGER.info(f'Initialised cache dir {str(cache_dir)}.')
            LOGGER.info(f'Loading {dataset_name} data...')
            features_df, metadata_df = loader_fn(data_dir)
            # Saving to compressed parquet reduces hte file size from 1000 MB to 700 MB but loading
            # takes 19s vs <1s.
            features_tensor = torch.as_tensor(features_df.to_numpy(), dtype=torch.float32)
            features_cache_path = cache_dir / 'features.pt'
            torch.save(features_tensor, features_cache_path)
            metadata_df.to_parquet(cache_dir / metadata_df_cache_filename)
            LOGGER.info(f'{dataset_name} data cached to {str(cache_dir)}.')
        except Exception as e:
            LOGGER.error(f'Exception during creation of cached {dataset_name} dataset: {e}')
            LOGGER.error('Removing cache_dir.')
            cache.remove_cache_dir()
            raise
    else:
        cache_dir = cache.get_cache_dir()
        assert cache_dir is not None
        LOGGER.info(f'load_cached: Using cached data for {dataset_name} {str(data_dir)}.')
        features_tensor = torch.load(cache_dir / features_cache_filename)
        # pd.read_csv creates a DataFrame so convert the first column to a pd.Index
        metadata_df = pd.read_parquet(cache_dir / metadata_df_cache_filename)
    return dict(features_tensor=features_tensor, metadata_df=metadata_df)


def prepare_singlecell_data(
    gex, metadata, genes, min_total_counts_cell=10, min_total_counts_gene=100, keep_only_singlets=True,
):
    filter_genes = np.array(gex.sum(0) > min_total_counts_gene).squeeze()
    filter_cells = np.array(gex.sum(1) > min_total_counts_cell).squeeze()
    if keep_only_singlets:
        filter_cells = filter_cells & (metadata['multiplets'] == 'singlet').values
    csr_filter_cells = np.where(filter_cells)[0]
    csr_filter_genes = np.where(filter_genes)[0]
    gex = gex.tocsr()[csr_filter_cells, :][:, csr_filter_genes]
    # Apply the filters to the metadata
    metadata = metadata.iloc[filter_cells]
    genes = np.array(genes)[filter_genes]
    return pd.DataFrame(gex.todense(), index=metadata.index, columns=genes), metadata


def prepare_training_data(data_dir,
                          load_data_fn,
                          batch_size,
                          input_label_fn,
                          use_cuda,
                          batch_sampling,
                          strata_fn,
                          dataset_name,
                          categories_to_leave_out=None,
                          condition_to_leave_out=None,
                          use_cache=True,
                          ):
    """
    Loads data and returns a TensorDataset, train and validation loaders, and the metadata
    associated with each sample. Faster loading versions of the data frames returned by
    load_data_fn are cached under the data directory, based on the hash of the files in the
    data_dir.

    Args:
        data_dir (Path): Path to the directory to load the data from.
        load_data_fn (callable Path -> (pd.DataFrame, pd.DataFrame)): The loader function
            for the datase, taking the path to the data directory as its argument and returning
            features (samples x genes) and metadata (samples x metadata) DataFrames.
        batch_size (int): Batch size for the data loaders.
        input_label_fn (Callable[[pd.DataFrame, bool], List[torch.Tensor]]: optional function
        to extract a list of label tensors from the metadata DataFrame to be used as additional
        input data.
        use_cude (bool): Use CUDA for the dataset and loaders.
        batch_sampling (str): Optional, if not None must be either 'batched-by-category', 
            'inv-weighted' or 'uniform'.
        strata_fn (callable (pd.DataFrame, int) -> Dict[str, List[int]]): Optional function to
            create strata for batching. If not None it should be a function of
            (metadata_df, batch_size) and return a dict of strata labels to the list of indices
            of each group in the dataset. Must be specified if stratafication_mode is not None.
        dataset_name (str): The name of the dataset.
        categories_to_leave_out (list): categories to leave out in the test set (default: None)
        condition_to_leave_out (str): the condition to leave out in the test set (default: None)
        use_cache (bool): Whether to load from (or save to) the cache 

    Returns:
        torch.TensorDataset, torch.DataLoader, torch.DataLoader, pd.DataFrame
    """
    if use_cache:
        data_dict = load_cached(load_data_fn, data_dir, dataset_name)
    else:
        data_dict = load_direct(load_data_fn, data_dir)

    features = data_dict['features_tensor']
    metadata_df = data_dict['metadata_df']
    test_indices = []

    if categories_to_leave_out is not None or condition_to_leave_out is not None:
        assert categories_to_leave_out is not None
        assert condition_to_leave_out is not None
        assert dataset_name == 'kang' or dataset_name == 'kang-trvae' or dataset_name == 'kang-trvae-counts'
        to_leave_out = metadata_df[CONDITION_NAME[dataset_name]].isin(condition_to_leave_out)
        if dataset_name == 'kang':
            categories = [KANG_CATEGORIES[c] for c in categories_to_leave_out]
        else:
            categories = [KANG_TRVAE_CATEGORIES[c] for c in categories_to_leave_out]
        LOGGER.info(f"Leaving out condition: {condition_to_leave_out}, in categories: {categories}")
        to_leave_out = to_leave_out & (metadata_df[CATEGORY_NAME[dataset_name]].isin(categories))
        test_indices = metadata_df.index[to_leave_out].tolist()
        LOGGER.info(f"Leaving out {len(test_indices)} test samples")
        # remove held out set from metadata and features
        metadata_df = metadata_df[~to_leave_out]
        features = features[~to_leave_out]
        LOGGER.info(f"Metadata now of size: {metadata_df.shape}, data now of size: {features.shape}")
        
    LOGGER.info("Using data matrix with shape: %s", features.shape)
    input_tensors: List[torch.Tensor] = []
    if use_cuda:
        input_tensors.append(features.cuda())
    else:
        input_tensors.append(features)
    if input_label_fn is not None:
        input_tensors.extend(input_label_fn(metadata_df, use_cuda))
    dataset = TensorDataset(*input_tensors)

    if batch_sampling == 'batched-by-category':
        strata, strata_column = strata_fn(metadata_df, batch_size)
        LOGGER.info(f"Created {strata_column} strata: {list((k, len(v)) for k, v in strata.items())}")
        LOGGER.info(f'Configuring data loader to batch by category: {strata_column}, batch size: {batch_size}')
        batch_sampler = StratifiedSampler(strata, batch_size=batch_size)
        train_loader = DataLoader(dataset, batch_sampler=batch_sampler, shuffle=False, num_workers=0)
        valid_loader = DataLoader(dataset, batch_sampler=batch_sampler, shuffle=False, num_workers=0)
    elif batch_sampling == 'inverse-weighted':
        strata, strata_column = strata_fn(metadata_df, batch_size)
        LOGGER.info(f"Created {strata_column} strata: {list((k, len(v)) for k, v in strata.items())}")
        LOGGER.info(f'Configuring data loader with inverse-weighted sampling by category: {strata_column}, batch size: {batch_size}')
        train_loader, valid_loader = inv_weighted_data_loaders(tensor_dataset=dataset,
                                                                   metadata_df=metadata_df,
                                                                   strata_column=strata_column,
                                                                   strata=strata,
                                                                   batch_size=batch_size)
    elif batch_sampling == "uniform" or batch_sampling is None:        
        L = len(dataset)
        valid_len = L // 10
        LOGGER.info(
            "Creating random train/valid split: %d:%d", L - valid_len, valid_len
        )
        trainset, validset = random_split(dataset, [L - valid_len, valid_len])
        train_loader = DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True,
        )
        valid_loader = DataLoader(
            validset, batch_size=batch_size, shuffle=False, num_workers=0
        )
    else:
        raise NotImplementedError(f"batch_sampling method ({batch_sampling}) not recognised.")

    return dataset, train_loader, valid_loader, metadata_df, test_indices
