import os
from typing import List, Union, Callable

import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.manifold import TSNE
from umap import UMAP

from spc.model import Embedder
from spc.dataset import LabelledDataset
from spc.aggregate import aggregate_samples
from spc.dfconst import CONTROL_MOA_NAME, META_DF_COLUMNS, MOA_COLUMN, COMPOUND_UM_COLUMN,\
    PLATE_COLUMN, WELL_COLUMN


def save_combined_visuals(
        dataset1: LabelledDataset,
        dataset2: LabelledDataset,
        experiment_folder: str,
        embedding_fn: Callable,
        model: Embedder,
):
    ds1_embeddings = embedding_fn(model, dataset1)
    ds1_well_embeddings, ds1_well_meta_df = aggregate_samples(
        ds1_embeddings,
        dataset1.get_df(),
        [PLATE_COLUMN, WELL_COLUMN],
    )
    ds2_embeddings = embedding_fn(model, dataset2)
    ds2_well_embeddings, ds2_well_meta_df = aggregate_samples(
        ds2_embeddings,
        dataset2.get_df(),
        [PLATE_COLUMN, WELL_COLUMN],
    )

    well_embeddings = np.concatenate([ds1_well_embeddings, ds2_well_embeddings])
    well_meta_df = pd.concat([ds1_well_meta_df, ds2_well_meta_df]).reset_index(drop=True)
    # make sure all controls correctly labelled as the same concentration
    well_meta_df.loc[well_meta_df[MOA_COLUMN] == CONTROL_MOA_NAME, COMPOUND_UM_COLUMN] = 0

    tsne = TSNE(n_components=2, metric='cosine', perplexity=30.0)
    tsne12 = tsne.fit_transform(well_embeddings)
    um = UMAP(n_components=2, metric='cosine', min_dist=0.5, n_neighbors=20)
    um12 = um.fit_transform(well_embeddings)

    save_interactive_scatter_plot(
        save_fpath=os.path.join(experiment_folder, f'combined_tsne.html'),
        embeddings2d=tsne12,
        meta_df=well_meta_df,
        plot_label_type=MOA_COLUMN,
        hover_label_types=META_DF_COLUMNS,
    )

    save_interactive_scatter_plot(
        save_fpath=os.path.join(experiment_folder, f'combined_umap.html'),
        embeddings2d=um12,
        meta_df=well_meta_df,
        plot_label_type=MOA_COLUMN,
        hover_label_types=META_DF_COLUMNS,
    )


def save_interactive_scatter_plot(
        save_fpath: str,
        embeddings2d: np.ndarray,
        meta_df: pd.DataFrame,
        plot_label_type: str,
        hover_label_types: List[str]
):
    meta_df = meta_df.copy()

    x = embeddings2d[:, 0]
    y = embeddings2d[:, 1]

    custom_data = [meta_df[label_type] for label_type in hover_label_types]
    trace = ["ColX: %{x}", "ColY: %{y}"] + [label_type + ": %{customdata[" + str(i) + "]}" for i, label_type in enumerate(hover_label_types)]

    fig = px.scatter(
        x=x,
        y=y,
        color=meta_df[plot_label_type],
        color_discrete_map={CONTROL_MOA_NAME: 'black'},
        color_discrete_sequence=px.colors.qualitative.Light24,
        title=f"{plot_label_type}.html",
        custom_data=custom_data,
    )

    fig.update_layout(
       plot_bgcolor='white'
    )

    fig.update_xaxes(range=[x.min() - 0.5, x.max() + 0.5])
    fig.update_yaxes(range=[y.min() - 0.5, y.max() + 0.5])

    fig.update_traces(
        hovertemplate="<br>".join(trace)
    )
    fig.write_html(save_fpath, include_plotlyjs="cdn")


