import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from SEPAL.downstream_evaluation import TARGETS, DATASETS_NAMES
from SEPAL import SEPAL_DIR


datasets_names = {
    f"{str(SEPAL_DIR)}/datasets/knowledge_graphs/{k}": v
    for k, v in DATASETS_NAMES.items()
}

MODEL_NAMES = {
    "transe": "TransE",
    "distmult": "DistMult",
    "mure": "MuRE",
    "tucker": "TuckER",
    "hole": "HolE",
}


# Downstream tasks results figure
def downstream_results(normalize_scores=True):
    # Load scores
    scores = make_score_df(
        filters={
            "embed_dim": [100],
            "embed_method": ["distmult", "fastrp"],
            "data": [
                "mini_yago3_lcc",
                "yago3_lcc",
                "yago4_lcc",
            ],
        }
    )

    # Make data to plot
    features = [
        "downstream_task",
        "mean_score",
        "source_dataset",
        "method",
        "id",
        "total_time",
    ]
    df = scores[features]

    # Make normalize scores if necessary
    if normalize_scores:
        df = make_scores_normalize(df)

    # Keep best performing embeddings for each method
    df = get_best_model(df, ["source_dataset", "method"])

    ## Plot the results
    sources = df["source_dataset"].unique()[::-1]
    n_sources = len(sources)
    order = ["Movie revenues", "US accidents", "US elections", "Housing prices"]

    height_ratios = [(df["source_dataset"] == s).sum() / len(order) for s in sources]
    fig, axs = plt.subplots(
        n_sources,
        figsize=(6, 4),
        sharex=True,
        gridspec_kw={"height_ratios": height_ratios, "hspace": 0.6},
    )

    for i in range(n_sources):
        ax = axs[i]
        df_to_plot = df[(df["source_dataset"] == sources[i])]
        df_to_plot = pd.pivot_table(
            data=df_to_plot,
            index=["method"],
            columns=["downstream_task"],
            values="mean_score",
            sort=False,
        )
        ax.set_title(sources[i], fontsize=11)
        if i == n_sources - 1:
            if normalize_scores:
                ax.set_xlabel("Cumulative normalized mean cross-validation score (R2)")
            else:
                ax.set_xlabel("Cumulative mean cross-validation score (R2)")

        df_to_plot[order][::-1].plot(
            ax=ax,
            kind="barh",
            stacked=True,
            ylabel="Mean cross-validation score",
            xlabel="",
            rot=0,
            edgecolor="black",
            alpha=0.9,
            color=sns.color_palette("colorblind", n_colors=4),
            zorder=3,
        )
        ax.get_legend().remove()
        ax.vlines(x=0, ymin=-1, ymax=10, linestyles="dotted", colors="k")
        fig.add_subplot(ax)

        # Add a grid under the bars
        ax.grid(axis="x", linestyle="--", alpha=0.7, zorder=0)

    # Set x limits
    if normalize_scores:
        plt.xlim((-0.1, 4.4))
    else:
        plt.xlim((-0.1, 2.4))

    # Add a legend
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        title="Evaluation dataset",
        loc="upper center",
        bbox_to_anchor=(1.06, 0.65),
        fancybox=False,
        frameon=False,
        ncol=1,
    )

    # Add the total time for each method
    times = pd.pivot_table(
        data=df,
        index=["method"],
        columns=["source_dataset"],
        values="total_time",
        sort=False,
    )

    for i in range(n_sources):
        ax = axs[i]
        for j, method in enumerate(times.index):
            method_order = df[(df["source_dataset"] == sources[i])].method.unique()
            # Check if times[sources[i]][method] is not NaN
            if times[sources[i]][method] == times[sources[i]][method]:
                computation_time = human_time_duration(times[sources[i]][method])
                ax.text(
                    4 if normalize_scores else 2.15,
                    method_order.size - np.where(method_order == method)[0][0] - 1,
                    computation_time,
                    ha="left",
                    va="center",
                    fontsize=8,
                    weight="bold",
                )

    # Save figure
    if normalize_scores:
        plt.savefig(
            SEPAL_DIR / "neurips_figures/normalized_downstream_results.pdf",
            bbox_inches="tight",
        )
    else:
        plt.savefig(
            SEPAL_DIR / "neurips_figures/downstream_results.pdf", bbox_inches="tight"
        )

    return


# Transductive link prediction results table
def tlp_results(side="tail"):
    # Load results
    val_scores = pd.read_parquet(SEPAL_DIR / "val_lp_scores.parquet")
    test_scores = pd.read_parquet(SEPAL_DIR / "test_lp_scores.parquet")

    # Preprocess scores
    val_scores = preprocess_tlp_scores(val_scores, side)
    test_scores = preprocess_tlp_scores(test_scores, side)

    # Select best performing models on validation set
    ids = get_best_tlp_model(val_scores, ["data", "method"])
    test_scores = test_scores[test_scores.id.isin(ids)]

    # Load PBG results
    scores_pbg = pd.read_parquet(SEPAL_DIR / f"test_lp_scores_pbg.parquet")
    scores_pbg = preprocess_pbg_tlp_scores(scores_pbg, side)

    # Merge results
    features = [
        "data",
        "method",
        # "count",
        "mr",
        "mrr",
        # "r1",
        # "r10",
        "r50",
        "total_time",
    ]
    scores = pd.concat([test_scores[features], scores_pbg[features]]).reset_index(
        drop=True
    )

    # Make time measurements readable
    scores["total_time"] = scores["total_time"].apply(human_time_duration)

    # Keep 4 significant figures for metrics
    scores[["mr", "mrr", "r50"]] = scores[["mr", "mrr", "r50"]].applymap(
        lambda x: f"{x:.4g}"
    )

    # Rename columns
    scores.rename(
        columns={
            "mr": "MR",
            "mrr": "MRR",
            "r50": "Hits@50",
            "total_time": "Time",
        },
        inplace=True,
    )

    for data in scores["data"].unique():
        subscores = scores[scores["data"] == data]
        subscores = pd.melt(
            subscores,
            id_vars=["method"],
            value_vars=["MR", "MRR", "Hits@50", "Time"],
            var_name="metric",
        )
        subscores = subscores.pivot(index="method", columns="metric", values="value")[
            ["MR", "Hits@50", "MRR", "Time"]
        ]
        print(
            subscores.to_latex(
                caption=f"Link prediction results on {DATASETS_NAMES[data]}, based on realistic ranks among sampled negatives.",
            )
        )

    return


# Clustering benchmark
def clustering_benchmark(
    datasets=["mini_yago3_lcc", "yago3_lcc", "yago4.5_lcc", "yago4_lcc"]
):
    # Load sepal partitioning results
    sepal_df = pd.read_parquet(SEPAL_DIR / "checkpoints_sepal.parquet")
    sepal_df = sepal_df[sepal_df["data"].isin(datasets)]
    sepal_df = sepal_df[sepal_df["partitioning"] == "sepal_subgraph"]
    sepal_df = sepal_df[~sepal_df["cover_subgraph_mem_usage"].isna()]
    sepal_df = sepal_df[["data", "cover_subgraph_time", "cover_subgraph_mem_usage"]]
    sepal_df["method"] = "SEPAL"
    sepal_df.rename(
        columns={"cover_subgraph_time": "time", "cover_subgraph_mem_usage": "memory"},
        inplace=True,
    )

    # Load partitioning results
    df = pd.concat(
        [
            pd.read_parquet(
                SEPAL_DIR
                / "experiments/partitioning/application_oriented_results.parquet"
            ),
            pd.read_parquet(SEPAL_DIR / "experiments/partitioning/results.parquet"),
            pd.read_parquet(
                SEPAL_DIR
                / "experiments/partitioning/old_application_oriented_results.parquet"
            ),
        ]
    )
    df = df[~df["memory"].isna()]
    df["method"] = df["method"].map(
        {
            "metis": "METIS",
            "LPA-ig": "LPA",
            "spectral": "SC",
            "louvain-ig": "Louvain",
            "louvain-nx": "Louvain",
            "eigen-ig": "LE",
            "infomap-ig": "Infomap",
            "leiden-ig": "Leiden",
        }
    )
    df = df[~df["method"].isna()]
    df["data"] = df["data"] + "_lcc"
    df = df[df["data"].isin(datasets)]
    df = df[["data", "method", "time", "memory"]]

    # Merge results
    df = pd.concat([df, sepal_df]).reset_index(drop=True)

    # Group by data and method
    df = df.groupby(["data", "method"], as_index=False).agg(
        {"time": ["mean", "std"], "memory": ["mean", "std"]}
    )

    # Add info about methods
    connectivity = {
        "SEPAL": True,
        "METIS": False,
        "LPA": True,
        "SC": False,
        "Louvain": False,
        "LE": False,
        "Infomap": False,
        "Leiden": True,        
    }
    df["Connectedness"] = df["method"].map(connectivity)

    bounds = {
        "SEPAL": True,
        "METIS": True,
        "LPA": False,
        "SC": False,
        "Louvain": False,
        "LE": False,
        "Infomap": False,
        "Leiden": False,        
    }
    df["Size constraints"] = df["method"].map(bounds)

    df = df[["data", "method", "Connectedness", "Size constraints", "time", "memory"]]

    # Make time measurements readable
    #df["time"] = df["time"].applymap(human_time_duration)

    # Make memory measurements readable
    #df["memory"] = df["memory"].applymap(human_memory_usage)

    # Convert memory to GB
    df["memory"] = df["memory"].applymap(convert_memory_to_gb)

    # Make data readable
    df["data"] = df["data"].map(DATASETS_NAMES)

    # Rename columns
    df.rename(columns={"time": "Time", "memory": "RAM usage"}, inplace=True)

    # Print tables
    for data in df["data"].unique():
        subscores = df[df["data"] == data]
        subscores = subscores.drop("data", axis=1)
        if data in ["YAGO4.5", "YAGO4"]:
            subscores["Time"] = subscores["Time"].applymap(convert_time_to_minutes)
        else:
            subscores["Time"] = subscores["Time"].applymap(convert_time_to_seconds)
        print(
            subscores.to_latex(
                caption=f"Performance of clustering methods on {data}.",
                index=False,
            )
        )
    return


# Entity coverage table
def entity_coverage():
    # Load entity coverage
    df = pd.read_parquet(SEPAL_DIR / "datasets/evaluation/entity_coverage.parquet")

    # Remove YAGO4.5
    df = df[df["source"] != "YAGO4.5"]

    print(
        df.pivot(index="target", columns="source", values="coverage").to_latex(
            caption="Percentage of rows of the downstream table whose corresponding entity is in the knowledge graph.",
            formatters=4 * [lambda x: f"{x:.1f}"],
        )
    )
    return


# Downstream tables statistics
def downstream_statistics():
    df = pd.DataFrame(
        {
            "Downstream table": list(TARGETS.values()),
            "Number of rows": [
                len(pd.read_parquet(SEPAL_DIR / table_path))
                for table_path in TARGETS.keys()
            ],
        }
    )
    print(df.transpose().to_latex(caption="Number of rows in the downstream tables."))
    return


# Preprocessing functions


def preprocess_pbg_tlp_scores(scores, side):
    # Preprocess scores
    scores = pd.melt(
        scores,
        id_vars=["id", "filtered", "sampled", "num_negatives"],
        value_vars=["head", "tail", "both"],
        var_name="Side",
        ignore_index=True,
    )
    scores = pd.concat(
        [scores.drop(["value"], axis=1), scores["value"].apply(pd.Series)], axis=1
    )
    scores = pd.melt(
        scores,
        id_vars=["id", "filtered", "sampled", "num_negatives", "Side"],
        value_vars=["realistic"],
        var_name="Type",
        ignore_index=True,
    )[["id", "filtered", "sampled", "num_negatives", "Side", "value"]]
    scores = pd.concat(
        [scores.drop(["value"], axis=1), scores["value"].apply(pd.Series)], axis=1
    )
    scores = scores[scores["Side"] == side]
    scores = scores[
        [
            "id",
            "filtered",
            "sampled",
            "num_negatives",
            "count",
            "arithmetic_mean_rank",
            "inverse_harmonic_mean_rank",
            "hits_at_1",
            "hits_at_10",
            "hits_at_50",
        ]
    ]
    scores.rename(
        columns={
            "arithmetic_mean_rank": "mr",
            "inverse_harmonic_mean_rank": "mrr",
            "hits_at_1": "r1",
            "hits_at_10": "r10",
            "hits_at_50": "r50",
        },
        inplace=True,
    )

    # Merge with checkpoint info
    checkpoints_pbg = pd.read_parquet(
        SEPAL_DIR / "baselines/PBG/checkpoints_pbg.parquet"
    )
    scores = scores.merge(checkpoints_pbg, on="id", how="left", suffixes=("", "_x"))
    scores["method"] = "PyTorch-BigGraph"
    scores["total_time"] = scores["training_time"]

    return scores


def preprocess_tlp_scores(scores, side, filters={"embed_method": ["distmult"]}):
    # Preprocess scores
    scores = pd.melt(
        scores,
        id_vars=["id", "filtered", "sampled", "num_negatives"],
        value_vars=["head", "tail", "both"],
        var_name="Side",
        ignore_index=True,
    )
    scores = pd.concat(
        [scores.drop(["value"], axis=1), scores["value"].apply(pd.Series)], axis=1
    )
    scores = pd.melt(
        scores,
        id_vars=["id", "filtered", "sampled", "num_negatives", "Side"],
        value_vars=["realistic"],
        var_name="Type",
        ignore_index=True,
    )[["id", "filtered", "sampled", "num_negatives", "Side", "value"]]
    scores = pd.concat(
        [scores.drop(["value"], axis=1), scores["value"].apply(pd.Series)], axis=1
    )
    scores = scores[scores["Side"] == side]
    scores = scores[
        [
            "id",
            "filtered",
            "sampled",
            "num_negatives",
            "count",
            "arithmetic_mean_rank",
            "inverse_harmonic_mean_rank",
            "hits_at_1",
            "hits_at_10",
            "hits_at_50",
        ]
    ]
    scores.rename(
        columns={
            "arithmetic_mean_rank": "mr",
            "inverse_harmonic_mean_rank": "mrr",
            "hits_at_1": "r1",
            "hits_at_10": "r10",
            "hits_at_50": "r50",
        },
        inplace=True,
    )

    # Merge with checkpoint info
    checkpoints = pd.read_parquet(SEPAL_DIR / "checkpoints_sepal.parquet")
    scores = scores.merge(checkpoints, on="id", how="left", suffixes=("", "_x"))
    scores.loc[scores["partitioning"] == "sepal_subgraph", "method"] = "SEPAL"
    scores.loc[scores["partitioning"].isna(), "method"] = scores["embed_method"].map(
        MODEL_NAMES
    )

    # Filter out experiments
    for k, v in filters.items():
        scores = scores[scores[k].isin(v)]

    return scores


def get_best_tlp_model(scores, grouping_variables):
    df2 = scores.groupby(grouping_variables + ["id"], as_index=False).sum("mrr")
    ids_to_keep = df2.loc[df2.groupby(grouping_variables)["mrr"].idxmax().values].id
    return ids_to_keep


def make_score_df(filters={"embed_dim": [100]}):
    # Load prediction scores
    scores = pd.read_parquet(SEPAL_DIR / "downstream_scores.parquet")

    # Load checkpoint info
    checkpoints = pd.read_parquet(SEPAL_DIR / "checkpoints_sepal.parquet")

    # Merge scores and checkpoint info
    scores = scores.merge(checkpoints, on="id", how="left", suffixes=("", "_x"))

    # Preprocess scores
    scores["downstream_task"] = scores["target_file"].map(TARGETS)
    scores["source_dataset"] = scores["triples_dir"].map(datasets_names)
    scores["mean_score"] = [arr.mean() for arr in scores["scores"]]
    scores["std_scores"] = [arr.std() for arr in scores["scores"]]

    # Set Methods
    scores.loc[scores["partitioning"] == "sepal_subgraph", "method"] = "SEPAL"
    scores.loc[scores["embed_method"] == "fastrp", "method"] = "FastRP"
    scores.loc[scores["embed_method"] == "random", "method"] = "Random"
    scores.loc[
        scores["core_prop"].isna() & (scores["embed_method"].isin(MODEL_NAMES.keys())),
        "method",
    ] = scores["embed_method"].map(MODEL_NAMES)

    # Distinguish SEPAL with and without partitioning
    scores.loc[
        (scores["method"] == "SEPAL") & (scores["num_subgraphs"] == 1), "method"
    ] = (scores["method"] + " (1 subgraph)")

    # Merge dim and embed_dim columns
    scores.loc[scores["embed_dim"].isna(), "embed_dim"] = scores["dim"]

    # Load PBG scores
    pbg_scores = pd.read_parquet(
        SEPAL_DIR / "baselines/PBG/downstream_scores_pbg.parquet"
    )
    pbg_scores["downstream_task"] = pbg_scores["target_file"].map(TARGETS)
    pbg_scores["mean_score"] = [arr.mean() for arr in pbg_scores["scores"]]
    pbg_scores["source_dataset"] = pbg_scores["data"].map(DATASETS_NAMES)
    checkpoints_pbg = pd.read_parquet(
        SEPAL_DIR / "baselines/PBG/checkpoints_pbg.parquet"
    )
    pbg_scores = pbg_scores.merge(
        checkpoints_pbg, on="id", how="left", suffixes=("", "_x")
    )
    pbg_scores["method"] = "PyTorch-BigGraph"
    pbg_scores["total_time"] = pbg_scores["training_time"]
    pbg_scores.loc[
        pbg_scores["relations"].apply(lambda x: x[0]["operator"]) == "diagonal",
        "embed_method",
    ] = "distmult"
    pbg_scores["embed_dim"] = pbg_scores["dimension"]

    # Concatenate SEPAL and PBG scores
    scores = pd.concat([scores, pbg_scores]).reset_index(drop=True)

    # Filter out experiments
    for k, v in filters.items():
        scores = scores[scores[k].isin(v)]

    return scores


def get_best_model(scores, grouping_variables):
    """
    Return the IDs of the best performing models, and sort the dataframe accordingly.
    `grouping_variables` should not contain "id".
    """
    df2 = scores.groupby(grouping_variables + ["id"], as_index=False).sum("mean_score")
    ids_to_keep = df2.loc[
        df2.groupby(grouping_variables)["mean_score"].idxmax().values
    ].id
    scores = scores[scores.id.isin(ids_to_keep)]

    # Sort results
    df2 = df2.sort_values("mean_score", ascending=True)
    scores = scores.sort_values(
        "id",
        key=lambda s: s.map(pd.Series(df2.mean_score.values, index=df2.id).to_dict()),
        ascending=False,
    )
    return scores


def make_scores_normalize(scores):
    for task in scores["downstream_task"].unique():
        # Filter DataFrame for the current task
        task_df = scores[scores["downstream_task"] == task]

        # Find the highest score for the current task
        best_score = task_df["mean_score"].max()

        # Calculate normalize score for the current task
        scores.loc[scores["downstream_task"] == task, "mean_score"] = (
            scores["mean_score"] / best_score
        )
    return scores


# Utility function


def human_time_duration(seconds):
    if seconds < 1:
        human_time = f"{round(seconds*1000)} ms"
    elif seconds < 60:
        human_time = f"{round(seconds)} s"
    elif seconds < 3600:
        human_time = f"{round(seconds/60)} min"
    else:
        human_time = f"{round(seconds/3600)} h"
    return human_time


def human_memory_usage(memory):
    if memory < 1024:
        human_memory = f"{memory:.3g} MB"
    else:
        human_memory = f"{memory/1024:.3g} GB"
    return human_memory


def convert_memory_to_gb(memory):
    return f"{memory / 1024:.3g} GB"


def convert_time_to_minutes(time):
    return f"{time / 60:.3g} min"


def convert_time_to_seconds(time):
    return f"{time:.3g} s"