from abc import ABC, abstractmethod
import logging
from pathlib import Path
import re
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import sklearn
from sklearn.ensemble import RandomForestClassifier
import torch
from umap import UMAP

from ccvae.metrics import compute_mmd, knn_metric, silhouette_coeff, kbet
from ccvae.nn.utils import MultinomialDecoder


LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)

UMAP_SEED = 42


class MetricHandler(ABC):
    @abstractmethod
    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        """
        Return a DataFrame of metrics indexed by samples from sample_metadata_df.
        The columns can be any metrics, intermediate results or metadata related to
        each sample, e.g. UMAP features and metrics derived from those.
        """
        raise NotImplementedError()

    @abstractmethod
    def reduce(self, results_df: pd.DataFrame) -> pd.Series:
        """
        Returns a series of metrics that are reductions of certain columns of results_df.
        Expected to be called on the DataFrame returned by calculate().
        """
        raise NotImplementedError()


def select_metric_handlers(dataset_name) -> List[MetricHandler]:
    handlers = []
    if dataset_name == 'celligner':
        handlers.extend([CellignerSilhouetteHandler(), CellignerKBETHandler()])
    elif dataset_name == 'kang' or dataset_name == 'kang-trvae' or dataset_name =='kang-trvae-counts':
        handlers.extend([KangSilhouetteHandler(), KangKBETHandler()])
    elif dataset_name == 'uci-income':
        handlers.extend([IncomeClassificationHandler(),
                         IncomeSilhouetteHandler(n_neighbours=1000),
                         IncomeKBETHandler(n_neighbours=100)])
    else:
        pass
    return handlers


def get_umap_params(dataset_name):
    params = dict(seed=UMAP_SEED)
    if dataset_name == 'celligner':
        pass
    elif dataset_name == 'kang' or dataset_name == 'kang-trvae' or dataset_name =='kang-trvae-counts':
        pass
    elif dataset_name == 'uci-income':
        params.update(a=0.1, b=2.0, n_neighbours=100)
    else:
        raise ValueError(f'Invalid dataset_name {dataset_name}')
    return params


def calc_metrics(handlers: List[MetricHandler],
                 output_dir: Path,
                 dataset: torch.utils.data.TensorDataset,
                 sample_metadata_df: pd.DataFrame,
                 models: List[Union[torch.nn.Module, Path]],
                 dataset_name: str,
                 model_type: str,
                 load_model_fn: Callable[[Path], pl.LightningModule],
                 baseline_logprob: float,
                 use_cuda: bool,
                 metrics_tracking_fn: Optional[Callable]=None):
    """
    models can be either a list of models or list of paths to Pytorch Lightning checkpoints
    to load.
    """
    model_metrics = {}
    for i, m in enumerate(models):
        if isinstance(m, Path):
            model = load_model_fn(m)
            model_id = m.stem
        elif isinstance(m, torch.nn.Module):
            model = m
            model_id = f'model-{i}'
        else:
            raise ValueError(f'Invalid model reference, must be a torch.nn.Module or Path to checkpoint: {m}')
        LOG.info(f'Calculating outputs and metrics for {model_type} {model_id}')
        is_conditional_model = model_type != 'vae'
        features, model_logprob = predict(model,
                                          dataset,
                                          conditional_model=is_conditional_model,
                                          use_cuda=use_cuda)
        feature_array = features.cpu().numpy()

        latent_df = pd.DataFrame(
            data=feature_array,
            index=sample_metadata_df.index,
            columns=[f"Z{i}" for i in range(features.shape[1])],
        )
        LOG.info("Generated latents for full data set: %s", latent_df.shape)
        latent_file = output_dir / f"latents-{model_id}.parquet"
        LOG.info("Saving loc from latent distribution to %s", latent_file)
        latent_df.to_parquet(latent_file)

        # Compute basic metrics
        type_a, type_b = sample_metadata_df["type"].unique()
        mmd = compute_mmd(
            latent_df[sample_metadata_df["type"] == type_a].values,
            latent_df[sample_metadata_df["type"] == type_b].values,
        )

        metrics_series = pd.Series(
            {
                "model_logprob": model_logprob,
                "mmd": mmd,
                "baseline_logprob": baseline_logprob,
                "logprob_improvement": model_logprob - baseline_logprob,
            },
            name=model_id,  # Name gets dropped after appends but add it here for any logging
        )

        umap_params = get_umap_params(dataset_name)
        LOG.info(f'Computing UMAP projection with params {umap_params}.')
        umap_df = calc_umap(feature_array, sample_metadata_df, **umap_params)
        LOG.info(f'UMAP projection complete.')

        # Call any additional metric handlers
        handler_result_dfs = [umap_df]
        for handler in handlers:
            metric_results_df = handler.calculate(umap_df.to_numpy(), sample_metadata_df)
            if len(metric_results_df) == len(sample_metadata_df):
                # TODO kbet does not currently return per-sample values from calculate()
                handler_result_dfs.append(metric_results_df)
            reduced_scores = handler.reduce(metric_results_df)
            metrics_series = metrics_series.append(reduced_scores, verify_integrity=True)
        model_metrics[model_id] = metrics_series
        if len(handler_result_dfs):
            handler_results_dir = output_dir / 'ckpt-results'
            handler_results_dir.mkdir(exist_ok=True, parents=True)
            handler_results_path = handler_results_dir / f'{str(model_id)}.parquet'
            LOG.info(f'Saving checkpoint metrics: {str(handler_results_path)}')
            pd.concat(handler_result_dfs, axis=1, sort=True).to_parquet(handler_results_path)
        else:
            LOG.info(f'No additional metric handlers registered for this dataset.')

    metrics_df = pd.DataFrame.from_dict(model_metrics, orient='index')
    LOG.info(f'Metrics\n{metrics_df}')
    main_metrics_path = output_dir / "metrics.csv"
    metrics_df.to_csv(main_metrics_path)
    LOG.info(f'Saved to {str(main_metrics_path)}')

    if metrics_tracking_fn is not None:
        LOG.info('Logging metrics to external tracking.')
        for model_id, metrics in model_metrics.items():
            match = re.search('epoch=([0-9]+)', 'epoch=0000-valid_loss=3.2862e+05')
            if match:
                epoch = int(match.group(1))
            else:
                # TODO Dummy epoch for last checkpoint, tracking is just to be indicative
                # of performance anyway.
                epoch = 9999
            metrics_tracking_fn(**metrics.to_dict(), step=epoch)


def predict(model: torch.nn.Module,
            dataset: torch.utils.data.TensorDataset,
            conditional_model: bool,
            use_cuda: bool) -> Tuple[torch.Tensor, float]:
    if use_cuda:
        # Lightning models end up back on CPU (or were copied and never moved) but
        # dataset is on cuda.
        model.cuda()
        if isinstance(model.decoder, MultinomialDecoder):
            if model.decoder.baseline is not None:
                model.decoder.baseline = model.decoder.baseline.cuda()
    model.eval()
    with torch.no_grad():
        # inputs is a concatentation of input features and conditional one-hot labels
        # if the model is a conditional one, others it just contains the input features.
        assert 0 < len(dataset.tensors) < 4
        if len(dataset.tensors) == 3 and conditional_model:
            x, c, g = dataset.tensors
            g = None
        elif len(dataset.tensors) == 2 and conditional_model:
            x, c = dataset.tensors
            g = None
        else:
            x = dataset.tensors[0]
            c = None
            g = None
        qz = model.forward(x, c, g)
        decoder_result = model.forward_decoder(qz.loc, c, g)
        if model.decoder.return_hidden:
            logprob = decoder_result[0].log_prob(x)
        else:
            logprob = decoder_result.log_prob(x)
    return qz.loc, logprob.mean().item()


def class_knn_score(feature_array: np.ndarray,
                    sample_metadata_df: pd.DataFrame,
                    query_selector: np.ndarray,
                    group_labels: np.ndarray,
                    class_labels: np.ndarray,
                    ) -> pd.DataFrame:
    assert len(feature_array) == len(sample_metadata_df)
    assert len(feature_array) == len(query_selector)
    assert len(feature_array) == len(class_labels)
    assert len(feature_array) == len(group_labels)
    umap_array = UMAP().fit_transform(feature_array)
    metric_array = knn_metric(feature_array,
                              queries=query_selector,
                              labels=group_labels,
                              class_partition=class_labels,
                              n_neighbours=50,
                              return_counts=False)
    umap_df = pd.DataFrame(umap_array,
                           index=sample_metadata_df.index,
                           columns=['umap-0', 'umap-1'],
                           dtype=np.float32)
    # Only the query set of samples has values for the metric, so filter the index and concat
    # with the umap_df to combine columns into one frame.
    metric_df = pd.DataFrame(metric_array[:, np.newaxis],
                             index=sample_metadata_df.index[query_selector],
                             columns=['knn-metric'],
                             dtype=np.float32)
    result_df = pd.concat([umap_df, metric_df], axis=1, sort=True)
    return result_df


def calc_umap(latents_array: np.ndarray,
              metadata_df: pd.DataFrame,
              n_neighbours: int = 15,
              a: Optional[float] = None,
              b: Optional[float] = None,
              seed: int = UMAP_SEED) -> pd.DataFrame:
    n_components = 2
    df = pd.DataFrame(UMAP(n_components=n_components,
                           a=a,
                           b=b,
                           n_neighbors=n_neighbours,
                           random_state=seed,
                           transform_seed=42).fit_transform(latents_array),
                      index=metadata_df.index,
                      columns=[f"umap-{i}" for i in range(n_components)],
                      dtype=np.float32)
    return df


def calc_silhouette(feature_array: np.ndarray,
                    sample_metadata_df: pd.DataFrame,
                    query_selector: np.ndarray,
                    group_labels: np.ndarray,
                    class_labels: np.ndarray,
                    n_neighbours=50,
                    ) -> pd.DataFrame:
    assert len(feature_array) == len(sample_metadata_df)
    assert len(feature_array) == len(query_selector)
    assert len(feature_array) == len(class_labels)
    assert len(feature_array) == len(group_labels)
    scores, mean_scores = silhouette_coeff(features=feature_array,
                                           queries=query_selector,
                                           labels=group_labels,
                                           class_partition=class_labels,
                                           n_neighbours=n_neighbours)
    metric_df = pd.DataFrame(np.stack([scores, mean_scores], axis=1),
                             index=sample_metadata_df.index[query_selector],
                             columns=[f'silhouette-{n_neighbours}',
                                      f'mean-silhouette-{n_neighbours}'],
                             dtype=np.float32)
    return metric_df


def calc_kbet(feature_array: np.ndarray,
              sample_metadata_df: pd.DataFrame,
              query_selector: np.ndarray,
              group_labels: np.ndarray,
              class_labels: np.ndarray,
              n_neighbours=50,
              ) -> pd.DataFrame:
    assert len(feature_array) == len(sample_metadata_df)
    assert len(feature_array) == len(query_selector)
    assert len(feature_array) == len(class_labels)
    assert len(feature_array) == len(group_labels)
    _, knn_counts = knn_metric(feature_array,
                               queries=query_selector,
                               labels=group_labels,
                               class_partition=class_labels,
                               n_neighbours=n_neighbours,
                               return_counts=True)

    num_1 = class_labels.sum()
    num_0 = len(class_labels) - num_1
    freq_0 = n_neighbours * num_0 / (num_1 + num_0)
    freq_1 = n_neighbours - freq_0
    # use the other classes frequency for each sample (0 -> freq_1, 1 -> freq_0)
    expected_freq = np.where(class_labels, freq_0, freq_1)

    score = kbet(knn_counts[:, 1],
                 expected_freq=expected_freq,
                 n_neighbours=n_neighbours,
                 significance=0.01)

    metric_df = pd.DataFrame([[score]],
                             columns=[f'kbet-{n_neighbours}'],
                             index=['kbet-0'], dtype=np.float32)
    return metric_df


def calc_mean_kbet(feature_array: np.ndarray,
                   sample_metadata_df: pd.DataFrame,
                   query_selector: np.ndarray,
                   group_labels: np.ndarray,
                   class_labels: np.ndarray,
                   n_neighbours=50,
              ) -> pd.DataFrame:
    # TODO NOT USED - Handle groups correctly here
    assert len(feature_array) == len(sample_metadata_df)
    assert len(feature_array) == len(query_selector)
    assert len(feature_array) == len(class_labels)
    assert len(feature_array) == len(group_labels)
    scores = []
    unique_labels = np.unique(group_labels)
    for group in unique_labels:
        group_features = feature_array[group_labels == group]
        group_class_labels = class_labels[group_labels == group]

        num_1 = group_class_labels.sum()
        num_0 = len(group_class_labels) - num_1
        freq_0 = n_neighbours * num_0 / (num_1 + num_0)
        freq_1 = n_neighbours - freq_0

        _, counts = knn_metric(features=data,
                               queries=[True] * data.shape[0],
                               labels=metadata.cell.to_numpy(),
                               class_partition=(metadata.type == "perturbed").values,
                               n_neighbours=n_neighbours,
                               return_counts=True,
                               )
        _, knn_counts = knn_metric(feature_array,
                                   queries=query_selector,
                                   labels=np.ones_like(group_class_labels),
                                   class_partition=class_labels,
                                   n_neighbours=n_neighbours,
                                   return_counts=True)

        # use the other classes frequency for each sample (0 -> freq_1, 1 -> freq_0)
        expected_freq = np.where(class_labels, freq_0, freq_1)

        group_scores = kbet(knn_counts[:, 1],
                            expected_freq=expected_freq,
                            n_neighbours=n_neighbours,
                            significance=0.01)
        scores.append(np.mean(group_scores))

    metric_df = pd.DataFrame(np.array(scores),
                             index=unique_labels,
                             columns=[f'mean-kbet-{n_neighbours}'],
                             dtype=np.float32)
    return metric_df


class CellignerSilhouetteHandler(MetricHandler):
    def __init__(self, n_neighbours=100):
        self.n_neighbours = n_neighbours

    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        cell_line_query = (sample_metadata_df.type == "CL").to_numpy()
        disease_labels = sample_metadata_df.disease.to_numpy()
        return calc_silhouette(feature_array,
                               sample_metadata_df,
                               query_selector=cell_line_query,
                               group_labels=disease_labels,
                               class_labels=cell_line_query,
                               n_neighbours=self.n_neighbours)

    def reduce(self, results_df: pd.DataFrame) -> pd.Series:
        return results_df.mean(axis=0)


class CellignerKBETHandler(MetricHandler):
    def __init__(self, n_neighbours=100):
        self.n_neighbours = n_neighbours

    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        cell_line_query = (sample_metadata_df.type == "CL").to_numpy()
        disease_labels = sample_metadata_df.disease.to_numpy()
        return calc_kbet(feature_array,
                         sample_metadata_df,
                         query_selector=cell_line_query,
                         group_labels=disease_labels,
                         class_labels=cell_line_query,
                         n_neighbours=self.n_neighbours)

    def reduce(self, results_df: pd.DataFrame) -> pd.Series:
        return results_df.mean(axis=0)


class CellignerKNNHandler(MetricHandler):
    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        cell_line_query = (sample_metadata_df.type == "CL").to_numpy()
        disease_labels = sample_metadata_df.disease.to_numpy()
        return class_knn_score(feature_array,
                               sample_metadata_df,
                               query_selector=cell_line_query,
                               group_labels=disease_labels,
                               class_labels=cell_line_query)

    def reduce(self,
               results_df: pd.DataFrame) -> pd.Series:
        return pd.Series([results_df['knn-metric'].mean()], index=['knn-metric'])


class KangSilhouetteHandler(MetricHandler):
    def __init__(self, n_neighbours=100):
        self.n_neighbours = n_neighbours

    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        perturbed_query = (sample_metadata_df.type == "perturbed").to_numpy()
        cell_labels = sample_metadata_df.cell.to_numpy()
        return calc_silhouette(feature_array,
                               sample_metadata_df,
                               query_selector=perturbed_query,
                               group_labels=cell_labels,
                               class_labels=perturbed_query,
                               n_neighbours=self.n_neighbours)

    def reduce(self, results_df: pd.DataFrame) -> pd.Series:
        return results_df.mean(axis=0)


class KangKBETHandler(MetricHandler):
    def __init__(self, n_neighbours=100):
        self.n_neighbours = n_neighbours

    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        perturbed_query = (sample_metadata_df.type == "perturbed").to_numpy()
        cell_labels = sample_metadata_df.cell.to_numpy()
        return calc_kbet(feature_array,
                         sample_metadata_df,
                         query_selector=perturbed_query,
                         group_labels=cell_labels,
                         class_labels=perturbed_query,
                         n_neighbours=self.n_neighbours)

    def reduce(self, results_df: pd.DataFrame) -> pd.Series:
        return results_df.mean(axis=0)


class KangKNNHandler(MetricHandler):
    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        perturbed_query = (sample_metadata_df.type == "perturbed").to_numpy()
        cell_labels = sample_metadata_df.cell.to_numpy()
        return class_knn_score(feature_array,
                               sample_metadata_df,
                               query_selector=perturbed_query,
                               group_labels=cell_labels,
                               class_labels=perturbed_query)

    def reduce(self,
               results_df: pd.DataFrame) -> pd.Series:
        return pd.Series([results_df['knn-metric'].mean()], index=['knn-metric'])


class IncomeSilhouetteHandler(MetricHandler):
    def __init__(self, n_neighbours=1000):
        self.n_neighbours = n_neighbours

    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        all_query = np.full_like(sample_metadata_df.type.to_numpy(), True, dtype=bool)
        gender_query = (sample_metadata_df.type == "Female").to_numpy()
        dummy_group = np.copy(all_query)  # no groups in this experiment
        df = calc_silhouette(feature_array,
                             sample_metadata_df,
                             query_selector=all_query,
                             group_labels=dummy_group,
                             class_labels=gender_query,
                             n_neighbours=self.n_neighbours)
        # mean silhouette score does not make sense for thss data
        return df.drop(f'mean-silhouette-{self.n_neighbours}', axis=1)

    def reduce(self, results_df: pd.DataFrame) -> pd.Series:
        return results_df.mean(axis=0)


class IncomeKBETHandler(MetricHandler):
    def __init__(self, n_neighbours=1000):
        self.n_neighbours = n_neighbours

    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        all_query = np.full_like(sample_metadata_df.type.to_numpy(), True, dtype=bool)
        gender_query = (sample_metadata_df.type == "Female").to_numpy()
        dummy_group = np.copy(all_query)  # no groups in this experiment
        return calc_kbet(feature_array,
                         sample_metadata_df,
                         query_selector=all_query,
                         group_labels=dummy_group,
                         class_labels=gender_query,
                         n_neighbours=self.n_neighbours)

    def reduce(self, results_df: pd.DataFrame) -> pd.Series:
        return results_df.mean(axis=0)


class IncomeKNNHandler(MetricHandler):
    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        gender_query = (sample_metadata_df.type == "Female").to_numpy()
        dummy_group = np.full_like(gender_query, True)
        return class_knn_score(feature_array,
                               sample_metadata_df,
                               query_selector=gender_query,
                               group_labels=dummy_group,
                               class_labels=gender_query)

    def reduce(self,
               results_df: pd.DataFrame) -> pd.Series:
        return pd.Series([results_df['knn-metric'].mean()], index=['knn-metric'])


def classify(features: np.ndarray, labels: np.ndarray, random_state=0xcb95423):
    classifier = RandomForestClassifier(max_depth=6, random_state=random_state)
    classifier.fit(features, labels)
    return classifier.predict(features)


class IncomeClassificationHandler(MetricHandler):
    def calculate(self,
                  feature_array: np.ndarray,
                  sample_metadata_df: pd.DataFrame) -> pd.DataFrame:
        # Predict income and gender from latent features
        income_target = sample_metadata_df.income.to_numpy()
        income_prediction = classify(feature_array, income_target)
        gender_target = sample_metadata_df.type.to_numpy()
        gender_prediction = classify(feature_array, gender_target)
        result_df = pd.DataFrame.from_dict(dict(income_prediction=income_prediction,
                                                income_target=income_target,
                                                gender_prediction=gender_prediction,
                                                gender_target=gender_target),
                                           orient='columns')
        result_df.index=sample_metadata_df.index
        return result_df

    def reduce(self,
               results_df: pd.DataFrame) -> pd.Series:
        accuracy_fn = sklearn.metrics.accuracy_score
        results = dict(income_accuracy=accuracy_fn(results_df.income_target.to_numpy(),
                                                   results_df.income_prediction.to_numpy()),
                       gender_accuracy=accuracy_fn(results_df.gender_target.to_numpy(),
                                                   results_df.gender_prediction.to_numpy()))
        return pd.Series(results)
