import os
from typing import Callable, Dict, Tuple, List

import numpy as np
import scipy as sp
import pandas as pd
from tqdm import tqdm
from sklearn.manifold import TSNE
from umap import UMAP

from spc.model import Embedder
from spc.dataset import LabelledDataset
from spc.visualization import save_interactive_scatter_plot
from spc.dfconst import META_DF_COLUMNS, PLATE_COLUMN, TREATMENT_COLUMN, MOA_COLUMN, COMPOUND_NAME_COLUMN,\
    CONTROL_MOA_NAME, COMPOUND_UM_COLUMN
from spc.map.map import mean_average_precision_score
from spc.aggregate import aggregate_samples


def knn1(
        embeddings: np.ndarray,
        df: pd.DataFrame,
        target: str = 'moa',
        metric: str = 'cosine',
) -> float:
    same_label = 0
    total = 0
    # make sure index lines up with embeddings
    df = df.reset_index(drop=True)
    for e in range(embeddings.shape[0]):
        leave_one_out_df = df[~df.index.isin([e])]
        leave_one_out_embeds = embeddings[leave_one_out_df.index, ...]
        leave_one_out_df = leave_one_out_df.reset_index(drop=True)

        emb = embeddings[e, ...]
        cdist = sp.spatial.distance.cdist([emb], leave_one_out_embeds, metric=metric)

        my_meta = df.iloc[e]
        ref_meta = leave_one_out_df.iloc[np.argmin(cdist)]

        my_label = my_meta[target]
        ref_label = ref_meta[target]
        same_label += int(my_label == ref_label)
        total += 1
    return same_label / total


def not_same_compound_score(
        embeddings: np.ndarray,
        compound_labels: np.ndarray,
        moa_labels: np.ndarray,
        metric: str = 'cosine',
) -> float:
    same_moa = 0
    total = 0
    for e in range(embeddings.shape[0]):
        emb = embeddings[e, :]

        c = compound_labels[e]
        indices = (compound_labels != c).nonzero()[0]
        nsc = embeddings[indices]

        cdist = sp.spatial.distance.cdist([emb], nsc, metric=metric)
        nn_idx = indices[np.argmin(cdist)]

        moa1 = moa_labels[e]
        moa2 = moa_labels[nn_idx]
        same_moa += int(moa1 == moa2)
        total += 1
    if total == 0:
        return 0
    return same_moa/total


def not_same_compound_batch_score(
        embeddings: np.ndarray,
        compound_labels: np.ndarray,
        moa_labels: np.ndarray,
        batch_labels: np.ndarray,
        metric: str = 'cosine',
) -> float:
    same_moa = 0
    total = 0
    for e in range(embeddings.shape[0]):
        emb = embeddings[e, :]

        moa1 = moa_labels[e]
        c = compound_labels[e]
        b = batch_labels[e]
        indices = ((compound_labels != c) & (batch_labels != b)).nonzero()[0]
        nscb = embeddings[indices]

        cdist = sp.spatial.distance.cdist([emb], nscb, metric=metric)
        nn_idx = indices[np.argmin(cdist)]

        moa2 = moa_labels[nn_idx]
        same_moa += int(moa1 == moa2)
        total += 1
    if total == 0:
        return 0
    return same_moa/total


def calculate_bbbc021_metrics(well_embeddings, well_meta_df) -> Tuple:
    """ Treatment level nsc/nscb (BBBC021 only). """
    # single treatment embeddings
    treatment_embeddings = []
    treatment_meta_df = pd.DataFrame(columns=META_DF_COLUMNS)
    for i, (_, indices) in tqdm(enumerate(well_meta_df.groupby([TREATMENT_COLUMN]).indices.items())):
        treatment_embedding = np.median(well_embeddings[indices, :], axis=0)
        treatment_embeddings.append(treatment_embedding)

        meta = well_meta_df.iloc[indices[0]]
        treatment_meta_df = pd.concat([treatment_meta_df, pd.DataFrame.from_records([dict(meta[META_DF_COLUMNS])])])
    treatment_embeddings = np.array(treatment_embeddings)

    # BBBC021 - DMSO and unknown are not used in metric reporting
    nsc_valid_samples = (treatment_meta_df[MOA_COLUMN] != 'unknown') & (treatment_meta_df[MOA_COLUMN] != 'DMSO')
    treatment_embeddings = treatment_embeddings[nsc_valid_samples]
    treatment_meta_df = treatment_meta_df[nsc_valid_samples]

    nsc = not_same_compound_score(
        embeddings=treatment_embeddings,
        compound_labels=treatment_meta_df[COMPOUND_NAME_COLUMN].values,
        moa_labels=treatment_meta_df[MOA_COLUMN].values,
        metric='cosine',
    )

    # Kinase inhibitors and Cholesterol-lowering are only in one batch
    nscb_valid_samples = (treatment_meta_df[MOA_COLUMN] != 'Kinase inhibitors') & (treatment_meta_df[MOA_COLUMN] != 'Cholesterol-lowering')
    nscb_treatment_embeddings = treatment_embeddings[nscb_valid_samples]
    nscb_treatment_meta_df = treatment_meta_df[nscb_valid_samples]

    nscb = not_same_compound_batch_score(
        embeddings=nscb_treatment_embeddings,
        compound_labels=nscb_treatment_meta_df[COMPOUND_NAME_COLUMN].values,
        moa_labels=nscb_treatment_meta_df[MOA_COLUMN].values,
        # week
        batch_labels=nscb_treatment_meta_df[PLATE_COLUMN].astype(str).str.split('_').str[0].values,
        metric='cosine',
    )

    return nsc, nscb


def calculate_map(embeddings, meta_df, pos_class, neg_diffby):
    try:
        map_diffby_moa = mean_average_precision_score(
            meta_df, embeddings,
            pos_sameby=pos_class, neg_diffby=neg_diffby,
            pos_diffby=[], neg_sameby=[],
        )
        map_by_class = map_diffby_moa.copy().groupby(
            by=[pos_class])['average_precision'].mean().reset_index().sort_values(
            by='average_precision', ascending=False).reset_index(drop=True)
        map_score = map_by_class['average_precision'].mean()
        return map_score, map_by_class
    except Exception as e:
        print(e)
        return 0, None


def evaluate_model(
        experiment_folder: str,
        model: Embedder,
        dataset: LabelledDataset,
        embedding_fn: Callable,
        save_visualizations: bool,
        save_embeddings: bool,
        report_bbbc021_metrics: bool,
        prefix: str = '',
) -> Dict[str, float]:
    embeddings = embedding_fn(model, dataset)

    well_embeddings, well_meta_df = aggregate_samples(
        embeddings,
        dataset.get_df(),
        ['plate', 'well'],
    )
    # make sure all controls correctly labelled as the same concentration
    well_meta_df.loc[well_meta_df[MOA_COLUMN] == CONTROL_MOA_NAME, COMPOUND_UM_COLUMN] = 0

    treatment_embeddings, treatment_meta_df = aggregate_samples(
        embeddings,
        dataset.get_df(),
        [TREATMENT_COLUMN],
    )
    # remove controls from treatment level embeddings (will only be one sample)
    treatment_embeddings = treatment_embeddings[treatment_meta_df[MOA_COLUMN] != CONTROL_MOA_NAME]
    treatment_meta_df = treatment_meta_df[(treatment_meta_df[MOA_COLUMN] != CONTROL_MOA_NAME)].reset_index(drop=True)

    metrics = {}
    # 1-NN to moa
    metrics[f'{prefix}well-1nn-moa'] = knn1(well_embeddings, well_meta_df, MOA_COLUMN)
    # 1-NN to treatment
    metrics[f'{prefix}well-1nn-treatment'] = knn1(well_embeddings, well_meta_df, TREATMENT_COLUMN)

    # mAP same treatment (well embeddings)
    map_score, _ = calculate_map(well_embeddings, well_meta_df, TREATMENT_COLUMN, TREATMENT_COLUMN)
    metrics[f'{prefix}well-map-treatment'] = map_score
    # mAP same moa (well embeddings)
    map_score, _ = calculate_map(well_embeddings, well_meta_df, MOA_COLUMN, MOA_COLUMN)
    metrics[f'{prefix}well-map-moa'] = map_score
    # mAP same moa (treatment embeddings)
    map_score, _ = calculate_map(treatment_embeddings, treatment_meta_df, MOA_COLUMN, MOA_COLUMN)
    metrics[f'{prefix}treatment-map-moa'] = map_score

    if report_bbbc021_metrics:
        nsc, nscb = calculate_bbbc021_metrics(well_embeddings, well_meta_df)
        metrics[f'{prefix}nsc'] = nsc
        metrics[f'{prefix}nscb'] = nscb

    if save_visualizations:
        tsne = TSNE(n_components=2, metric='cosine', perplexity=30.0)
        tsne12 = tsne.fit_transform(well_embeddings)
        um = UMAP(n_components=2, metric='cosine', min_dist=0.5, n_neighbors=20)
        um12 = um.fit_transform(well_embeddings)

        save_interactive_scatter_plot(
            save_fpath=os.path.join(experiment_folder, f'{prefix}tsne.html'),
            embeddings2d=tsne12,
            meta_df=well_meta_df,
            plot_label_type=MOA_COLUMN,
            hover_label_types=META_DF_COLUMNS,
        )

        save_interactive_scatter_plot(
            save_fpath=os.path.join(experiment_folder, f'{prefix}umap.html'),
            embeddings2d=um12,
            meta_df=well_meta_df,
            plot_label_type=MOA_COLUMN,
            hover_label_types=META_DF_COLUMNS,
        )

    if save_embeddings:
        np.save(os.path.join(experiment_folder, f'{prefix}embeddings.npy'), embeddings)

    return metrics

