import argparse
import matplotlib.pyplot as plt
import os
import pandas as pd
import pprint
import math
from load_all_result import load_all_result
import numpy as np

def extract_time_info(row):
    res = {}
    if "trained_xgboost_model_train_time_ms" in row and not np.isnan(row["trained_xgboost_model_train_time_ms"]):
        res["trained_xgboost_model_train_time_s"] = float(row["trained_xgboost_model_train_time_ms"]) / 1000.0
    if "calibration_time_ms" in row and not np.isnan(row["calibration_time_ms"]):
        res["calibration_time_s"] = float(row["calibration_time_ms"]) / 1000.0
    if "configuration_time_ms" in row and not np.isnan(row["configuration_time_ms"]):
        res["configuration_time_s"] = float(row["configuration_time_ms"]) / 1000.0
    if "add_keys_time_ms" in row and not np.isnan(row["add_keys_time_ms"]):
        res["add_keys_time_s"] = float(row["add_keys_time_ms"]) / 1000.0
    return res

def plot(result_df, vis_setting):
    datasets = result_df["dataset"].unique()

    for dataset in datasets:
        df_d = result_df[result_df["dataset"] == dataset]

        plt.rcParams['axes.axisbelow'] = True
        plt.grid(True,which="both",ls="--",c='gray')

        for target in vis_setting["targets_list"]:
            df_dm = df_d
            for cond_key, cond_val in target["conds"].items():
                if cond_key == "trained_xgboost_folder_path":
                    df_dm = df_dm[df_dm[cond_key] == cond_val.format(dataset=dataset)]
                elif cond_key == "max_depth":
                    df_dm = df_dm[df_dm["trained_xgboost_folder_path"].str.contains(cond_val)]
                else:
                    df_dm = df_dm[df_dm[cond_key] == cond_val]

            df_dm = df_dm.sort_values(by="trained_xgboost_num_boost_round")
            
            ds = df_dm["trained_xgboost_num_boost_round"]
            total_construction_times = [
                sum(extract_time_info(row).values())
                for _, row in df_dm.iterrows()
            ]

            d_total_construction_times = [(d, t) for d, t in zip(ds, total_construction_times) if d % 10 == 0]
            ds = [d for d, _ in d_total_construction_times]
            total_construction_times = [t for _, t in d_total_construction_times]

            plt.plot(
                ds,
                total_construction_times,
                label=target["vis_info"]["label"],
                marker=target["vis_info"]["marker"],
                markersize=target["vis_info"]["markersize"],
                linestyle=target["vis_info"]["linestyle"],
                color=target["vis_info"]["color"],
                alpha=target["vis_info"].get("alpha", 1.0)
            )

        plt.xlabel(r"$D$, $\bar{D}$", fontsize=22)
        plt.ylabel("Construction Time (s)", fontsize=22)

        ax = plt.gca()
        for dir in ['right', 'left', 'top', 'bottom']:
            ax.spines[dir].set_color('black')

        fig_path = vis_setting["fig_path"].format(dataset=dataset)
        os.makedirs(os.path.dirname(fig_path), exist_ok=True)
        print(f"=== Saving figure to {fig_path} ===")
        # plt.legend()
        plt.savefig(fig_path, bbox_inches='tight')
        plt.close()

    # Only plot the legend
    fig, ax = plt.subplots(figsize=(1, 1))
    for target in vis_setting["targets_list"]:
        ax.plot(
            [],
            [], 
            label=target["vis_info"]["label"],
            marker=target["vis_info"]["marker"],
            markersize=target["vis_info"]["markersize"],
            linestyle=target["vis_info"]["linestyle"],
            color=target["vis_info"]["color"],
            alpha=target["vis_info"].get("alpha", 1.0)
        )
    ax.axis('off')
    legend = ax.legend(loc='center', bbox_to_anchor=(0.5, 0.5), ncol=1, fontsize=18)
    legend_fig = legend.figure
    legend_fig.canvas.draw()
    bbox = legend.get_window_extent().transformed(legend_fig.dpi_scale_trans.inverted())
    bbox = bbox.expanded(1.05, 1.05)
    fig.savefig(vis_setting["fig_path"].replace("{dataset}", "legend"), bbox_inches=bbox)
    plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot results and save figures.")
    parser.add_argument("--models_dir", type=str, default="models", help="Directory to save the models. Defaults to 'models'.", required=False)
    parser.add_argument("--results_dir", type=str, default="results", help="Directory to save the results. Defaults to 'result'.", required=False)
    parser.add_argument("--fig_dir", type=str, default="fig", help="Directory to save the figures. Defaults to 'fig'.", required=False)
    args = parser.parse_args()
    models_dir = args.models_dir
    result_dir = args.results_dir
    fig_dir = args.fig_dir

    result_df = load_all_result(result_dir)
    # print(result_df.columns)

    # df_with_not_0_fnr = result_df[result_df["fnr"] != 0]
    # print(df_with_not_0_fnr)

    datasets = result_df["dataset"].unique()
    model_types = result_df["model_type"].unique()
    # print("datasets: ", datasets)
    # print("model_types: ", model_types)
    # print("columns: ", result_df.columns)

    vis_settings = []
    for F in [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]:
        F_str = f"{F}".replace(".", "_")
        vis_settings.append(
        {
            "targets_list": [
                {
                    "conds": {
                        "model_type": "clbf",
                        "lambda": 1.0,
                        "mu": 0.0,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "F": F,
                        "max_depth": "max_depth_4_",
                    },
                    "vis_info": {
                        "label": "CLBF (Proposed)",
                        "marker": "o",
                        "markersize": 6.5,
                        "linestyle": "-",
                        "color": "tab:red"
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "F": F,
                        "max_depth": "max_depth_4_",
                    },
                    "vis_info": {
                        "label": "PLBF",
                        "marker": "D",
                        "markersize": 8,
                        "linestyle": "-.",
                        "color": "tab:green"
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "F": F,
                        "max_depth": "max_depth_4_",
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF",
                        "marker": "^",
                        "markersize": 8,
                        "linestyle": ":",
                        "color": "tab:brown"
                    }
                },
            ],
            "fig_path": os.path.join(fig_dir, "D_construction_time/D_construction_time_{dataset}" + f"_F_{F_str}" + ".pdf")
        })

    for vis_setting in vis_settings:
        plot(result_df, vis_setting)
