import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colormaps


def get_labels():
    return {
        "true": "True Graph",
        "flop-restarts=0-lambda=2.0-randomstart=False": "FLOP (0 restarts)",
        "flop-restarts=20-lambda=2.0-randomstart=False": "FLOP (20 restarts)",
        "flop-restarts=100-lambda=2.0-randomstart=False": "FLOP (100 restarts)",
        "flop-restarts=500-lambda=2.0-randomstart=False": "FLOP (500 restarts)",
        "flop_baseline_lazygs-restarts=0-lambda=2.0": "FLOP (w/o Cholesky updates)",
        "flop_baseline_naivegs-restarts=0-lambda=2.0": "FLOP (naive)",
        "boss-restarts=0-lambda=2.0": "BOSS",
        "boss-restarts=20-lambda=2.0": "BOSS (20 restarts)",
        "boss-restarts=100-lambda=2.0": "BOSS (100 restarts)",
        "pc-alpha=0.01": "PC",
        "ges-lambda=2.0": "GES",
        "dagma-lambda=0.02": "DAGMA",
        "dagma_nonlinear-lambda=0.02": "DAGMA nonlinear",
        "exact-lambda=2.0": "Exact",
        "lingam": "LiNGAM",
    }


def get_priorities():
    return {
        "true": 0,
        "flop-restarts=0-lambda=2.0-randomstart=False": 10_000,
        "flop-restarts=20-lambda=2.0-randomstart=False": 9_999,
        "flop-restarts=100-lambda=2.0-randomstart=False": 9_998,
        "flop-restarts=500-lambda=2.0-randomstart=False": 9_997,
        "flop-restarts=0-lambda=2.0-randomstart=True": 9_500,
        "flop_baseline_lazygs-restarts=0-lambda=2.0": 11_000,
        "flop_baseline_naivegs-restarts=0-lambda=2.0": 12_000,
        "boss-restarts=0-lambda=2.0": 9_000,
        "boss-restarts=20-lambda=2.0": 8_999,
        "boss-restarts=100-lambda=2.0": 8_998,
        "pc-alpha=0.01": 8_000,
        "ges-lambda=2.0": 7_000,
        "dagma-lambda=0.02": 5_000,
        "exact-lambda=2.0": 6_000,
    }


def mean_runtime_plot(
    result_path, title, suppress_algos=("true",), overwrite_labels={}, save_name=None
):
    df = pd.read_csv(result_path)
    df = df[~df["algo"].str.lower().str.startswith(suppress_algos)]

    run_time_col = df.columns[8]
    df["nodes"] = df["graph"].str.extract(r"er-(\d+)-").astype(int)

    df[run_time_col] = pd.to_numeric(df[run_time_col], errors="coerce")

    agg = df.groupby(["algo", "nodes"])[run_time_col].agg(["mean", "std"]).reset_index()

    plt.figure(figsize=(15, 6))

    labels = get_labels()
    labels.update(overwrite_labels)

    priorities = get_priorities()

    groups = list(agg.groupby("algo"))
    groups_sorted = sorted(groups, key=lambda x: -priorities[x[0]])

    for algo, group in groups_sorted:
        group = group.sort_values("nodes")
        x = group["nodes"]
        y = group["mean"]
        y_std = group["std"]

        plt.plot(x, y, linestyle="-", label=labels[algo])
        plt.fill_between(x, y - y_std, y + y_std, alpha=0.3)

    plt.tick_params(axis="x", labelsize=14)
    plt.tick_params(axis="y", labelsize=14)
    plt.xlabel("Number of Nodes", fontsize=18, labelpad=10)
    plt.ylabel("Run-time (seconds)", fontsize=18, labelpad=10)
    plt.title(title, fontsize=18, pad=15)
    plt.legend(title="Algorithm", fontsize=14, title_fontsize=16)
    plt.grid(True, linestyle="--", alpha=0.6)
    if save_name is not None:
        plt.savefig(f"plots/{save_name}", bbox_inches="tight")
        plt.close()
    else:
        plt.show()


def runtime_accuracy_plot(
    result_path,
    title,
    accuracy_col,
    accuracy_label,
    suppress_algos=("true",),
    overwrite_labels={},
    filter_graph=None,
    filter_samples=None,
    filter_data_type=None,
    start_zero=True,
    save_name=None,
):
    df = pd.read_csv(result_path)
    df = df[~df["algo"].str.lower().str.startswith(suppress_algos)]
    if filter_graph is not None:
        df["graph_name"] = df["graph"].str.extract(r"-(.+)")
        df = df[df["graph_name"].isin(filter_graph)]
    if filter_samples is not None:
        df["samples"] = df["data"].str.extract(r"(\d+)-").astype(int)
        df = df[df["samples"].isin(filter_samples)]
    if filter_data_type is not None:
        df["data_type"] = df["data"].str.extract(r"-(.+)")
        df = df[df["data_type"].isin(filter_data_type)]

    priorities = get_priorities()

    agg = (
        df.groupby("algo")[["runtime", accuracy_col]].agg(["mean", "std"]).reset_index()
    )
    agg = agg.sort_values(
        by="algo", key=lambda col: col.map(priorities), ascending=False
    )
    agg.columns = [
        "algo",
        "runtime_mean",
        "runtime_std",
        "accuracy_mean",
        "accuracy_std",
    ]

    plt.figure(figsize=(8, 5))

    labels = get_labels()
    labels.update(overwrite_labels)

    algorithms = agg["algo"].unique()
    colors = dict(zip(algorithms, colormaps["tab10"].colors[: len(algorithms)]))

    for algo, df_algo in df.groupby("algo"):
        plt.scatter(
            df_algo["runtime"],
            df_algo[accuracy_col],
            alpha=0.2,  # faint
            s=30,
            c=[colors[algo]],
            label="_nolegend_",
            zorder=1,
        )

    for _, row in agg.iterrows():
        plt.plot(
            row["runtime_mean"],
            row["accuracy_mean"],
            marker="D",  # diamond
            markersize=9,
            color=colors[row["algo"]],
            markeredgecolor="black",
            lw=0,
            zorder=2,
            label=labels[row["algo"]],
        )

    if start_zero:
        plt.ylim(bottom=0)
    plt.tick_params(axis="x", labelsize=14)
    plt.tick_params(axis="y", labelsize=14)
    plt.xlabel("Run-time (seconds)", fontsize=18, labelpad=10)
    plt.ylabel(accuracy_label, fontsize=18, labelpad=10)
    plt.title(title, fontsize=18, pad=15)
    plt.legend(title="Algorithm", fontsize=14, title_fontsize=16)
    plt.grid(True, linestyle="--", alpha=0.6, which="both")
    if save_name is not None:
        plt.savefig(f"plots/{save_name}", bbox_inches="tight")
        plt.close()
    else:
        plt.show()


mean_runtime_plot(
    "results/large.csv",
    "ER, avg. degree 16, 1000 samples",
    overwrite_labels={
        "flop-restarts=0-lambda=2.0-randomstart=False": "FLOP (optimized)",
    },
    save_name="large.pdf",
)

runtime_accuracy_plot(
    "results/chain.csv",
    "Path, 50 nodes, 1000 samples",
    "shd",
    "SHD",
    overwrite_labels={
        "flop-restarts=0-lambda=2.0-randomstart=False": "FLOP (initial order)",
        "flop-restarts=0-lambda=2.0-randomstart=True": "FLOP (random start)",
    },
    save_name="chain.pdf",
)


runtime_accuracy_plot(
    "results/dense.csv",
    "ER, 25 nodes, avg. degree 16, 1000 samples",
    "shd",
    "SHD",
    overwrite_labels={"boss-restarts=0-lambda=2.0": "BOSS (0 restarts)"},
    filter_samples=[1_000],
    save_name="dense_1000.pdf",
)

runtime_accuracy_plot(
    "results/dense.csv",
    "ER, 25 nodes, avg. degree 16, 50000 samples",
    "shd",
    "SHD",
    overwrite_labels={"boss-restarts=0-lambda=2.0": "BOSS (0 restarts)"},
    filter_samples=[50_000],
    save_name="dense_50000.pdf",
)

runtime_accuracy_plot(
    "results/default.csv",
    "ER, 50 nodes, avg. degree 8, 1000 samples",
    "shd",
    "SHD",
    save_name="default.pdf",
)

runtime_accuracy_plot(
    "results/default.csv",
    "ER, 50 nodes, avg. degree 8, 1000 samples",
    "aid",
    "AID",
    save_name="default_aid.pdf",
)

runtime_accuracy_plot(
    "results/default.csv",
    "ER, 50 nodes, avg. degree 8, 1000 samples",
    "bic",
    "BIC",
    save_name="default_bic.pdf",
)

runtime_accuracy_plot(
    "results/default.csv",
    "ER, 50 nodes, avg. degree 8, 1000 samples",
    "dagma-loss",
    "DAGMA loss",
    save_name="default_dagma.pdf",
)

runtime_accuracy_plot(
    "results/sf.csv",
    "SF, 50 nodes, density param. 4, 1000 samples",
    "shd",
    "SHD",
    save_name="sf.pdf",
)

runtime_accuracy_plot(
    "results/sf.csv",
    "SF, 50 nodes, density param. 4, 1000 samples",
    "aid",
    "AID",
    save_name="sf_aid.pdf",
)

runtime_accuracy_plot(
    "results/uniform.csv",
    "ER, 50 nodes, avg. degree 8, uniform noise",
    "shd",
    "SHD",
    save_name="uniform.pdf",
)

runtime_accuracy_plot(
    "results/raw.csv",
    "ER, 50 nodes, avg. degree 8, unstandardized",
    "shd",
    "SHD",
    save_name="raw.pdf",
)

runtime_accuracy_plot(
    "results/onion.csv",
    "ER, 50 nodes, avg. degree 8, DaO",
    "shd",
    "SHD",
    save_name="onion.pdf",
)

runtime_accuracy_plot(
    "results/bnlearn.csv",
    "Alarm (bnlearn repository)",
    "shd",
    "SHD",
    filter_graph=["alarm"],
    save_name="alarm.pdf",
)

runtime_accuracy_plot(
    "results/bnlearn.csv",
    "Barley (bnlearn repository)",
    "shd",
    "SHD",
    filter_graph=["barley"],
    save_name="barley.pdf",
)

runtime_accuracy_plot(
    "results/bnlearn.csv",
    "Mildew (bnlearn repository)",
    "shd",
    "SHD",
    filter_graph=["mildew"],
    save_name="mildew.pdf",
)

runtime_accuracy_plot(
    "results/bnlearn.csv",
    "Pathfinder (bnlearn repository)",
    "shd",
    "SHD",
    filter_graph=["pathfinder"],
    save_name="pathfinder.png",
)

runtime_accuracy_plot(
    "results/bnlearn.csv",
    "Pathfinder (bnlearn repository)",
    "bic",
    "BIC",
    filter_graph=["pathfinder"],
    save_name="pathfinder_bic.png",
)

runtime_accuracy_plot(
    "results/large_accuracy.csv",
    "ER, 100 nodes, avg. degree 8, 1000 samples",
    "shd",
    "SHD",
    filter_graph=["100-8"],
    save_name="er_100.pdf",
)

runtime_accuracy_plot(
    "results/large_accuracy.csv",
    "ER, 250 nodes, avg. degree 8, 1000 samples",
    "shd",
    "SHD",
    filter_graph=["250-8"],
    save_name="er_250.pdf",
)

runtime_accuracy_plot(
    "results/large_accuracy.csv",
    "ER, 500 nodes, avg. degree 8, 1000 samples",
    "shd",
    "SHD",
    filter_graph=["500-8"],
    save_name="er_500.pdf",
)

runtime_accuracy_plot(
    "results/er_exact.csv",
    "ER, 25 nodes, avg. degree 4, 1000 samples",
    "shd",
    "SHD",
    filter_graph=["25-4"],
    save_name="er_25_4.pdf",
)

runtime_accuracy_plot(
    "results/er_exact.csv",
    "ER, 25 nodes, avg. degree 8, 1000 samples",
    "shd",
    "SHD",
    filter_graph=["25-8"],
    save_name="er_25_8.pdf",
)

runtime_accuracy_plot(
    "results/causalAssembly.csv",
    "causalAssembly, 5000 samples",
    "shd",
    "SHD",
    save_name="causalAssembly.png",
)

runtime_accuracy_plot(
    "results/causalAssembly.csv",
    "causalAssembly, 5000 samples",
    "shd",
    "SHD",
    save_name="causalAssemblyNoLingam.png",
    suppress_algos=("true", "lingam"),
)

runtime_accuracy_plot(
    "results/nonlinear.csv",
    "Nonlinear (MLP)",
    "shd",
    "SHD",
    save_name="nonlinear_mlp.png",
    filter_data_type=["mlp"],
)

runtime_accuracy_plot(
    "results/nonlinear.csv",
    "Nonlinear (GP)",
    "shd",
    "SHD",
    save_name="nonlinear_gp.png",
    filter_data_type=["gp"],
)

runtime_accuracy_plot(
    "results/nonlinear.csv",
    "Nonlinear (MLP)",
    "shd",
    "SHD",
    save_name="nonlinear_mlp_reduced_algos.png",
    filter_data_type=["mlp"],
    suppress_algos=("true", "lingam", "dagma_nonlinear"),
)

runtime_accuracy_plot(
    "results/nonlinear.csv",
    "Nonlinear (GP)",
    "shd",
    "SHD",
    save_name="nonlinear_gp_reduced_algos.png",
    filter_data_type=["gp"],
    suppress_algos=("true", "lingam", "dagma_nonlinear"),
)

runtime_accuracy_plot(
    "results/nonlinear.csv",
    "Nonlinear (MLP)",
    "bic",
    "BIC",
    save_name="nonlinear_mlp_reduced_algos_bic.png",
    filter_data_type=["mlp"],
    suppress_algos=("true", "lingam", "dagma_nonlinear", "pc", "ges"),
    start_zero=False,
)

runtime_accuracy_plot(
    "results/nonlinear.csv",
    "Nonlinear (GP)",
    "bic",
    "BIC",
    save_name="nonlinear_gp_reduced_algos_bic.png",
    filter_data_type=["gp"],
    suppress_algos=("true", "lingam", "dagma_nonlinear", "pc", "ges"),
    start_zero=False,
)
