"""
Main experiment script.
"""
#%%
from collections import OrderedDict
from collections import OrderedDict
import alglab
import experiments
from matplotlib import pyplot as plt
import pandas as pd
import ast
import os
import numpy as np
import matplotlib.ticker as mticker

def ensure_dir(path):
    os.makedirs(os.path.dirname(path), exist_ok=True)


def get_metrics(folder):
    metrics = [
        ("adjusted_rand_index", "Adjusted Rand Index", f"{folder}adjusted_rand_index.png"),
        #("total_running_time_s", "Total Running Time (s)", f"{folder}running_time.png"),
        #("add_edges_total_running_time_s", "Add Edges Total Running Time (s)", f"{folder}running_time_add_edges.png"),
        #("predict_total_running_time_s", "Predict Total Running Time (s)", f"{folder}running_time_predict.png"),
        ("add_edges_iter_running_time_s", "Edge update runtime (s)", f"{folder}add_edge_running_time_iter.png"),
        ("predict_iter_running_time_s", "Predict runtime (s)", f"{folder}predict_running_time_iter.png"),
    ]

    if 'change' in folder:
        metrics.append(("remove_edges_total_running_time_s", "Remove Edges Total Running Time (s)", f"{folder}running_time_remove_edges.png"))
        metrics.append(("remove_edges_iter_running_time_s", "Remove edge update running time (s)", f"{folder}remove_edge_running_time_iter.png"))

    return metrics

tab10 = plt.get_cmap("tab10").colors  # 10 distinct colours


colour_map = {
    "LS24": tab10[2],
    "NaiveDynamicSC": tab10[0],
    "StaticCSC": tab10[8],
    "DynamicCSC": tab10[6],
    "FastFasterDynamicCSC": tab10[9],
    "FasterDynamicCSC": tab10[3],

}




def plot(results, col, ylabel, ax=None, label_prefix=None, logy=True):
    # local font size variable (not a parameter) -- change this value to increase all subplot fonts
    fontsize = 14
    # apply to rcParams so axes, ticks and the global legend created later will honor it
    plt.rcParams.update({
        'font.size': fontsize,
        'axes.titlesize': fontsize,
        'axes.labelsize': fontsize,
        'xtick.labelsize': fontsize,
        'ytick.labelsize': fontsize,
        'legend.fontsize': fontsize,
    })

    results.columns = results.columns.str.strip()
    if "algorithm" in results.columns:
        results["algorithm"] = results["algorithm"].str.strip()

    results[col.strip()] = pd.to_numeric(results[col.strip()], errors='coerce')

    if col == "add_edges_iter_running_time_s":
        # add edge deletion time too:
        del_name = " remove_edges_iter_running_time_s"
        results[col.strip()] = results[col.strip()] + results[del_name.strip()]





    for cat, group in results.groupby("algorithm"):
        means = group.groupby("iter")[col.strip()].mean()[1:]
        stds = group.groupby("iter")[col.strip()].std()[1:]

        label = f"{label_prefix} — {cat}" if label_prefix else cat
        ax.plot(means.index, means, label=label, color = colour_map[cat])
        ax.fill_between(means.index, means - stds, means + stds, alpha=0.2, color=colour_map[cat])
        ax.plot(means.index, means, label=label, color = colour_map[cat])
        ax.fill_between(means.index, means - stds, means + stds, alpha=0.2, color=colour_map[cat])

    ax.set_xlabel("Iteration", fontsize=fontsize)
    ax.set_ylabel(ylabel, fontsize=fontsize)

    if not logy:
        ax.set_ylim(bottom=0)

    if ylabel != "Adjusted Rand Index" and logy:
        ax.set_yscale("log")
        # ax.set_ylabel("Runtime (s)")
    # else:
    #     ax.set_ylabel("Solution Quality")
    ax.set_title(f"{ylabel}")# for different algorithms")
    if ylabel == "Adjusted Rand Index":
        ax.set_title("Solution Quality")
    
    #ax.set_title("Solution Quality")






def plot_multiple_datasets(dataset_files, labels=None, out_folder="figures/comparison/", filename=None, no_ari = False, logy=True):
    """Plot multiple CSV result files side-by-side for the same metrics.

    dataset_files: list of paths to CSV files (same format as used elsewhere)
    labels: optional list of short labels for the datasets (defaults to basenames)
    out_folder: where to save the summary image
    title: vertical title shown on the left
    """
    # read dataframes
    dfs = []
    for p in dataset_files:
        df = pd.read_csv(
            p,
            sep=r',(?=(?:[^\[]*\[[^\]]*\])*[^\]]*$)',
            engine="python",
            skipinitialspace=True,
        )
        dfs.append(df)

    if labels is None:
        labels = [os.path.splitext(os.path.basename(p))[0] for p in dataset_files]

    metrics = get_metrics(out_folder)
    if no_ari:
        metrics = [m for m in metrics if m[0] != "adjusted_rand_index"]

    # create a grid: one row per dataset, one column per metric
    n_rows = len(dfs)
    n_cols = len(metrics)
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3*n_rows), constrained_layout=True)

    # normalize axs to 2D array
    if n_rows == 1 and n_cols == 1:
        axs = np.array([[axs]])
    elif n_rows == 1:
        axs = np.array([axs])
    elif n_cols == 1:
        axs = np.array([[a] for a in axs])

    # title font size for dataset labels (compute once so it's available in loop)
    title_fs = max(8, int(plt.rcParams.get('font.size', 12)) + 2)

    for i, (df, lab) in enumerate(zip(dfs, labels)):
        # prepare df once per dataset
        df.columns = df.columns.str.strip()
        if "algorithm" in df.columns:
            df["algorithm"] = df["algorithm"].str.strip()
        
        
        
        

        for j, (col, ylabel, _filename) in enumerate(metrics):
            ax = axs[i][j]
            # reuse the single-plot function to draw lines and fill
            plot(df, col, ylabel, ax=ax, label_prefix=lab, logy=logy)
            # keep metric title on every subplot (plot() sets title already)
            #ax.set_title(f"{ylabel}")
            # annotate row with dataset label on the left of the first column
            if j == 0:
                # place dataset label outside the axes on the left
                #lab to upper case
                lab = lab.upper()
                ax.text(-0.38, 0.5, lab, transform=ax.transAxes, va='center', ha='center', rotation='vertical', fontsize=title_fs, fontweight='bold')

    # left-side vertical title for the whole figure
    # fig.subplots_adjust(left=0.22)
    # fig.subplots_adjust(left=0.22)
    title_fs = max(8, int(plt.rcParams.get('font.size', 12)) + 2)
    #fig.text(-0.05, 0.5, title, va='center', ha='center', rotation='vertical', fontsize=title_fs)


    # build a single global legend with unique algorithm names (strip dataset prefixes)
    
    algs_map = OrderedDict([
        ("NaiveDynamicSC","Naive"),
        ("StaticCSC","Static" ),
        ("LS24", "[LS24]"),
        ("DynamicCSC","Merge&Reduce []"),
        ("FasterDynamicCSC", "Ours"),
        ("FastFasterDynamicCSC","Ours + []"),
    ])
    unique = OrderedDict()
    algs_map = OrderedDict([
        ("NaiveDynamicSC","Naive"),
        ("StaticCSC","Static" ),
        ("LS24", "[LS24]"),
        ("DynamicCSC","Merge&Reduce []"),
        ("FasterDynamicCSC", "Ours"),
        ("FastFasterDynamicCSC","Ours + []"),
    ])
    unique = OrderedDict()
    for ax in axs.flat:
        h, l = ax.get_legend_handles_labels()
        for hh, ll in zip(h, l):
            # ll is like "mnist — algorithm"; keep only the algorithm part
            alg = ll.rsplit('—', 1)[-1].strip()
            if alg not in unique:
                unique[alg] = hh
    
    # reorder unique to respect algs_map
    unique = OrderedDict(sorted(unique.items(), key=lambda item: list(algs_map.keys()).index(item[0])))


    
    # reorder unique to respect algs_map
    unique = OrderedDict(sorted(unique.items(), key=lambda item: list(algs_map.keys()).index(item[0])))


    if unique:
        handles = list(unique.values())
        labels_all = list(unique.keys())
        labels_all = [algs_map.get(l, l) for l in labels_all]
        ncol = min(6, len(labels_all))

        nrows = len(labels_all)
        if filename  == "static_comp":
            fig.legend(handles, labels_all, loc='upper center', ncol=ncol, bbox_to_anchor=(0.5, 1.15), prop={'size': plt.rcParams.get('legend.fontsize', 12)})
        else:
            fig.legend(handles, labels_all, loc='upper center', ncol=ncol, bbox_to_anchor=(0.5, 1+nrows*0.02), prop={'size': plt.rcParams.get('legend.fontsize', 12)})


    #set filename to be a concat of all dataset names
    if filename is None:
        filename = "_".join([lab for lab in labels if lab]) + ".png"
    outpath = os.path.join(out_folder, filename + ".png")
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    fig.savefig(outpath, bbox_inches="tight", pad_inches=0.2)
    plt.close(fig)


def plot_multiple_datasets2(dataset_files, labels=None, out_folder="figures/comparison/", filename=None, no_ari=False, logy=True):
    dfs = []
    for p in dataset_files:
        df = pd.read_csv(
            p,
            sep=r',(?=(?:[^\[]*\[[^\]]*\])*[^\]]*$)',
            engine="python",
            skipinitialspace=True,
        )
        dfs.append(df)

    if labels is None:
        labels = [os.path.splitext(os.path.basename(p))[0] for p in dataset_files]

    metrics = get_metrics(out_folder)
    if no_ari:
        metrics = [m for m in metrics if m[0] != "adjusted_rand_index"]

    n_rows = len(metrics)   # rows = metrics
    n_cols = len(dfs)       # cols = datasets
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows), constrained_layout=True)

    if n_rows == 1 and n_cols == 1:
        axs = np.array([[axs]])
    elif n_rows == 1:
        axs = np.array([axs])
    elif n_cols == 1:
        axs = np.array([[a] for a in axs])

    title_fs = max(8, int(plt.rcParams.get('font.size', 12)) + 2)

    for j, (col, ylabel, _filename) in enumerate(metrics):  # iterate over rows
        for i, (df, lab) in enumerate(zip(dfs, labels)):    # iterate over columns
            df.columns = df.columns.str.strip()
            if "algorithm" in df.columns:
                df["algorithm"] = df["algorithm"].str.strip()

            ax = axs[j][i]
            plot(df, col, ylabel, ax=ax, label_prefix=lab, logy=logy)

            # only show dataset titles on the top row
            if j == 0:
                ax.set_title(lab.upper(), fontsize=title_fs, fontweight='bold')
            else:
                ax.set_title("")  # remove metric titles from lower rows (already labeled by ylabel)

            # only show x-label on the bottom row
            if j == n_rows - 1:
                ax.set_xlabel("Iteration")
            else:
                ax.set_xlabel("")  # remove intermediate x-labels

            # only show y-labels on leftmost column
            if i == 0:
                ax.set_ylabel(ylabel, fontsize=title_fs)
            else:
                ax.set_ylabel("")
                # ax.tick_params(labelleft=False)

    for i, ax in enumerate(axs.flat):
        if i in set([5,6,7,10]):
            # clear y limits and set log scale:
            ax.set_yscale("log")
            ax.relim()
            ax.autoscale()
            ax.yaxis.set_major_locator(mticker.LogLocator(base=10, numticks=10))


    algs_map = OrderedDict([
        ("NaiveDynamicSC","Naive"),
        ("StaticCSC","Static"),
        ("LS24", "[LS24]"),
        ("DynamicCSC","Merge&Reduce"),
        ("FasterDynamicCSC", "Ours"),
        ("FastFasterDynamicCSC","Ours + []"),
    ])
    unique = OrderedDict()
    for ax in axs.flat:
        h, l = ax.get_legend_handles_labels()
        for hh, ll in zip(h, l):
            alg = ll.rsplit('—', 1)[-1].strip()
            if alg not in unique:
                unique[alg] = hh

    unique = OrderedDict(sorted(unique.items(), key=lambda item: list(algs_map.keys()).index(item[0])))

    if unique:
        handles = list(unique.values())
        labels_all = [algs_map.get(l, l) for l in unique.keys()]
        fig.legend(handles, labels_all, loc='upper center', ncol=min(6, len(labels_all)),
                   bbox_to_anchor=(0.5, 1.05), prop={'size': plt.rcParams.get('legend.fontsize', 12)})

    if filename is None:
        filename = "_".join(labels) + "_transposed"
    outpath = os.path.join(out_folder, filename + ".png")
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    fig.savefig(outpath, bbox_inches="tight", pad_inches=0.2)
    plt.close(fig)


def main():

    # experiments.knn_mnist_experiment_with_deletions.run_all()
    # experiments.knn_emnist_experiment_with_deletions.run_all()
    # experiments.sbm_change_clusters_experiment.run_all()
    # experiments.sbm_new_clusters_experiment.run_all()
    # experiments.sbm_new_clusters_experiment_nols24.run_all()
    # experiments.sbm_merge_clusters_experiment.run_all()
    # experiments.sbm_merge_clusters_experiment_nols24.run_all()
    # experiments.knn_mnist_experiment.run_all()
    # experiments.knn_emnist_experiment.run_all()
    # experiments.sbm_new_clusters_experiment_large.run_all()





    #create plot for all sbm results
    plot_multiple_datasets2([
        #"results/sbm_change_clusters.csv",
        "results/sbm_new_clusters.csv",
        "results/sbm_merge_clusters.csv",
        "results/sbm_change_clusters.csv",
        "results/knn_mnist_deletions.csv",
        "results/knn_emnist_deletions.csv",
    ], labels=[
        # "change_clusters", 
        "split_clusters", 
        "merge_clusters",
        "change_clusters",
        "mnist", 
        "emnist"
    ], filename="comp", logy=False)




    plot_multiple_datasets([
        "results/sbm_new_clusters_large.csv"
    ], labels=[""], filename="static_comp", no_ari=True, logy=False)





if __name__ == "__main__":
    main()
# %%
