import argparse
import matplotlib.pyplot as plt
import os
import pandas as pd
import pprint
import math
from load_all_result import load_all_result
import numpy as np

def extract_time_info(row):
    res = {}
    if "trained_xgboost_model_train_time_ms" in row and not np.isnan(row["trained_xgboost_model_train_time_ms"]):
        res["trained_xgboost_model_train_time_s"] = float(row["trained_xgboost_model_train_time_ms"]) / 1000.0
    if "calibration_time_ms" in row and not np.isnan(row["calibration_time_ms"]):
        res["calibration_time_s"] = float(row["calibration_time_ms"]) / 1000.0
    if "configuration_time_ms" in row and not np.isnan(row["configuration_time_ms"]):
        res["configuration_time_s"] = float(row["configuration_time_ms"]) / 1000.0
    if "add_keys_time_ms" in row and not np.isnan(row["add_keys_time_ms"]):
        res["add_keys_time_s"] = float(row["add_keys_time_ms"]) / 1000.0
    return res


def plot(result_df, vis_setting):
    datasets = result_df["dataset"].unique()

    for dataset in datasets:
        df_d = result_df[result_df["dataset"] == dataset]

        if "all" in vis_setting["fig_path"]:
            fig, ax = plt.subplots(figsize=(12, 18))
        else:
            fig, ax = plt.subplots(figsize=(8, 6))
        ax.grid(True, which="both", ls="--", c='gray', zorder=0)

        time_info_list = []

        for target in vis_setting["targets_list"]:
            df_dm = df_d
            for cond_key, cond_val in target["conds"].items():
                if cond_key == "trained_xgboost_folder_path":
                    df_dm = df_dm[df_dm[cond_key] == cond_val.format(dataset=dataset)]
                else:
                    df_dm = df_dm[df_dm[cond_key] == cond_val]

            df_dm = df_dm.sort_values(by="F")
            if len(df_dm) == 0:
                print("[Warning] No data found for the following condition:")
                pprint.pprint(target["conds"])
                continue
            elif len(df_dm) == 1:
                row = df_dm.iloc[0]
            else:
                print("[Warning] Multiple data found for the following condition:")
                exit(1)
                row = df_dm.iloc[0]

            time_info_list.append(
                (target, extract_time_info(row))
            )

        for target_idx, (target, time_info) in enumerate(time_info_list):
            now_bottom = 0
            if "add_keys_time_s" in time_info:
                ax.bar(target_idx, time_info["add_keys_time_s"], bottom=now_bottom, color="blue", label="Add Keys", zorder=3)
                now_bottom += time_info["add_keys_time_s"]
            if "configuration_time_s" in time_info:
                ax.bar(target_idx, time_info["configuration_time_s"], bottom=now_bottom, color="green", label="Configuration", zorder=3)
                now_bottom += time_info["configuration_time_s"]
            if "calibration_time_s" in time_info:
                ax.bar(target_idx, time_info["calibration_time_s"], bottom=now_bottom, color="red", label="Scoring", zorder=3)
                now_bottom += time_info["calibration_time_s"]
            if "trained_xgboost_model_train_time_s" in time_info:
                ax.bar(target_idx, time_info["trained_xgboost_model_train_time_s"], bottom=now_bottom, color="orange", label="XGBoost Training", zorder=3)
                now_bottom += time_info["trained_xgboost_model_train_time_s"]

            print(target['vis_info']['label'].replace('\n', ' '), f": {now_bottom:.2f} s")

        # plt.xlabel("Model Type", fontsize=22)
        plt.ylabel("Time [s]", fontsize=22)
        plt.xticks(range(len(vis_setting["targets_list"])), [target["vis_info"]["label"] for target in vis_setting["targets_list"]], rotation=90, fontsize=20)
        plt.yticks(fontsize=16)
        # plt.yscale("log")

        # legend
        # legends = [
        #     ("blue", "Add Keys"),
        #     ("green", "Configuration"),
        #     ("red", "Scoring"),
        #     ("orange", "XGBoost Training")
        # ]
        # legend_handles = [plt.Rectangle((0, 0), 1, 1, color=color) for color, label in legends]
        # legend_labels = [label for color, label in legends]
        # plt.legend(legend_handles, legend_labels, loc="upper right", fontsize=16)

        ax = plt.gca()
        for dir in ['right', 'left', 'top', 'bottom']:
            ax.spines[dir].set_color('black')

        fig_path = vis_setting["fig_path"].format(dataset=dataset)
        os.makedirs(os.path.dirname(fig_path), exist_ok=True)
        print(f"=== Saving figure to {fig_path} ===")
        plt.savefig(fig_path, bbox_inches='tight')
        plt.close()

    # Only plot the legend
    legends = [
        ("orange", "XGBoost Training"),
        ("red", "Scoring"),
        ("green", "Configuration"),
        ("blue", "Add Keys"),
    ]
    legend_handles = [plt.Rectangle((0, 0), 1, 1, color=color) for color, label in legends]
    legend_labels = [label for color, label in legends]
    fig, ax = plt.subplots(figsize=(3.5, 2.0))
    ax.axis('off')
    ax.legend(legend_handles, legend_labels, loc="center", fontsize=16)
    fig_path = vis_setting["fig_path"].replace("{dataset}", "legend")
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    print(f"=== Saving figure to {fig_path} ===")
    plt.savefig(fig_path, bbox_inches='tight')
    plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot results and save figures.")
    parser.add_argument("--models_dir", type=str, default="models", help="Directory to save the models. Defaults to 'models'.", required=False)
    parser.add_argument("--results_dir", type=str, default="results", help="Directory to save the results. Defaults to 'result'.", required=False)
    parser.add_argument("--fig_dir", type=str, default="fig", help="Directory to save the figures. Defaults to 'fig'.", required=False)
    args = parser.parse_args()
    models_dir = args.models_dir
    result_dir = args.results_dir
    fig_dir = args.fig_dir

    result_df = load_all_result(result_dir)
    # print(result_df.columns)
    # print("alpha_l: ")
    # print(result_df["alpha_l"].unique())
    # print("alpha_r: ")
    # print(result_df["alpha_r"].unique())

    # df_with_not_0_fnr = result_df[result_df["fnr"] != 0]
    # print(df_with_not_0_fnr)

    datasets = result_df["dataset"].unique()
    model_types = result_df["model_type"].unique()
    # print("datasets: ", datasets)
    # print("model_types: ", model_types)

    F = 0.001
    bit_size_of_Ada_BF = 1600000
    pos_query_ratio = 0.0
    vis_settings = [
        {
            "targets_list": [
                {
                    "conds": {
                        "model_type": "clbf",
                        "F": F,
                        "lambda": 1.0,
                        "mu": 0.0,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "CLBF\n(Proposed)",
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "Fast PLBF\n(D=100)",
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_10_eta_03")
                    },
                    "vis_info": {
                        "label": "Fast PLBF\n(D=10)",
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_1_eta_03")
                    },
                    "vis_info": {
                        "label": "Fast PLBF\n(D=1)",
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF\n(D=100)",
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_10_eta_03")
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF\n(D=10)",
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_1_eta_03")
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF\n(D=1)",
                    }
                },
                {
                    "conds": {
                        "model_type": "bloom",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000
                    },
                    "vis_info": {
                        "label": "Bloom Filter",
                    }
                },
            ],
            "fig_path": os.path.join(fig_dir, "time_hist/time_hist_{dataset}.pdf")
        },
        {
            "targets_list": [
                {
                    "conds": {
                        "model_type": "clbf",
                        "F": F,
                        "lambda": 1.0,
                        "mu": 0.0,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "CLBF\n(Proposed)",
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "Fast PLBF\n(D=100)",
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_10_eta_03")
                    },
                    "vis_info": {
                        "label": "Fast PLBF\n(D=10)",
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_1_eta_03")
                    },
                    "vis_info": {
                        "label": "Fast PLBF\n(D=1)",
                    }
                },
                {
                    "conds": {
                        "model_type": "disjointadabf",
                        "bit_size_of_Ada_BF": bit_size_of_Ada_BF,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "disjoint Ada-BF\n(D=100)",
                    }
                },
                {
                    "conds": {
                        "model_type": "disjointadabf",
                        "bit_size_of_Ada_BF": bit_size_of_Ada_BF,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_10_eta_03")
                    },
                    "vis_info": {
                        "label": "disjoint Ada-BF\n(D=10)",
                    }
                },
                {
                    "conds": {
                        "model_type": "disjointadabf",
                        "bit_size_of_Ada_BF": bit_size_of_Ada_BF,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_1_eta_03")
                    },
                    "vis_info": {
                        "label": "disjoint Ada-BF\n(D=1)",
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF\n(D=100)",
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_10_eta_03")
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF\n(D=10)",
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_1_eta_03")
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF\n(D=1)",
                    }
                },
                {
                    "conds": {
                        "model_type": "bloom",
                        "F": F,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000
                    },
                    "vis_info": {
                        "label": "Bloom Filter",
                    }
                },
                {
                    "conds": {
                        "model_type": "habf",
                        "bits_per_key": 15,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000
                    },
                    "vis_info": {
                        "label": "HABF",
                    }
                },
                {
                    "conds": {
                        "model_type": "phbf",
                        "bitarray_size": 1600000,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "hash_count": 30
                    },
                    "vis_info": {
                        "label": "PHBF\n(k=30)",
                    }
                },
                {
                    "conds": {
                        "model_type": "phbf",
                        "bitarray_size": 1600000,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "hash_count": 20
                    },
                    "vis_info": {
                        "label": "PHBF\n(k=20)",
                    }
                },
                {
                    "conds": {
                        "model_type": "phbf",
                        "bitarray_size": 1600000,
                        "pos_query_ratio": pos_query_ratio,
                        "query_num": 40000,
                        "hash_count": 10
                    },
                    "vis_info": {
                        "label": "PHBF\n(k=10)",
                    }
                },
            ],
            "fig_path": os.path.join(fig_dir, "time_hist/time_hist_{dataset}_all.pdf")
        }
    ]

    for max_depth in [1, 2, 4, 6]:
        vis_settings.append(
            {
                "targets_list": [
                    {
                        "conds": {
                            "model_type": "clbf",
                            "F": F,
                            "lambda": 1.0,
                            "mu": 0.0,
                            "pos_query_ratio": pos_query_ratio,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_100_eta_03")
                        },
                        "vis_info": {
                            "label": "CLBF\n(Proposed)",
                        }
                    },
                    {
                        "conds": {
                            "model_type": "plbf",
                            "F": F,
                            "pos_query_ratio": pos_query_ratio,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_100_eta_03")
                        },
                        "vis_info": {
                            "label": "Fast PLBF\n(D=100)",
                        }
                    },
                    {
                        "conds": {
                            "model_type": "plbf",
                            "F": F,
                            "pos_query_ratio": pos_query_ratio,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_10_eta_03")
                        },
                        "vis_info": {
                            "label": "Fast PLBF\n(D=10)",
                        }
                    },
                    {
                        "conds": {
                            "model_type": "plbf",
                            "F": F,
                            "pos_query_ratio": pos_query_ratio,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_1_eta_03")
                        },
                        "vis_info": {
                            "label": "Fast PLBF\n(D=1)",
                        }
                    },
                    {
                        "conds": {
                            "model_type": "sandwichedlbf",
                            "F": F,
                            "pos_query_ratio": pos_query_ratio,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_100_eta_03")
                        },
                        "vis_info": {
                            "label": "Sandwiched LBF\n(D=100)",
                        }
                    },
                    {
                        "conds": {
                            "model_type": "sandwichedlbf",
                            "F": F,
                            "pos_query_ratio": pos_query_ratio,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_10_eta_03")
                        },
                        "vis_info": {
                            "label": "Sandwiched LBF\n(D=10)",
                        }
                    },
                    {
                        "conds": {
                            "model_type": "sandwichedlbf",
                            "F": F,
                            "pos_query_ratio": pos_query_ratio,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_1_eta_03")
                        },
                        "vis_info": {
                            "label": "Sandwiched LBF\n(D=1)",
                        }
                    },
                    {
                        "conds": {
                            "model_type": "bloom",
                            "F": F,
                            "pos_query_ratio": pos_query_ratio,
                            "query_num": 40000
                        },
                        "vis_info": {
                            "label": "Bloom Filter",
                        }
                    },
                ],
                "fig_path": os.path.join(fig_dir, "time_hist/time_hist_{dataset}_max_depth_" + str(max_depth) + ".pdf")
            }
        )

    for vis_setting in vis_settings:
        plot(result_df, vis_setting)
