import json
import os
from collections import OrderedDict
from typing import Tuple, Union

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import umap
from matplotlib.colors import Normalize
from matplotlib.patches import Patch
from sklearn.decomposition import PCA
from src.training.training_utils import (extract_all_embeddings,
                                         extract_holdout_embeddings)


def show_results_holdout(
    dataset: str,
    task: str,
    regressors: Tuple[str, ...],
    embedding_types: Tuple[str, ...],
    metric: str,
    split_key: str,
    active: bool,
    target: str = "",
):
    # Visualisation parameters
    sns.set_style(style="dark", rc={"ytick.left": True})
    n_bars = len(embedding_types)
    color_palette = sns.color_palette("tab10")[:n_bars]
    # color_palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b"]

    splits = ("test", "train", "val")

    insert = "active_" if active else ""
    if target == "target_class_2":
        insert = "high_" + insert

    # Load data
    df = pd.read_csv(
        f"results/{dataset}/baseline_{task}_{insert}{dataset}.csv",
        index_col=0,
    )

    df = df[df["split_type"] == split_key]
    df = df.sort_values("embedding")
    best_test = df[f"test_{metric}"].max()

    fig, ax = plt.subplots(
        len(splits), len(regressors), figsize=(len(regressors) * 5, len(splits) * 5)
    )
    for i, split in enumerate(splits):
        df_i = df[["model", "embedding", f"{split}_{metric}"]]
        for j, regressor in enumerate(regressors):
            ax_ij = ax[i, j]
            df_ij = df_i[(df_i["model"] == regressor) | (df_i["model"].isnull())]
            sns.barplot(
                data=df_ij,
                x="embedding",
                y=f"{split}_{metric}",
                palette=color_palette,
                ax=ax_ij,
                ci="sd",
                errwidth=1,
            )

            # Adjust axes
            if metric in ["spearman", "pearson", "spearman_active", "mcc", "auroc"]:
                ax_ij.set_ylim([0, 1])
            ax_ij.set_ylabel("")
            ax_ij.set_xlabel("")
            if i == (len(splits) - 1):
                ax_ij.tick_params("x", labelrotation=45)
            else:
                ax_ij.set_xticks([])
            if split == "test" and split_key not in ["CV", "random"]:
                lim = ax_ij.get_xlim()
                ax_ij.hlines(
                    y=best_test,
                    xmin=lim[0],
                    xmax=lim[1],
                    colors="black",
                    alpha=0.1,
                    linestyles="--",
                )

    # Adjust axes
    for i, split in enumerate(splits):
        # Capitalize labels
        ax[i, 0].set_ylabel(split.capitalize())
        for j in range(len(regressors)):
            # Set appropriate x-ticks
            if metric in ["mcc", "spearman"]:
                ax[i, j].set_yticks(np.arange(0, 1.2, 0.2))
            else:
                ax[i, j].set_yticks(np.arange(0, df[f"test_{metric}"].max() + 0.1, 0.2))
                ax[i, j].set_yticks(
                    np.linspace(0, df[f"test_{metric}"].max() + 0.1, 10)
                )

    for j, regressor in enumerate(regressors):
        # Set title for each column
        ax[0, j].set_title(regressor)

    title = f"{dataset.upper()} {task} results. Metric: {metric.upper()}. Split-strategy: {split_key}."
    if target == "target_class" and task == "classification":
        title += " Active vs. inactive sequences."
        path_postfix = f"classification_results_{metric}_{split_key}"
    elif target == "target_class_2" and task == "classification":
        title += " Very active vs. active/inactive."
        path_postfix = f"classification_results_high_{metric}_{split_key}"
    elif active and task == "regression":
        title += " Regression on active sequences."
        path_postfix = f"regression_results_active_{metric}_{split_key}"
    elif not active and task == "regression":
        title += " Regression on all sequences."
        path_postfix = f"regression_results_{metric}_{split_key}"
    else:
        raise ValueError

    plt.suptitle(
        title,
        fontsize=16,
    )

    # Save and show
    path = f"figures/{dataset}/{task}/{dataset}_{path_postfix}"
    plt.savefig(f"{path}.pdf")
    plt.savefig(f"{path}.png")
    print(f"Saved plot in {path}.png")
    plt.show()


def show_CV_split_distribution(
    df: pd.DataFrame, threshold: Union[float, None], dataset: str, n_partitions: int
):
    # Setup plotting
    sns.set_style("dark")
    color_palette = sns.color_palette("colorblind")
    y_limits = (df["target_reg"].min(), df["target_reg"].max())
    n_eff = int(df[[f"part_{i}" for i in range(n_partitions)]].sum().sum())

    fig, ax = plt.subplots(
        n_partitions, 2, figsize=(10, 5 * n_partitions), sharex="col"
    )

    # Histogram over target values
    for i in range(n_partitions):
        axi = ax[i, 0]
        sns.histplot(
            data=df[df[f"part_{i}"] == 1],
            x="target_reg",
            ax=axi,
            bins=12,
            binrange=y_limits,
            color=color_palette[i],
            alpha=0.9,
        )
        axi.set_title(f"Target values (partition {i + 1})")

        axi = ax[i, 1]
        sns.countplot(
            data=df[df[f"part_{i}"] == 1],
            x="target_class",
            ax=axi,
            color=color_palette[i],
            alpha=0.9,
        )
        axi.set_title(f"Binarized target (partition {i + 1})")

    plt.suptitle(
        f"{dataset.upper()} dataset overview.\n"
        f"Cross-validation partitions (K={n_partitions}) at threshold {threshold}.\n"
        f"N = {len(df)}, Neff = {n_eff}.\n"
        f"N_i = {df[[f'part_{i}' for i in range(n_partitions)]].sum().tolist()}/{n_eff}.",
        fontsize="x-large",
    )

    [axi.set_ylabel("Count") for axi in ax.flatten()]
    [axi.set_xlabel("") for axi in ax.flatten()]

    plt.savefig(
        f"figures/{dataset}/splits/{dataset}_CV_K={n_partitions}_distribution.pdf"
    )
    plt.savefig(
        f"figures/{dataset}/splits/{dataset}_CV_K={n_partitions}_distribution.png"
    )
    plt.show()


def visualize_embedding(
    dataset: str, embedding_type: str, suffix: Union[str, None] = None
):
    np.random.seed(42)
    if embedding_type == "EVE (z)":
        assert suffix is not None
    # Extract embeddings
    embeddings, y, names = extract_all_embeddings(
        dataset=dataset,
        embedding_type=embedding_type,
        target="target_reg",
        suffix=suffix,
    )

    # PCA
    x_pca = PCA(n_components=2).fit_transform(embeddings)

    # UMAP
    x_umap = umap.UMAP(random_state=42).fit_transform(embeddings)

    df = pd.read_csv(f"data/processed/{dataset}/{dataset}.csv", index_col=0)
    df = df[df[["part_0", "part_1", "part_2"]].sum(axis=1) == 1.0]
    df["split"] = ""
    df.loc[df["part_0"] == 1, "split"] = "part_0"
    df.loc[df["part_1"] == 1, "split"] = "part_1"
    df.loc[df["part_2"] == 1, "split"] = "part_2"
    split = df["split"].values
    # Setup plotting
    sns.set_style("dark")
    s = 80

    df_plot = pd.DataFrame(
        {
            "x_pca": x_pca[df.index.values, 0],
            "y_pca": x_pca[df.index.values, 1],
            "x_umap": x_umap[df.index.values, 0],
            "y_umap": x_umap[df.index.values, 1],
            "target": y[df.index.values],
            "split": split[df.index.values],
        }
    )
    fig, ax = plt.subplots(2, 2, figsize=(10, 10))

    sns.scatterplot(
        data=df_plot,
        x="x_pca",
        y="y_pca",
        hue="split",
        ax=ax[0, 0],
        s=s,
        style="split",
        alpha=0.5,
    )
    sns.scatterplot(
        data=df_plot, x="x_pca", y="y_pca", hue="target", ax=ax[1, 0], s=s, alpha=0.5
    )
    sns.scatterplot(
        data=df_plot,
        x="x_umap",
        y="y_umap",
        hue="split",
        ax=ax[0, 1],
        s=s,
        style="split",
        alpha=0.5,
    )
    sns.scatterplot(
        data=df_plot, x="x_umap", y="y_umap", hue="target", ax=ax[1, 1], s=s, alpha=0.5
    )

    ax[0, 0].set_title(f"PCA")
    ax[1, 0].set_title(f"PCA")
    ax[0, 1].set_title(f"UMAP")
    ax[1, 1].set_title(f"UMAP")

    plt.suptitle(f"{embedding_type} embeddings on {dataset.upper()} dataset")
    plt.subplots_adjust(hspace=0.3)

    path = f"figures/{dataset}/embeddings/{dataset}_{embedding_type}_dimensionality_reduction"
    plt.savefig(f"{path}.pdf")
    plt.savefig(f"{path}.png")
    print(f"Saved figure in {path}.<pdf,png>")
    plt.show()


def show_test_results(
    dataset: str,
    task: str,
    regressors: Tuple[str, ...],
    embedding_types: Tuple[str, ...],
    metric: str,
    split_key: str,
    active: bool,
    target: str = "",
):
    # Visualisation parameters
    sns.set_style(style="dark")  # , rc={"ytick.left": True})
    color_palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#17becf"]

    insert = "active_" if active else ""
    if target == "target_class_2":
        insert = "high_" + insert

    # Load data
    df = pd.read_csv(
        f"results/{dataset}/baseline_{task}_{insert}{dataset}.csv",
        index_col=0,
    )
    embedding_names = {
        "ONEHOT (MSA)": "MSA",
        "ESM-2": "ESM-2",
        "ESM-1B": "ESM-1B",
        "ESM-IF1": "ESM-IF1",
        "EVE (z)": "EVE",
        "AF2": "Evoformer",
    }
    # Extract results, rename embeddings
    df = df[df["embedding"].isin(embedding_types)]
    df = df[df["embedding"].isin(embedding_names)]
    df = df[df["split_type"] == split_key]
    df["embedding"] = df["embedding"].replace(embedding_names)
    df = df.sort_values("embedding")
    df = df[["model", "embedding", f"test_{metric}"]]

    fig, ax = plt.subplots(1, 3, figsize=(5.5, 2), sharey="all")

    for i, regressor in enumerate(regressors):
        ax_i = ax[i]
        df_j = df[(df["model"] == regressor)]
        sns.barplot(
            data=df_j,
            x="embedding",
            y=f"test_{metric}",
            palette=color_palette,
            ax=ax_i,
            errorbar="se",
            errwidth=1.5,
            capsize=0.5,
            label="embedding",
            edgecolor="k",
            saturation=1.0,
        )

        # Adjust axes
        if metric in ["spearman", "pearson", "spearman_active", "mcc", "auroc"]:
            ax_i.set_ylim([0, 1])
        ax_i.set_ylabel("")
        ax_i.set_xlabel("")

    patches = [
        Patch(color=color_palette[i], label=t)
        for i, t in enumerate(t.get_text() for t in ax_i.get_xticklabels())
    ]
    [ax_i.set_xticks([]) for ax_i in ax]
    # [ax_i.set_yticks([]) for ax_i in ax[1:]]
    # plt.legend(handles=patches, loc="upper right", prop={"size": 10})

    # Adjust axes

    # Capitalize labels
    for i in range(len(regressors)):
        # Set appropriate x-ticks
        if metric in ["mcc", "spearman"]:
            ax[i].set_yticks(np.arange(0, 1.2, 0.2))
        else:
            ax[i].set_yticks(np.arange(0, df[f"test_{metric}"].max() + 0.1, 0.2))
            ax[i].set_yticks(np.linspace(0, df[f"test_{metric}"].max() + 0.1, 10))

    for i, regressor in enumerate(regressors):
        # Set title for each column
        ax[i].set_title(regressor)

    if metric == "spearman":
        ax[0].set_ylabel(r"Spearman $\rho$")
    elif metric == "mcc":
        ax[0].set_ylabel(r"Matthew's $\phi$")
    else:
        raise ValueError

    title = f"{dataset.upper()} {task} results. Metric: {metric.upper()}. Split-strategy: {split_key}."
    if target == "target_class" and task == "classification":
        title += " Active vs. inactive sequences."
        path_postfix = f"classification_results_{metric}_{split_key}"
    elif target == "target_class_2" and task == "classification":
        title += " Very active vs. active/inactive."
        path_postfix = f"classification_results_high_{metric}_{split_key}"
    elif active and task == "regression":
        title += " Regression on active sequences."
        path_postfix = f"regression_results_active_{metric}_{split_key}"
    elif not active and task == "regression":
        title += " Regression on all sequences."
        path_postfix = f"regression_results_{metric}_{split_key}"
    else:
        raise ValueError

    plt.tight_layout()

    # fig.subplots_adjust(wspace=0.1)
    # Save and show
    path = f"figures/{dataset}/{task}/{dataset}_test_{path_postfix}"
    plt.savefig(f"{path}.pdf")
    plt.savefig(f"{path}.png")
    print(f"Saved plot in {path}.<png,pdf>")
    plt.show()
