import pandas as pd
import numpy as np
from scipy.interpolate import CubicSpline
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
import matplotlib.gridspec as gridspec
import seaborn as sns
from adjustText import adjust_text
import pickle
import regex as re

from SEPAL.downstream_evaluation import TARGETS, DATASETS_NAMES
from SEPAL import SEPAL_DIR

pastel_palette = sns.color_palette("pastel")

# Define model-specific / dataset-specific palettes from here https://matplotlib.org/stable/gallery/color/named_colors.html
# name = "Pastel1"
# cmap = get_cmap(name)  # type: matplotlib.colors.ListedColormap
# colors = cmap.colors  # type: list
# colors = ["tomato", "royalblue", "g", "purple"]
colors = sns.color_palette("husl", n_colors=4).as_hex()
downstream_palette = {
    "US elections": colors[0],
    "Housing prices": colors[1],
    "US accidents": colors[2],
    "Movie revenues": colors[3],
}

my_colors = sns.color_palette("husl", n_colors=5).as_hex()
models_palette = {
    "SEPAL": my_colors[0],
    "Relational": my_colors[1],
    "Contextual": my_colors[2],
    "Random": my_colors[3],
    "PBG": my_colors[4],
}


datasets_names = {
    f"{str(SEPAL_DIR)}/datasets/knowledge_graphs/{k}": v
    for k, v in DATASETS_NAMES.items()
}

MODEL_NAMES = {
    "transe": "TransE",
    "distmult": "DistMult",
    "mure": "MuRE",
    "tucker": "TuckER",
    "hole": "HolE",
}


def make_score_df(filters={"embed_dim": 100}):
    # Load prediction scores
    scores = pd.read_parquet(SEPAL_DIR / "downstream_scores.parquet")

    # Load checkpoint info
    checkpoints = pd.read_parquet(SEPAL_DIR / "checkpoints_sepal.parquet")

    # Merge scores and checkpoint info
    scores = scores.merge(checkpoints, on="id", how="left", suffixes=("", "_x"))

    # Preprocess scores
    scores["downstream_task"] = scores["target_file"].map(TARGETS)
    scores["source_dataset"] = scores["triples_dir"].map(datasets_names)
    scores["mean_score"] = [arr.mean() for arr in scores["scores"]]
    scores["std_scores"] = [arr.std() for arr in scores["scores"]]

    # Set Methods
    scores.loc[scores["partitioning"] == "sepal_subgraph", "method"] = "SEPAL"
    scores.loc[scores["emb_model_name"] == "fastrp", "method"] = "Contextual"
    scores.loc[scores["emb_model_name"] == "random", "method"] = "Random"
    scores.loc[
        scores["core_prop"].isna()
        & (scores["emb_model_name"].isin(["distmult", "transe", "mure", "tucker", "hole"])),
        "method",
    ] = "Relational"

    # Set embedding models' names
    scores.loc[scores["emb_model_name"] == "fastrp", "emb_model_name"] = "FastRP"
    scores.loc[scores["emb_model_name"] == "random", "emb_model_name"] = (
        "Gaussian Noise"
    )
    scores.loc[scores["method"] == "SEPAL", "emb_model_name"] = (
        scores["method"] + " " + scores["embed_method"].map(MODEL_NAMES)
    )
    scores.loc[scores["method"] == "Relational", "emb_model_name"] = scores[
        "emb_model_name"
    ].map(MODEL_NAMES)

    # Distinguish SEPAL with and without partitioning
    scores.loc[
        (scores["method"] == "SEPAL") & (scores["num_subgraphs"] == 1), "method"
    ] = (scores["method"] + " (no partition)")

    # Merge dim and embed_dim columns
    scores.loc[scores["embed_dim"].isna(), "embed_dim"] = scores["dim"]

    # Filter out experiments
    for k, v in filters.items():
        scores = scores[scores[k] == v]

    return scores


def get_best_model(scores, grouping_variables):
    """
    Return the IDs of the best performing models, and sort the dataframe accordingly.
    `grouping_variables` should not contain "id".
    """
    df2 = scores.groupby(grouping_variables + ["id"], as_index=False).sum("mean_score")
    ids_to_keep = df2.loc[
        df2.groupby(grouping_variables)["mean_score"].idxmax().values
    ].id
    scores = scores[scores.id.isin(ids_to_keep)]

    # Sort results
    df2 = df2.sort_values("mean_score", ascending=True)
    scores = scores.sort_values(
        "id",
        key=lambda s: s.map(pd.Series(df2.mean_score.values, index=df2.id).to_dict()),
        ascending=False,
    )
    return scores


def make_scores_relative(scores):
    for task in scores["downstream_task"].unique():
        # Filter DataFrame for the current task
        task_df = scores[scores["downstream_task"] == task]

        # Find the highest score for the current task
        best_score = task_df["mean_score"].max()

        # Calculate relative score for the current task
        scores.loc[scores["downstream_task"] == task, "mean_score"] = (
            scores["mean_score"] / best_score
        )
    return scores


## Times

# Time distribution for SEPAL
def plot_sepal_times():
    # Load scores
    scores = make_score_df()

    # Filter sepal
    scores = scores[scores.method == "SEPAL"]

    # Keep only best performing models
    scores = get_best_model(scores, ["source_dataset"])[::-1]

    # Keep relevant columns
    times = scores[
        [
            "total_time",
            "subgraph_time",
            "embed_time",
            "propagation_time",
            "source_dataset",
        ]
    ].drop_duplicates(ignore_index=True)

    # Plotting
    ax = times.plot(
        x="source_dataset",
        y=["total_time"],
        logy=True,
        linewidth=2,
        marker="o",
        markersize=5,
        zorder=3,
        figsize=(3, 5),
    )
    times.plot(
        x="source_dataset",
        y=["subgraph_time", "embed_time", "propagation_time"],
        logy=True,
        linestyle="--",
        linewidth=1,
        marker="o",
        markersize=5,
        alpha=0.6,
        ax=ax,
    )

    # Axis labels
    ax.set_xlabel("")
    ax.set_ylabel("Time (s)", fontsize=14)

    # Add legend
    legend_labels = ["Total", "Subgraph", "Embedding", "Propagation"]
    # ax.legend(labels=legend_labels, loc='upper left', fontsize=12)
    ax.legend(
        labels=legend_labels,
        bbox_to_anchor=(0.45, -0.25),
        loc="lower center",
        borderaxespad=0.0,
        ncol=2,
        fontsize=10,
    )

    # Remove box
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    # Customize ticks
    ax.tick_params(axis="both", which="both", labelsize=14)

    # Remove x ticks that do not have a label
    ax.set_xticks([0, 1, 2])

    # Grid
    ax.yaxis.grid(True, linestyle="--", alpha=0.7)

    # Add horizontal grid lines for each graduation of the log scale
    low_limit, high_limit = ax.get_ylim()
    for y_position in ax.get_yticks(minor=True):
        if y_position > low_limit and y_position < high_limit:
            ax.axhline(
                y_position, color="gray", linestyle="--", linewidth=0.5, alpha=0.2
            )

    plt.savefig(SEPAL_DIR / "figures/sepal_times.pdf", bbox_inches="tight")
    return


# Overall time of different methods
def plot_times():
    # Load scores
    scores = make_score_df()

    # Keep only best performing models
    scores = get_best_model(scores, ["source_dataset", "emb_model_name"])

    # Remove Core Yago4
    scores = scores[scores["source_dataset"] != "Core Yago4"]

    # Load PBG checkpoints
    pbg_info = pd.read_parquet(SEPAL_DIR / "baselines/PBG/checkpoints_pbg.parquet")
    pbg_info["total_time"] = pbg_info["training_time"]
    pbg_info["emb_model_name"] = "PBG"
    pbg_info["source_dataset"] = pbg_info["data"].map(DATASETS_NAMES)
    pbg_info = pbg_info[pbg_info["subset"] == "all"]

    # Keep relevant columns
    features = ["total_time", "emb_model_name", "source_dataset"]

    times = pd.concat([scores[features], pbg_info[features]]).drop_duplicates(
        ignore_index=True
    )

    # Sort DataFrame
    times = times.sort_values(by="source_dataset")

    # Set figure size
    plt.figure(figsize=(3, 5))

    # Set the log scale for the y-axis
    plt.yscale("log")

    # Plot times
    ax = sns.lineplot(
        data=times,
        x="source_dataset",
        y="total_time",
        hue="emb_model_name",
        style="emb_model_name",
        linestyle="",
        markers=["o", "v", "v", "s"],
        markersize=7,
        sort=False,
        dashes=False,
        alpha=1,
    )
    sns.lineplot(
        data=times,
        x="source_dataset",
        y="total_time",
        hue="emb_model_name",
        sort=False,
        dashes=False,
        alpha=0.3,
        legend=False,
        ax=ax,
    )

    # Set labels
    plt.xlabel("")
    plt.ylabel("Time (s)", fontsize=14)

    # Add legend
    plt.legend(
        title="",
        bbox_to_anchor=(0.45, -0.25),
        loc="lower center",
        borderaxespad=0.0,
        ncol=2,
        fontsize=10,
    )

    # Remove box
    sns.despine()

    # Customize ticks
    plt.tick_params(axis="both", which="both", labelsize=14)

    # Add a grid
    plt.grid(axis="y", linestyle="--", alpha=0.7)

    # Add horizontal grid lines for each graduation of the log scale
    low_limit, high_limit = ax.get_ylim()
    for y_position in ax.get_yticks(minor=True):
        if y_position > low_limit and y_position < high_limit:
            ax.axhline(
                y_position, color="gray", linestyle="--", linewidth=0.5, alpha=0.2
            )

    plt.savefig(SEPAL_DIR / "figures/computation_times.pdf", bbox_inches="tight")
    return


## Downstream evaluation

# Stacked horizontal barplots
def hbar(relative_scores=True):
    # Load scores
    scores = make_score_df()
    scores["emb_model_name"] = scores["emb_model_name"].str.removeprefix("SEPAL ")

    # Load PBG scores
    pbg_scores = pd.read_parquet(
        SEPAL_DIR / "baselines/PBG/downstream_scores_pbg.parquet"
    )
    pbg_scores["downstream_task"] = pbg_scores["target_file"].map(TARGETS)
    pbg_scores["emb_model_name"] = "DistMult"
    pbg_scores["mean_score"] = [arr.mean() for arr in pbg_scores["scores"]]
    pbg_scores["source_dataset"] = pbg_scores["data"].map(DATASETS_NAMES)
    pbg_scores["method"] = "PBG"

    # Make data to plot
    features = [
        "downstream_task",
        "emb_model_name",
        "mean_score",
        "source_dataset",
        "method",
        "id",
    ]
    df = pd.concat([scores[features], pbg_scores[features]]).reset_index(drop=True)

    # Make relative scores if necessary
    if relative_scores:
        df = make_scores_relative(df)

    # Keep best performing embeddings for each method
    df = get_best_model(df, ["emb_model_name", "source_dataset", "method"])

    # Remove Core Yago4
    df = df[df["source_dataset"] != "Core Yago4"]

    ## Plot the results
    sources = df["source_dataset"].unique()[::-1]
    n_sources = len(sources)
    order = ["Movie revenues", "US accidents", "US elections", "Housing prices"]

    fig = plt.figure(figsize=(8, 9))
    height_ratios = [(df["source_dataset"] == s).sum() / len(order) for s in sources]
    outer = gridspec.GridSpec(
        nrows=n_sources, ncols=1, height_ratios=height_ratios, figure=fig, hspace=0.3
    )

    for i in range(n_sources):
        groups = df[df["source_dataset"] == sources[i]]["method"].unique()
        n_groups = len(groups)
        heights = [
            ((df["source_dataset"] == sources[i]) & (df["method"] == groups[j])).sum()
            for j in range(n_groups)
        ]
        inner = gridspec.GridSpecFromSubplotSpec(
            nrows=n_groups,
            ncols=1,
            subplot_spec=outer[i],
            hspace=0.2,
            height_ratios=heights,
        )

        for j in range(n_groups):
            df_to_plot = df[
                (df["source_dataset"] == sources[i]) & (df["method"] == groups[j])
            ]
            df_to_plot = pd.pivot_table(
                data=df_to_plot,
                index=["emb_model_name"],
                columns=["downstream_task"],
                values="mean_score",
                sort=False,
            )
            ax = plt.Subplot(fig, inner[j])
            if j == 0:
                ax.set_title(sources[i], fontweight="bold")
            if j == n_groups - 1 and i == n_sources - 1:
                if relative_scores:
                    ax.set_xlabel(
                        "Relative cumulative mean cross-validation score (R2)"
                    )
                else:
                    ax.set_xlabel("Cumulative mean cross-validation score (R2)")

            df_to_plot[order][::-1].plot(
                ax=ax,
                kind="barh",
                stacked=True,
                ylabel="Mean cross-validation score",
                xlabel="",
                rot=0,
                edgecolor="black",
                alpha=0.7,
            )
            ax.set_ylabel(
                f"{j+1}. " + groups[j], weight="bold", rotation=0, ha="left", color="k"
            )
            box_height = inner[j].get_position(fig).get_points()[1, 1] - inner[j].get_position(fig).get_points()[0, 1]
            ax.yaxis.set_label_coords(-0.45, 5*box_height)
            ax.get_legend().remove()
            ax.vlines(x=0, ymin=-1, ymax=10, linestyles="dotted", colors="k")
            if i == 0 and j == 0:
                a_ref = ax
            else:
                ax.sharex(a_ref)
            fig.add_subplot(ax)

    handles, labels = a_ref.get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        title="Evaluation dataset",
        loc="upper center",
        bbox_to_anchor=(0.5, 0.05),
        fancybox=True,
        ncol=4,
    )

    if relative_scores:
        a_ref.annotate(
            "Rankings",
            xy=(0, 0),
            xytext=(-2.1, 2.2),
            color="k",
            fontsize=14,
            weight="bold",
        )
    else:
        a_ref.annotate(
            "Rankings",
            xy=(0, 0),
            xytext=(-1.14, 2.2),
            color="k",
            fontsize=14,
            weight="bold",
        )

    if relative_scores:
        plt.savefig(SEPAL_DIR / "figures/relative_hbar.pdf", bbox_inches="tight")
    else:
        plt.savefig(SEPAL_DIR / "figures/hbar.pdf", bbox_inches="tight")

    return


# Waterfall chart


## Link prediction

def preprocess_tlp_scores(scores, side, filters={}):
    # Preprocess scores
    scores = pd.melt(
        scores,
        id_vars=["id", "filtered", "sampled", "num_negatives"],
        value_vars=["head", "tail", "both"],
        var_name="Side",
        ignore_index=True,
    )
    scores = pd.concat(
        [scores.drop(["value"], axis=1), scores["value"].apply(pd.Series)], axis=1
    )
    scores = pd.melt(
        scores,
        id_vars=["id", "filtered", "sampled", "num_negatives", "Side"],
        value_vars=["realistic"],
        var_name="Type",
        ignore_index=True,
    )[["id", "filtered", "sampled", "num_negatives", "Side", "value"]]
    scores = pd.concat(
        [scores.drop(["value"], axis=1), scores["value"].apply(pd.Series)], axis=1
    )
    scores = scores[scores["Side"] == side]
    scores = scores[
        [
            "id",
            "filtered",
            "sampled",
            "num_negatives",
            "count",
            "arithmetic_mean_rank",
            "inverse_harmonic_mean_rank",
            "hits_at_1",
            "hits_at_10",
            "hits_at_50",
        ]
    ]
    scores.rename(
        columns={
            "arithmetic_mean_rank": "mr",
            "inverse_harmonic_mean_rank": "mrr",
            "hits_at_1": "r1",
            "hits_at_10": "r10",
            "hits_at_50": "r50",
        },
        inplace=True,
    )

    # Merge with checkpoint info
    checkpoints = pd.read_parquet(SEPAL_DIR / "checkpoints_sepal.parquet")
    scores = scores.merge(checkpoints, on="id", how="left", suffixes=("", "_x"))
    scores.loc[scores["partitioning"] == "sepal_subgraph", "method"] = (
        "SEPAL (" + scores["embed_method"].map(MODEL_NAMES) + ")"
    )
    scores.loc[scores["partitioning"].isna(), "method"] = scores["embed_method"].map(
        MODEL_NAMES
    )

    # Filter out experiments
    for k, v in filters.items():
        scores = scores[scores[k].isin(v)]
    return scores


def get_best_tlp_model(scores, grouping_variables):
    df2 = scores.groupby(grouping_variables + ["id"], as_index=False).sum("mrr")
    ids_to_keep = df2.loc[
        df2.groupby(grouping_variables)["mrr"].idxmax().values
    ].id
    return ids_to_keep

# Transductive
def plot_tlp_results(side="tail"):
    # Load results
    val_scores = pd.read_parquet(SEPAL_DIR / "val_lp_scores.parquet")
    test_scores = pd.read_parquet(SEPAL_DIR / "test_lp_scores.parquet")

    # Preprocess scores
    val_scores = preprocess_tlp_scores(val_scores, side)
    test_scores = preprocess_tlp_scores(test_scores, side)

    # Select best performing models on validation set
    ids = get_best_tlp_model(val_scores, ["data", "method"])
    test_scores = test_scores[test_scores.id.isin(ids)]

    # Load PBG results
    scores_pbg = pd.read_parquet(
        SEPAL_DIR / f"baselines/PBG/test_lp_scores_pbg.parquet"
    )
    scores_pbg = pd.concat(
        [scores_pbg.drop(["metrics"], axis=1), scores_pbg["metrics"].apply(pd.Series)],
        axis=1,
    )
    scores_pbg = scores_pbg[scores_pbg["sampled"]]
    scores_pbg["method"] = "PyTorch-BigGraph"
    checkpoints_pbg = pd.read_parquet(
        SEPAL_DIR / "baselines/PBG/checkpoints_pbg.parquet"
    )
    scores_pbg = scores_pbg.merge(
        checkpoints_pbg, on="id", how="left", suffixes=("", "_x")
    )
    scores_pbg["total_time"] = scores_pbg["training_time"]
    scores_pbg.rename(columns={"pos_rank": "mr"}, inplace=True)
    scores_pbg = scores_pbg[scores_pbg["side"] == side]

    # Merge results
    features = [
        "data",
        "method",
        "count",
        "mr",
        "mrr",
        "r1",
        "r10",
        "r50",
        # "total_time",
    ]
    scores = pd.concat([test_scores[features], scores_pbg[features]]).reset_index(drop=True)

    # scores.pivot(index="data", columns="method", values=["mr", "mrr", "r1", "r10", "r50"])

    for data in scores["data"].unique():
        subscores = scores[scores["data"] == data]
        subscores = pd.melt(
            subscores,
            id_vars=["method"],
            value_vars=["mr", "mrr", "r1", "r10", "r50"],
            var_name="metric",
        )
        subscores = subscores.pivot(index="method", columns="metric", values="value")
        print(
            subscores.to_latex(
                float_format="%.4g",
                caption=f"Transductive link prediction results on {DATASETS_NAMES[data]}, based on realistic ranks among sampled negatives.",
            )
        )

    return



## Summary
# Radar/spider chart
def spider_chart():
    return


# Pareto optimality
def pareto_plot(relative_scores=True):
    # Load scores
    scores = make_score_df()

    # Make relative scores if necessary
    if relative_scores:
        scores = make_scores_relative(scores)

    # Make data to plot
    data = scores.groupby(
        ["method", "source_dataset", "id", "total_time"], as_index=False
    ).sum("mean_score")[["method", "source_dataset", "mean_score", "total_time"]]
    data.rename(columns={"method": "Method", "source_dataset": "Dataset"}, inplace=True)

    # Sort the DataFrame by 'total_time' in ascending order for each dataset
    sorted_data = data.sort_values(by=["Dataset", "total_time"], ascending=[True, True])

    # Compute Pareto frontier points for each dataset
    pareto_frontiers = {}
    for dataset, group in sorted_data.groupby("Dataset"):
        pareto_frontier_points = []
        max_score = float("-inf")

        for index, row in group.iterrows():
            if row["mean_score"] > max_score:
                max_score = row["mean_score"]
                pareto_frontier_points.append((row["total_time"], row["mean_score"]))

        pareto_frontiers[dataset] = pareto_frontier_points

    # create a matplotlib figure
    fig, ax = plt.subplots(1, 1, figsize=(6, 5))

    # Set the log scale for the x-axis
    plt.xscale("log")

    # create an axes-level plot
    sns.scatterplot(
        data=data,
        x="total_time",
        y="mean_score",
        hue="Method",
        style="Dataset",
        s=30,
        ax=ax,
        alpha=0.8,
    )
    sns.despine()
    sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc="center left", frameon=False)

    # Draw Pareto frontiers
    for dataset, pareto_points in pareto_frontiers.items():
        pareto_x, pareto_y = zip(*pareto_points)
        plt.step(
            pareto_x,
            pareto_y,
            color="black",
            linestyle=":",
            label=f"{dataset} Pareto Frontier",
            where="post",
            lw=1,
        )
        if relative_scores:
            plt.text(
                pareto_x[-2] * 1.5, pareto_y[-2] - 0.1, f"{dataset} PF", fontsize=8
            )
        else:
            plt.text(
                pareto_x[-2] * 1.5, pareto_y[-2] - 0.05, f"{dataset} PF", fontsize=8
            )

    # Customize ticks
    plt.tick_params(axis="both", which="both", labelsize=14)

    # Add labels and legend
    plt.xlabel("Computation time (s)", fontsize=14)
    if relative_scores:
        plt.ylabel("Relative cumulative R2 score", fontsize=14)
    else:
        plt.ylabel("Cumulative R2 score", fontsize=14)

    # Save the plot
    if relative_scores:
        plt.savefig(SEPAL_DIR / "figures/relative_pareto_plot.pdf", bbox_inches="tight")
    else:
        plt.savefig(SEPAL_DIR / "figures/pareto_plot.pdf", bbox_inches="tight")

    return


## Model analysis

# Core subgraph coverage
def plot_core_coverage():
    # Load results
    core_coverage = pd.read_parquet(
        SEPAL_DIR / "datasets/evaluation/core_coverage.parquet"
    )
    core_coverage["proportion"] *= 100
    core_coverage["coverage"] *= 100

    # Set up the figure and axes for subplots
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(7, 6), sharex=True, sharey=True)

    # Iterate over unique 'data' values
    for i, data_value in enumerate(core_coverage["data"].unique()):
        # Calculate subplot position
        row = i // 2
        col = i % 2

        # Filter DataFrame for the specific 'data' value
        subset = core_coverage[core_coverage["data"] == data_value]

        # Use Seaborn to create a line plot with different colors for each 'downstream_table'
        plot = sns.lineplot(
            x="proportion",
            y="coverage",
            hue="downstream_table",
            data=subset,
            ax=axes[row, col],
        )

        # Add vertical lines and annotations
        if data_value == "mini_yago3_lcc":
            axes[row, col].vlines(
                x=100, ymin=0, ymax=100, linestyles="dotted", colors="lightsalmon"
            )
            axes[row, col].annotate(
                "Mini Yago3\n           size",
                xy=(76, 90),
                color="lightsalmon",
                fontsize=8,
            )
        elif data_value == "yago3_lcc":
            axes[row, col].vlines(
                x=100 * 129493 / 2570716,
                ymin=0,
                ymax=100,
                linestyles="dotted",
                colors="lightsalmon",
            )
            axes[row, col].annotate(
                "Mini Yago3 size", xy=(-3, 100), color="lightsalmon", fontsize=8
            )
            axes[row, col].vlines(
                x=100, ymin=0, ymax=100, linestyles="dotted", colors="plum"
            )
            axes[row, col].annotate(
                "Yago3\n   size", xy=(85, 1), color="plum", fontsize=8
            )
        elif data_value == "yago4_lcc":
            axes[row, col].vlines(
                x=100 * 2570716 / 37959466,
                ymin=0,
                ymax=96,
                linestyles="dotted",
                colors="plum",
            )
            axes[row, col].annotate("Yago3 size", xy=(-2, 99), color="plum", fontsize=8)
            axes[row, col].vlines(
                x=100, ymin=0, ymax=100, linestyles="dotted", colors="tan"
            )
            axes[row, col].annotate(
                "Yago4\n   size", xy=(85, 1), color="tan", fontsize=8
            )
        elif data_value == "yago4_with_full_ontology":
            axes[row, col].vlines(
                x=100 * 2570716 / 66915744,
                ymin=0,
                ymax=100,
                linestyles="dotted",
                colors="plum",
            )
            axes[row, col].annotate(
                "Yago3 size", xy=(-3, 100), color="plum", fontsize=8
            )
            axes[row, col].vlines(
                x=100 * 37959466 / 66915744,
                ymin=0,
                ymax=100,
                linestyles="dotted",
                colors="tan",
            )
            axes[row, col].annotate("Yago4\nsize", xy=(59, 1), color="tan", fontsize=8)

        # Collect handles and labels for the legend
        handles, labels = plot.get_legend_handles_labels()

        # Remove subplot legend
        plot.get_legend().remove()

        # Set axis labels
        plot.set_xlabel("Core proportion (%)")
        plot.set_ylabel("Downstream coverage (%)")

        # Set plot title
        data_names = {
            "mini_yago3_lcc": "Mini Yago3",
            "yago3_lcc": "Yago3",
            "yago4_lcc": "Yago4",
            "yago4_with_full_ontology": "Yago4 and taxonomy",
        }
        axes[row, col].set_title(data_names[data_value])

    # Add a common legend outside the subplots
    fig.legend(
        handles=handles,
        labels=labels,
        title="Downstream Table",
        loc="lower center",
        ncol=4,
        bbox_to_anchor=(0.5, -0.08),
        fancybox=True,
    )

    # Adjust layout
    plt.tight_layout()

    # Save plot
    plt.savefig(SEPAL_DIR / "figures/core_coverage.pdf", bbox_inches="tight")

    return


def yago4_downstream_concentration():
    """
    Plot coverage's derivative and identify high downstream concentration areas in the graph.
    """
    # Load results
    core_coverage = pd.read_parquet(
        SEPAL_DIR / "datasets/evaluation/core_coverage.parquet"
    ).drop_duplicates(ignore_index=True)
    core_coverage["proportion"] *= 100
    core_coverage["coverage"] *= 100
    core_coverage = core_coverage[core_coverage["data"] == "yago4_lcc"]

    # Set up the figure and axes for subplots
    fig = plt.figure(figsize=(6, 8), constrained_layout=True)
    outer = gridspec.GridSpec(
        nrows=2, ncols=1, hspace=0.1, height_ratios=[4, 1], figure=fig
    )
    inner = gridspec.GridSpecFromSubplotSpec(
        nrows=2, ncols=2, subplot_spec=outer[0], height_ratios=[1, 1]
    )

    # Set color cycle
    name = "Pastel1"
    cmap = get_cmap(name)  # type: matplotlib.colors.ListedColormap
    colors = cmap.colors  # type: list
    markercolors = ["tomato", "royalblue", "g", "purple"]

    # Iterate over unique 'downstream_table' values
    intervals = []
    for i, downstream_table_value in enumerate(
        core_coverage["downstream_table"].unique()
    ):
        # Calculate subplot position
        row = i // 2
        col = i % 2
        ax = plt.Subplot(fig, inner[row, col])

        # Filter DataFrame for the specific 'downstream_table' value
        subset = core_coverage[
            core_coverage["downstream_table"] == downstream_table_value
        ][["proportion", "coverage"]]

        # Compute coverage's derivative
        subset = subset.sort_values("proportion", ascending=True).reset_index(drop=True)
        subset["d_coverage"] = subset["coverage"].diff() / subset["proportion"].diff()

        # Substract mean growth
        subset["relative_growth"] = subset["d_coverage"] - (
            max(subset["coverage"]) / 100
        )

        # Use Seaborn to create a line plot with different colors for each 'downstream_table'
        plot = sns.lineplot(
            x="proportion",
            y="relative_growth",
            data=subset,
            ax=ax,
            color=colors[i],
            lw=0.5,
            marker="x",
            markersize=3,
            mec=markercolors[i],
        )

        # Interpolate values between data points
        cs = CubicSpline(
            subset["proportion"].loc[1:], subset["relative_growth"].loc[1:]
        )

        # Generate a smooth curve
        x_smooth = np.linspace(0, 100, 1000)
        y_smooth = cs(x_smooth)
        ax.plot(x_smooth, y_smooth, lw=1, color="k")

        # Find approximate zeros
        zeros = x_smooth[:-1][y_smooth[:-1] * y_smooth[1:] < 0]

        # Draw the areas of high downstream entities concentration
        frontiers = zeros.tolist()
        if cs(0) > 0 and not 0 in zeros:
            frontiers = [0] + frontiers
        if cs(100) > 0 and not 100 in zeros:
            frontiers = frontiers + [100]
        intervals.append(frontiers)

        for k in range(len(frontiers) // 2):
            ax.axvspan(
                frontiers[2 * k], frontiers[2 * k + 1], color=colors[i], alpha=0.5
            )

        # Set axis labels
        plot.set_xlabel("Core proportion (%)")
        plot.set_ylabel("$c'_d(p) - c_d(1)$")

        # Set plot title
        ax.set_title(downstream_table_value, color=markercolors[i])

        # Add subplot
        fig.add_subplot(ax)

    ## Draw high concentration intervals
    ax = plt.Subplot(fig, outer[1])

    # Plot each interval for each group
    for i in range(len(intervals)):
        frontier = intervals[i]
        for k in range(len(frontier) // 2):
            ax.plot(
                [frontier[2 * k], frontier[2 * k + 1]],
                [i, i],
                color=colors[i],
                linewidth=10,
            )

    # Set axis limits
    ax.set_xlim(0, 100)
    ax.set_ylim(-1, len(intervals))

    # Customize the plot
    ax.set_yticks(range(len(intervals)))
    ax.set_yticklabels([table for table in core_coverage["downstream_table"].unique()])
    ax.grid(True, axis="y", linestyle="--", alpha=0.7)
    ax.set_xlabel("Core proportion (%)")
    ax.set_title("Areas of high downstream entities concentration")

    # Add subplot
    fig.add_subplot(ax)

    # Save plot
    plt.savefig(
        SEPAL_DIR / "figures/yago4_downstream_concentration.pdf", bbox_inches="tight"
    )

    return

# Parallel coordinates plot


# Core subgraph selection strategy
def core_strategy(relative_scores=True):
    # Load relevant data
    scores = make_score_df()
    scores["emb_model_name"] = scores["emb_model_name"] + " " + scores["core_selection"]
    scores = scores[~scores["emb_model_name"].isna()]

    # Make data to plot
    features = [
        "downstream_task",
        "emb_model_name",
        "mean_score",
        "source_dataset",
        "id",
    ]
    df = scores[features].reset_index(drop=True)

    # Make relative scores if necessary
    if relative_scores:
        df = make_scores_relative(df)

    # Keep best performing embeddings for each method
    df = get_best_model(df, ["emb_model_name", "source_dataset"])

    ## Plot the results
    fig, axs = plt.subplots(nrows=3, sharex=True, figsize=(7, 4))

    for i, data in enumerate(df["source_dataset"].unique()):
        pd.pivot_table(
            data=df[df["source_dataset"] == data],
            index=["emb_model_name"],
            columns=["downstream_task"],
            values="mean_score",
        ).plot(
            ax=axs[i],
            kind="barh",
            stacked=True,
            xlabel="",
            rot=0,
            edgecolor="black",
            alpha=0.7,
        )
        axs[i].set_title(data, fontweight="bold")
        axs[i].get_legend().remove()

    # Make legend
    handles, labels = axs[2].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        title="Evaluation dataset",
        loc="upper center",
        bbox_to_anchor=(0.5, 0),
        fancybox=True,
        ncol=4,
    )

    if relative_scores:
        axs[2].set_xlabel("Relative cumulative mean cross-validation score (R2)")
    else:
        axs[2].set_xlabel("Cumulative mean cross-validation score (R2)")

    fig.subplots_adjust(hspace=0.5)

    if relative_scores:
        plt.savefig(
            SEPAL_DIR / "figures/relative_core_strategy.pdf", bbox_inches="tight"
        )
    else:
        plt.savefig(SEPAL_DIR / "figures/core_strategy.pdf", bbox_inches="tight")

    return


# Propagation strategy
def propagation_strategy(relative_scores=True):
    # Load relevant data
    scores = make_score_df()
    scores["emb_model_name"] = (
        scores["emb_model_name"] + " " + scores["propagation_type"]
    )
    scores = scores[~scores["emb_model_name"].isna()]

    # Make data to plot
    features = [
        "downstream_task",
        "emb_model_name",
        "mean_score",
        "source_dataset",
        "id",
    ]
    df = scores[features].reset_index(drop=True)

    # Make relative scores if necessary
    if relative_scores:
        df = make_scores_relative(df)

    # Keep best performing embeddings for each method
    df = get_best_model(df, ["emb_model_name", "source_dataset"])

    ## Plot the results
    fig, axs = plt.subplots(nrows=3, sharex=True, figsize=(7, 4))

    for i, data in enumerate(df["source_dataset"].unique()):
        pd.pivot_table(
            data=df[df["source_dataset"] == data],
            index=["emb_model_name"],
            columns=["downstream_task"],
            values="mean_score",
        ).plot(
            ax=axs[i],
            kind="barh",
            stacked=True,
            xlabel="",
            rot=0,
            edgecolor="black",
            alpha=0.7,
        )
        axs[i].set_title(data, fontweight="bold")
        axs[i].get_legend().remove()

    # Make legend
    handles, labels = axs[2].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        title="Evaluation dataset",
        loc="upper center",
        bbox_to_anchor=(0.5, 0),
        fancybox=True,
        ncol=4,
    )

    if relative_scores:
        axs[2].set_xlabel("Relative cumulative mean cross-validation score (R2)")
    else:
        axs[2].set_xlabel("Cumulative mean cross-validation score (R2)")

    fig.subplots_adjust(hspace=0.5)

    if relative_scores:
        plt.savefig(
            SEPAL_DIR / "figures/relative_propagation_strategy.pdf", bbox_inches="tight"
        )
    else:
        plt.savefig(SEPAL_DIR / "figures/propagation_strategy.pdf", bbox_inches="tight")

    return


# Mixed effect of core and propagation strategies
def mixed_effect(relative_scores=True):
    # Load relevant data
    scores = make_score_df()
    scores["emb_model_name"] = (
        scores["emb_model_name"]
        + " "
        + scores["core_selection"]
        + " "
        + scores["propagation_type"]
    )
    scores = scores[~scores["emb_model_name"].isna()]

    # Make data to plot
    features = [
        "downstream_task",
        "emb_model_name",
        "mean_score",
        "source_dataset",
        "id",
    ]
    df = scores[features].reset_index(drop=True)

    # Make relative scores if necessary
    if relative_scores:
        df = make_scores_relative(df)

    # Keep best performing embeddings for each method
    df = get_best_model(df, ["emb_model_name", "source_dataset"])

    ## Plot the results
    fig, axs = plt.subplots(nrows=3, sharex=True, figsize=(7, 4))

    for i, data in enumerate(df["source_dataset"].unique()):
        pd.pivot_table(
            data=df[df["source_dataset"] == data],
            index=["emb_model_name"],
            columns=["downstream_task"],
            values="mean_score",
        ).plot(
            ax=axs[i],
            kind="barh",
            stacked=True,
            xlabel="",
            rot=0,
            edgecolor="black",
            alpha=0.7,
        )
        axs[i].set_title(data, fontweight="bold")
        axs[i].get_legend().remove()

    # Make legend
    handles, labels = axs[2].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        title="Evaluation dataset",
        loc="upper center",
        bbox_to_anchor=(0.5, 0),
        fancybox=True,
        ncol=4,
    )

    if relative_scores:
        axs[2].set_xlabel("Relative cumulative mean cross-validation score (R2)")
    else:
        axs[2].set_xlabel("Cumulative mean cross-validation score (R2)")

    fig.subplots_adjust(hspace=0.5)

    if relative_scores:
        plt.savefig(
            SEPAL_DIR / "figures/relative_mixed_effect.pdf", bbox_inches="tight"
        )
    else:
        plt.savefig(SEPAL_DIR / "figures/mixed_effect.pdf", bbox_inches="tight")

    return


# Number of subgraphs


# Pie chart with types in core subgraph (and total graph)
def types_pie_chart(KG="Yago4"):
    # Load types
    if KG == "Yago4":
        types = pd.read_parquet(
            SEPAL_DIR / "datasets/knowledge_graphs/yago4/yagoTypes.parquet"
        )
    elif KG == "Yago4.5":
        types = pd.read_parquet(
            SEPAL_DIR / "datasets/knowledge_graphs/yago4.5/yagoTypes.parquet"
        )
        bw_types = pd.read_parquet(
            SEPAL_DIR / "datasets/knowledge_graphs/yago4.5/yagoTypes_BW.parquet"
        )
        types = pd.concat([types, bw_types])

    # Load KG entities
    path = {v: k for k, v in datasets_names.items()}[KG]
    with open(f"{path}/metadata.pkl", "rb") as f:
        metadata = pickle.load(f)
    entity_list = list(metadata["entity_to_idx"].keys())

    # Load core entities
    core_path = {v: k for k, v in datasets_names.items()}["Core " + KG]
    with open(f"{core_path}/metadata.pkl", "rb") as f:
        core_metadata = pickle.load(f)
    core_entity_list = list(core_metadata["entity_to_idx"].keys())

    # Get KG types
    types = types[types.entity.isin(entity_list)]

    # Get core types
    core_types = types[types.entity.isin(core_entity_list)]

    ## Make plot
    def create_pie_chart(ax, df, title):
        # Count the frequency of each 'type'
        type_counts = df["type"].value_counts(normalize=True)
        # Group unfrequent types under the label 'Other'
        threshold = 0.0325
        mask = type_counts < threshold
        type_counts["Other"] = type_counts[mask].sum()
        mask["Other"] = False
        type_counts = type_counts[~mask]
        # Plot pie chart
        ax.pie(
            type_counts,
            labels=[
                re.sub(r"(\w)([A-Z])", r"\1 \2", label) for label in type_counts.index
            ],
            autopct="%1.0f%%",
            pctdistance=0.8,
            radius=1,
            wedgeprops=dict(width=0.6, edgecolor="k", linewidth=1.5),
            colors=sns.color_palette("Accent"),
            textprops=dict(weight="bold"),
            startangle=90,
        )
        ax.axis("equal")  # Equal aspect ratio ensures that pie is drawn as a circle.
        ax.text(
            0,
            0,
            title.replace(" ", "\n"),
            ha="center",
            va="center",
            fontsize=14,
            fontweight="bold",
        )

    # Create subplots
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))

    # Create pie charts
    create_pie_chart(axs[0], types, KG)
    create_pie_chart(axs[1], core_types, f"Core {KG}")
    fig.tight_layout()

    # Save plot
    plt.savefig(
        SEPAL_DIR / f"figures/{KG}_types.pdf".replace(" ", "_").lower(),
        bbox_inches="tight",
    )
    return


def supertypes_pie_chart(KG="Yago4.5"):
    # Load supertypes
    supertypes = pd.read_parquet(
        SEPAL_DIR / "datasets/knowledge_graphs/yago4.5/yagoSuperTypes.parquet"
    )

    # Load KG entities
    path = {v: k for k, v in datasets_names.items()}[KG]
    with open(f"{path}/metadata.pkl", "rb") as f:
        metadata = pickle.load(f)
    entity_list = list(metadata["entity_to_idx"].keys())

    # Load core entities
    core_path = {v: k for k, v in datasets_names.items()}["Core " + KG]
    with open(f"{core_path}/metadata.pkl", "rb") as f:
        core_metadata = pickle.load(f)
    core_entity_list = list(core_metadata["entity_to_idx"].keys())

    # Get KG types
    types = supertypes[supertypes.entity.isin(entity_list)]

    # Get core types
    core_types = supertypes[supertypes.entity.isin(core_entity_list)]

    # Sort dataframes by most frequent supertypes and types
    supertype_frequency_rank = types["supertype"].value_counts().index.to_list()
    types["supertype"] = (
        types["supertype"]
        .astype("category")
        .cat.set_categories(supertype_frequency_rank)
    )
    core_types["supertype"] = (
        core_types["supertype"]
        .astype("category")
        .cat.set_categories(supertype_frequency_rank)
    )

    type_frequency_rank = types["type"].value_counts().index.to_list()
    types["type"] = (
        types["type"].astype("category").cat.set_categories(type_frequency_rank)
    )
    core_types["type"] = (
        core_types["type"].astype("category").cat.set_categories(type_frequency_rank)
    )

    types = types.sort_values(by=["supertype", "type"])
    core_types = core_types.sort_values(by=["supertype", "type"])

    ## Make plot
    def create_nested_pie_chart(ax, df, title, size=0.3, legend=False):
        # Count the frequency of each 'type' and 'supertype'
        type_counts = df["type"].value_counts(normalize=True)[df["type"].unique()][:10]
        supertype_counts = df["supertype"].value_counts(normalize=True)[
            df["supertype"].unique()
        ]

        # Group unfrequent types under the label 'Other'
        threshold = 0.02
        mask = supertype_counts < threshold
        supertype_counts["Other"] = supertype_counts[mask].sum()
        mask["Other"] = False
        supertype_counts = supertype_counts[~mask]

        # Set colors
        cmap = plt.colormaps["Accent"]
        outer_colors = cmap.colors
        cmaps = [plt.cm.Greens, plt.cm.Purples]
        inner_colors = [cmaps[0](alpha) for alpha in np.linspace(0.7, 0.2, 4)] + [
            cmaps[1](alpha) for alpha in np.linspace(0.6, 0.2, 6)
        ]

        # Plot nested pie chart
        outer_pie = ax.pie(
            supertype_counts,
            labels=[
                re.sub(r"(\w)([A-Z])", r"\1 \2", label)
                for label in supertype_counts.index
            ],
            autopct=lambda p: f"{p:.0f}%" if p > 2.5 else "",
            pctdistance=0.88,
            labeldistance=1.1,
            radius=1,
            colors=outer_colors,
            wedgeprops=dict(width=2 * size, edgecolor="k", linewidth=1.5),
            textprops=dict(weight="bold"),
            startangle=-180 * supertype_counts[0],
        )
        inner_pie = ax.pie(
            type_counts,
            autopct=lambda p: f"{p:.0f}%" if p > 2.5 else "",
            pctdistance=0.8,
            radius=1 - size,
            colors=inner_colors,
            wedgeprops=dict(width=size, edgecolor="k", linewidth=1.5),
            normalize=False,
            textprops=dict(weight="bold"),
            startangle=-180 * supertype_counts[0],
        )
        ax.axis("equal")  # Equal aspect ratio ensures that pie is drawn as a circle.
        ax.text(
            0,
            0,
            title.replace(" ", "\n"),
            ha="center",
            va="center",
            fontsize=16,
            fontweight="bold",
        )
        if legend:
            ax.legend(
                inner_pie[0],
                [
                    re.sub(r"(\w)([A-Z])", r"\1 \2", label)
                    for label in type_counts.index
                ],
                ncol=5,
                bbox_to_anchor=(0.98, -0.05),
                frameon=False,
            )

    # Create subplots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # Create pie charts
    create_nested_pie_chart(axs[0], types, KG)
    create_nested_pie_chart(axs[1], core_types, f"Core {KG}", legend=True)
    # fig.tight_layout()

    # Save plot
    plt.savefig(
        SEPAL_DIR / f"figures/{KG}_supertypes.pdf".replace(" ", "_").lower(),
        bbox_inches="tight",
    )

    return


# Core subgraph size effect
def core_variation_plot(data="Mini Yago3"):
    # Load data
    scores = make_score_df(filters={})
    scores = scores[scores.method == "SEPAL (no partition)"].reset_index(drop=True)

    core_coverage = pd.read_parquet(
        SEPAL_DIR / "datasets/evaluation/core_coverage.parquet"
    ).drop_duplicates(ignore_index=True)
    core_coverage["coverage"] = 100 * core_coverage["coverage"]

    # Merge tables
    df = pd.merge(
        scores,
        core_coverage,
        how="left",
        left_on=["data", "downstream_task", "core_prop"],
        right_on=["data", "downstream_table", "proportion"],
    )
    df = df[
        ["downstream_task", "core_prop", "mean_score", "coverage", "source_dataset"]
    ]

    # Filter data
    df = df[df["source_dataset"] == data]

    # Set up the figure and axes for subplots
    fig, axes = plt.subplots(
        nrows=2, ncols=2, figsize=(7, 6), sharex=True, sharey=False
    )

    for i, downstream_value in enumerate(df["downstream_task"].unique()):
        # Calculate subplot position
        row = i // 2
        col = i % 2

        # Filter DataFrame for the specific 'downstream_task' values
        subset = df[df["downstream_task"] == downstream_value]

        # Use Seaborn to create a line plot with different colors for each downstream table
        coverage_plot = sns.lineplot(
            x="core_prop",
            y="coverage",
            data=subset,
            color="gray",
            alpha=0.7,
            ax=axes[row, col].twinx(),
        )

        score_plot = sns.lineplot(
            x="core_prop",
            y="mean_score",
            data=subset,
            color=downstream_palette[downstream_value],
            alpha=0.7,
            ax=axes[row, col],
        )

        # Set y limit
        axes[row, col].set_ylim(0, 1)

        # Set axis labels
        score_plot.set_xlabel("Core proportion (%)")

        score_plot.set_ylabel(
            "Mean cross-validation score (R2)",
            color=downstream_palette[downstream_value],
        )
        score_plot.tick_params(
            axis="y", labelcolor=downstream_palette[downstream_value]
        )

        coverage_plot.set_ylabel(
            "Downstream coverage (%)", color="gray", rotation=-90, labelpad=12
        )
        coverage_plot.tick_params(axis="y", labelcolor="gray")

        # Set plot title
        score_plot.set_title(
            downstream_value, color=downstream_palette[downstream_value]
        )

    # Adjust layout
    plt.tight_layout()

    # Save plot
    plt.savefig(
        SEPAL_DIR / f"figures/{data}_core_variation.pdf".replace(" ", "_").lower(),
        bbox_inches="tight",
    )

    return


# n_propagation_passes effect
def propagation_passes_effect(data="mini_yago3_lcc", side="tail"):
    # Load data
    val_scores = pd.read_parquet(SEPAL_DIR / "val_lp_scores.parquet")
    val_scores = preprocess_tlp_scores(val_scores, side, filters={"embed_method": ["distmult"], "partitioning": ["sepal_subgraph"]})
    val_scores = val_scores[val_scores["data"] == data]

    # Get data to plot
    val_scores = val_scores[["n_propagation_steps", "mr", "mrr", "r1", "r10", "r50"]]
    #val_scores = val_scores.melt(id_vars="n_propagation_steps", var_name="metric", value_name="score")


    # Plot metrics against n_propagation_steps
    sns.lineplot(data=val_scores[["n_propagation_steps", "mrr", "r1", "r10", "r50"]].melt(id_vars="n_propagation_steps", var_name="metric", value_name="score"), x="n_propagation_steps", y="score", hue="metric", marker="o", linestyle="--")
    ax2 = plt.twinx()
    sns.lineplot(data=val_scores[["n_propagation_steps", "mr"]], x="n_propagation_steps", y="mr", marker="o", ax=ax2)
  
    return

# Computation time against diffusion_stop parameter

# Perf against number of epochs and batch size, scatterplot with colormap='mean_score', x and y to n_epoch (or lr?) and batch_size, training_time as size of the dot, and shape for adam or sparse adam


## Baseline analysis
def distmult_batch_size(KG="Core Yago4"):
    # Load scores
    scores = make_score_df(
        filters={
            "embed_dim": 100,
            "method": "Relational",
            "emb_model_name": "DistMult",
            "source_dataset": KG,
            "lr": 1e-3,
        }
    )

    # Make scores relative
    scores = make_scores_relative(scores)

    # Preprocess dtypes
    scores["batch_size"] = scores["batch_size"].astype("int")
    scores["num_epochs"] = scores["num_epochs"].astype("int")

    # Log transform batch size (to get a nice colormap)
    scores["batch_size"] = np.log2(scores["batch_size"]).astype("int")

    # Sum scores across datasets
    scores = scores.groupby(
        ["id", "total_time", "batch_size", "num_epochs"], as_index=False
    ).sum("mean_score")[["mean_score", "total_time", "batch_size", "num_epochs"]]

    # Set the log scale for the x-axis
    plt.xscale("log")

    # Make plot
    palette = sns.color_palette("flare", as_cmap=True)
    lineplot = sns.lineplot(
        data=scores,
        x="total_time",
        y="mean_score",
        hue="batch_size",
        palette=palette,
        legend="full",
    )

    # Add annotations for num_epochs
    texts = []
    for i in range(len(scores)):
        text = plt.text(
            scores["total_time"][i],
            scores["mean_score"][i],
            str(scores["num_epochs"][i]),
            fontsize=8,
            ha="center",
            va="bottom",
            color=palette(
                (scores["batch_size"][i] - scores["batch_size"].min())
                / (scores["batch_size"].max() - scores["batch_size"].min())
            ),
        )
        texts.append(text)

    # Adjust text positions to avoid overlaps
    adjust_text(texts, arrowprops=dict(arrowstyle="-", color="black"))

    # Set labels and title
    plt.xlabel("Training time (s)")
    plt.ylabel("Relative cumulative R2 score")
    plt.title(
        f"Effect of batch size when training DistMult on {KG} with learning rate 1e-3"
    )
    handles, labels = lineplot.get_legend_handles_labels()
    plt.legend(
        handles,
        [2 ** int(l) for l in labels],
        title="Batch Size",
        loc="upper left",
        ncol=2,
    )

    # Save figure
    plt.savefig(
        SEPAL_DIR
        / f"figures/{KG}_distmult_batch_size_effect.pdf".replace(" ", "_").lower(),
        bbox_inches="tight",
    )

    return


def distmult_lr(KG="Core Yago4", batch_size=65536):
    # Load scores
    scores = make_score_df(
        filters={
            "embed_dim": 100,
            "method": "Relational",
            "emb_model_name": "DistMult",
            "source_dataset": KG,
            "batch_size": batch_size,
        }
    )

    # Make scores relative
    scores = make_scores_relative(scores)

    # Preprocess dtypes
    scores["batch_size"] = scores["batch_size"].astype("int")
    scores["num_epochs"] = scores["num_epochs"].astype("int")

    ## Preprocess constant lr data
    df1 = scores[scores.lr_scheduler.isna()]

    # Log transform learning rate (to get a nice colormap)
    df1["lr"] = np.log2(df1["lr"]).astype("float")

    # Sum scores across datasets
    df1 = df1.groupby(["id", "total_time", "lr", "num_epochs"], as_index=False).sum(
        "mean_score"
    )[["mean_score", "total_time", "lr", "num_epochs"]]

    ## Preprocess exponential lr data
    df2 = scores[~scores.lr_scheduler.isna()]
    df2 = df2[df2.lr_scheduler == "ExponentialLR"]
    df2["gamma"] = df2["lr_scheduler_kwargs"].apply(lambda x: x["gamma"])
    df2 = df2.groupby(
        ["id", "total_time", "lr", "num_epochs", "gamma"], as_index=False
    ).sum("mean_score")[["mean_score", "total_time", "lr", "num_epochs", "gamma"]]

    ## Preprocess other scheduled lr data
    df3 = scores[~scores.lr_scheduler.isna()]
    df3 = df3[df3.lr_scheduler != "ExponentialLR"]
    df3 = df3.groupby(
        ["id", "total_time", "lr", "num_epochs", "lr_scheduler"], as_index=False
    ).sum("mean_score")[
        ["mean_score", "total_time", "lr", "num_epochs", "lr_scheduler"]
    ]

    # Set the log scale for the x-axis
    plt.xscale("log")

    ## Make plot
    # Plot constant lr
    palette1 = sns.color_palette("flare", as_cmap=True)
    lineplot = sns.lineplot(
        data=df1,
        x="total_time",
        y="mean_score",
        hue="lr",
        palette=palette1,
        legend="full",
    )
    handles1, labels1 = lineplot.get_legend_handles_labels()
    n1 = len(handles1)

    # Plot exponential lr
    palette2 = sns.color_palette("blend:#8cd9b3,#206040", as_cmap=True)
    sns.lineplot(
        data=df2,
        x="total_time",
        y="mean_score",
        hue="gamma",
        legend="full",
        palette=palette2,
        alpha=0.8,
    )
    handles, labels = lineplot.get_legend_handles_labels()
    handles2, labels2 = handles[n1:], labels[n1:]
    n2 = len(handles2)

    # Plot other scheduled lr
    palette3 = sns.color_palette("colorblind")
    sns.lineplot(
        data=df3,
        x="total_time",
        y="mean_score",
        hue="lr_scheduler",
        legend="full",
        palette=palette3,
        alpha=0.8,
    )
    handles, labels = lineplot.get_legend_handles_labels()
    handles3, labels3 = handles[n1 + n2 :], labels[n1 + n2 :]

    # Add annotations for num_epochs
    texts1 = []
    for i in range(len(df1)):
        text = plt.text(
            df1["total_time"][i],
            df1["mean_score"][i],
            str(df1["num_epochs"][i]),
            fontsize=8,
            ha="center",
            va="bottom",
            color=palette1(
                (df1["lr"][i] - df1["lr"].min()) / (df1["lr"].max() - df1["lr"].min())
            ),
        )
        texts1.append(text)
    texts2 = []
    for i in range(len(df2)):
        text = plt.text(
            df2["total_time"][i],
            df2["mean_score"][i],
            str(df2["num_epochs"][i]),
            fontsize=8,
            ha="center",
            va="bottom",
            color=palette2(
                (df2["gamma"][i] - df2["gamma"].min())
                / (df2["gamma"].max() - df2["gamma"].min())
            ),
        )
        texts2.append(text)
    texts3 = []
    for i in range(len(df3)):
        text = plt.text(
            df3["total_time"][i],
            df3["mean_score"][i],
            str(df3["num_epochs"][i]),
            fontsize=8,
            ha="center",
            va="bottom",
            color=palette3[
                np.where(df3.lr_scheduler.unique() == df3["lr_scheduler"][i])[0][0]
            ],
        )
        texts3.append(text)

    # Adjust text positions to avoid overlaps
    adjust_text(texts1, arrowprops=dict(arrowstyle="-", color="black"))
    adjust_text(texts2, arrowprops=dict(arrowstyle="-", color="black"))
    adjust_text(texts3, arrowprops=dict(arrowstyle="-", color="black"))

    # Create legends for each lineplot
    legend1 = plt.legend(
        handles1,
        [round(2 ** float(l), 3) for l in labels1],
        title="Constant LR",
        loc="upper left",
        bbox_to_anchor=(1, 1),
        ncol=2,
    )
    legend2 = plt.legend(
        handles2,
        [f"$\gamma={l}$" for l in labels2],
        title="Exponential decay\nwith $LR_0=0.02$",
        loc="lower left",
        bbox_to_anchor=(1, 0.4),
        ncol=2,
    )
    legend3 = plt.legend(
        handles3,
        labels3,
        title="Other schedules",
        loc="lower left",
        bbox_to_anchor=(1, 0),
    )

    # Add legends to the plot
    plt.gca().add_artist(legend1)
    plt.gca().add_artist(legend2)
    plt.gca().add_artist(legend3)

    # Set labels and title
    plt.xlabel("Training time (s)", fontsize=12)
    plt.ylabel("Relative cumulative R2 score", fontsize=12)
    plt.title(
        f"Effect of Learning Rate when training DistMult on {KG} with batch size {batch_size}"
    )

    # Save figure
    plt.savefig(
        SEPAL_DIR
        / f"figures/{KG}_distmult_lr_effect_bs_{batch_size}.pdf".replace(
            " ", "_"
        ).lower(),
        bbox_inches="tight",
    )

    return


# Correlation between validation loss and evaluation score


def distmult_score_vs_loss(KG="Core Yago4"):
    # Load scores
    scores = make_score_df(
        filters={
            "embed_dim": 100,
            "method": "Relational",
            "emb_model_name": "DistMult",
            "source_dataset": KG,
        }
    )

    # Make scores relative
    scores = make_scores_relative(scores)

    # Make needed features
    scores["validation_loss"] = scores["training_losses"].apply(lambda x: x[-1])
    scores["gamma"] = scores["lr_scheduler_kwargs"].apply(
        lambda x: x if x is None else x["gamma"]
    )

    # Sum scores across datasets
    scores = scores.groupby(
        ["id", "num_epochs", "validation_loss", "batch_size", "lr", "gamma"],
        as_index=False,
        dropna=False,
    ).sum("mean_score")[["mean_score", "validation_loss", "batch_size", "lr", "gamma"]]

    # Make plot
    fig, ax = plt.subplots(figsize=(5, 5))
    for _, group in scores.groupby(["batch_size", "lr", "gamma"], dropna=False):
        sns.lineplot(
            data=group,
            x="validation_loss",
            y="mean_score",
            lw=1,
            alpha=0.7,
            ax=ax,
        )

    # Set labels and title
    plt.xlabel("Validation loss", fontsize=12)
    plt.ylabel("Evaluation score", fontsize=12)
    plt.title(f"DistMult, {KG}")

    # Save figure
    plt.savefig(
        SEPAL_DIR
        / f"figures/{KG}_distmult_score_vs_loss.pdf".replace(" ", "_").lower(),
        bbox_inches="tight",
    )

    return
