import argparse
import matplotlib.pyplot as plt
import os
import pandas as pd
import pprint
import math
from load_all_result import load_all_result

def plot(result_df, vis_setting, y_log_scale=False):
    datasets = result_df["dataset"].unique()

    for dataset in datasets:
        # if dataset != "ember": continue
        df_d = result_df[result_df["dataset"] == dataset]

        if vis_setting["fig_path"].find("all") != -1:
            fig, ax = plt.subplots(figsize=(6.4, 9.6))
        plt.rcParams['axes.axisbelow'] = True
        plt.grid(True,which="both",ls="--",c='gray')

        for target in vis_setting["targets_list"]:
            df_dm = df_d
            for cond_key, cond_val in target["conds"].items():
                if cond_key == "trained_xgboost_folder_path":
                    df_dm = df_dm[df_dm[cond_key] == cond_val.format(dataset=dataset)]
                elif cond_key == "trained_xgboost_folder_path_contain":
                    df_dm = df_dm[df_dm["trained_xgboost_folder_path"].str.contains(cond_val)]
                else:
                    if isinstance(cond_val, list):
                        df_dm = df_dm[df_dm[cond_key].isin(cond_val)]
                    elif isinstance(cond_val, tuple):
                        df_dm = df_dm[(df_dm[cond_key] >= cond_val[0]) & (df_dm[cond_key] <= cond_val[1])]
                    else:
                        df_dm = df_dm[df_dm[cond_key] == cond_val]
            if target["conds"]["model_type"] == "phbf":
                if dataset == "url":
                    df_dm = df_dm[df_dm["bitarray_size"] >= 1600000]
                elif dataset == "ember":
                    df_dm = df_dm[df_dm["bitarray_size"] >= 3200000]

            df_dm = df_dm.sort_values(by=["lambda", "trained_xgboost_num_boost_round"])
            query_num = target["conds"]["query_num"]

            xs = []
            ys = []
            lambdas = []
            Ds = []
            for _, row in df_dm.iterrows():
                x = row["model_size_kb"]
                y = row["test_time_ms"] * 1000000 / query_num
                lambda_ = row["lambda"]
                D = row["trained_xgboost_num_boost_round"]
                xs.append(x)
                ys.append(y)
                lambdas.append(lambda_)
                Ds.append(D)

            print("Dataset: ", dataset, ", Model Type: ", target["conds"]["model_type"])
            print("Ds: ", Ds)
            print("lambdas: ", lambdas)
            print("xs: ", xs)
            print("ys: ", ys)

            if target["vis_info"]["linestyle"] is None:
                plt.scatter(
                    xs,
                    ys,
                    label=target["vis_info"]["label"],
                    marker=target["vis_info"]["marker"],
                    s=target["vis_info"]["markersize"] * 7.5,
                    color=target["vis_info"]["color"],
                    alpha=target["vis_info"].get("alpha", 1.0)
                )
            else:
                plt.plot(
                    xs,
                    ys,
                    label=target["vis_info"]["label"],
                    marker=target["vis_info"]["marker"],
                    markersize=target["vis_info"]["markersize"],
                    linestyle=target["vis_info"]["linestyle"],
                    color=target["vis_info"]["color"],
                    alpha=target["vis_info"].get("alpha", 1.0)
                )

            # annotate
            # if target["conds"]["model_type"] == "clbf":
            #     for x, y, lambda_ in zip(xs, ys, lambdas):
            #         plt.text(x, y, f"$\\lambda = {lambda_}$    ", fontsize=6, ha='right', va='center', color="tab:red", clip_on=True)
            # if target["conds"]["model_type"] == "plbf":
            #     for x, y, D in zip(xs, ys, Ds):
            #         plt.text(x, y, f"$D = {D}$    ", fontsize=6, ha='right', va='center', color="tab:green", clip_on=True)
            # if target["conds"]["model_type"] == "sandwichedlbf":
            #     for x, y, D in zip(xs, ys, Ds):
            #         plt.text(x, y, f"$D = {D}$    ", fontsize=6, ha='right', va='center', color="tab:brown", clip_on=True)

        plt.xlabel("Memory Usage [kB]", fontsize=22)
        plt.ylabel("Average Reject Time [ns]", fontsize=22)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        if y_log_scale:
            plt.yscale("log")

        # plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=18)

        ax = plt.gca()
        for dir in ['right', 'left', 'top', 'bottom']:
            ax.spines[dir].set_color('black')

        fig_path = vis_setting["fig_path"].format(dataset=dataset)
        if y_log_scale:
            fig_path = fig_path.replace(".", "_logscale.")
        os.makedirs(os.path.dirname(fig_path), exist_ok=True)
        print(f"Saving figure to {fig_path}")
        plt.savefig(fig_path, bbox_inches='tight')
        plt.close()

    # Only plot the legend
    fig, ax = plt.subplots(figsize=(1, 1))
    for target in vis_setting["targets_list"]:
        ax.plot(
            [],
            [],
            label=target["vis_info"]["label"],
            marker=target["vis_info"]["marker"],
            markersize=target["vis_info"]["markersize"],
            linestyle=target["vis_info"]["linestyle"],
            color=target["vis_info"]["color"],
            alpha=target["vis_info"].get("alpha", 1.0)
        )
    ax.axis('off')
    legend = ax.legend(loc='center', bbox_to_anchor=(0.5, 0.5), ncol=1, fontsize=18)
    legend_fig = legend.figure
    legend_fig.canvas.draw()
    bbox = legend.get_window_extent().transformed(legend_fig.dpi_scale_trans.inverted())
    bbox = bbox.expanded(1.05, 1.05)
    fig.savefig(vis_setting["fig_path"].replace("{dataset}", "legend"), bbox_inches=bbox)
    plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot results and save figures.")
    parser.add_argument("--models_dir", type=str, default="models", help="Directory to save the models. Defaults to 'models'.", required=False)
    parser.add_argument("--results_dir", type=str, default="results", help="Directory to save the results. Defaults to 'result'.", required=False)
    parser.add_argument("--fig_dir", type=str, default="fig", help="Directory to save the figures. Defaults to 'fig'.", required=False)
    args = parser.parse_args()
    models_dir = args.models_dir
    result_dir = args.results_dir
    fig_dir = args.fig_dir

    result_df = load_all_result(result_dir)
    # print(result_df.columns)

    # df_with_not_0_fnr = result_df[result_df["fnr"] != 0]
    # print(df_with_not_0_fnr)

    datasets = result_df["dataset"].unique()
    model_types = result_df["model_type"].unique()
    # print("datasets: ", datasets)
    # print("model_types: ", model_types)

    F = 0.001
    max_depths = [1, 2, 4, 6]
    vis_settings = []
    vis_settings.append(
        {
            "targets_list": [
                {
                    "conds": {
                        "model_type": "clbf",
                        "F": F,
                        "mu": 0.0,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "CLBF (Proposed)",
                        "marker": "o",
                        "markersize": 6.5,
                        "linestyle": "-",
                        "color": "tab:red"
                    }
                },
                {
                    "conds": {
                        "model_type": "bloom",
                        "F": F,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000
                    },
                    "vis_info": {
                        "label": "Bloom Filter",
                        "marker": "s",
                        "markersize": 6.5,
                        "linestyle": "-",
                        "color": "tab:blue"
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "F": F,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path_contain": "max_depth_4_",
                        "trained_xgboost_num_boost_round": [1, 2, 5 ,10, 20, 50, 100]
                    },
                    "vis_info": {
                        "label": "PLBF",
                        "marker": "D",
                        "markersize": 6.5,
                        "linestyle": "-.",
                        "color": "tab:green"
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "F": F,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path_contain": "max_depth_4_",
                        "trained_xgboost_num_boost_round": [1, 2, 5 ,10, 20, 50, 100]
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF",
                        "marker": "^",
                        "markersize": 6.5,
                        "linestyle": ":",
                        "color": "tab:brown"
                    }
                },
            ],
            "fig_path": os.path.join(fig_dir, "memory_query_time/memory_query_time_{dataset}.pdf")
        }
    )

    for pos_query_ratio in [0.0]:
        for max_depth in max_depths:
            pos_query_ratio_str = str(pos_query_ratio).replace(".", "")
            vis_settings.append(
                {
                    "targets_list": [
                        {
                            "conds": {
                                "model_type": "clbf",
                                "F": F,
                                "mu": 0.0,
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000,
                                "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_100_eta_03")
                            },
                            "vis_info": {
                                "label": "CLBF (Proposed)",
                                "marker": "o",
                                "markersize": 6.5,
                                "linestyle": "-",
                                "color": "tab:red"
                            }
                        },
                        {
                            "conds": {
                                "model_type": "bloom",
                                "F": F,
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000
                            },
                            "vis_info": {
                                "label": "Bloom Filter",
                                "marker": "s",
                                "markersize": 6.5,
                                "linestyle": "-",
                                "color": "tab:blue"
                            }
                        },
                        {
                            "conds": {
                                "model_type": "plbf",
                                "F": F,
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000,
                                "trained_xgboost_folder_path_contain": f"max_depth_{str(max_depth)}_",
                                "trained_xgboost_num_boost_round": [1, 2, 5 ,10, 20, 50, 100]
                            },
                            "vis_info": {
                                "label": "PLBF",
                                "marker": "D",
                                "markersize": 6.5,
                                "linestyle": "-.",
                                "color": "tab:green"
                            }
                        },
                        {
                            "conds": {
                                "model_type": "sandwichedlbf",
                                "F": F,
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000,
                                "trained_xgboost_folder_path_contain": f"max_depth_{str(max_depth)}_",
                                "trained_xgboost_num_boost_round": [1, 2, 5 ,10, 20, 50, 100]
                            },
                            "vis_info": {
                                "label": "Sandwiched LBF",
                                "marker": "^",
                                "markersize": 6.5,
                                "linestyle": ":",
                                "color": "tab:brown"
                            }
                        },
                    ],
                    "fig_path": os.path.join(fig_dir, "memory_query_time/memory_query_time_{dataset}_max_depth" + str(max_depth) + ".pdf")
                }
            )
            vis_settings.append(
                {
                    "targets_list": [
                        {
                            "conds": {
                                "model_type": "clbf",
                                "F": F,
                                "mu": 0.0,
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000,
                                "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_100_eta_03")
                            },
                            "vis_info": {
                                "label": "CLBF (Proposed)",
                                "marker": "o",
                                "markersize": 6.5,
                                "linestyle": "-",
                                "color": "tab:red"
                            }
                        },
                        {
                            "conds": {
                                "model_type": "bloom",
                                "F": F,
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000
                            },
                            "vis_info": {
                                "label": "Bloom Filter",
                                "marker": "s",
                                "markersize": 6.5,
                                "linestyle": "-",
                                "color": "tab:blue"
                            }
                        },
                        {
                            "conds": {
                                "model_type": "plbf",
                                "F": F,
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000,
                                "trained_xgboost_folder_path_contain": f"max_depth_{str(max_depth)}_",
                                "trained_xgboost_num_boost_round": [1, 2, 5 ,10, 20, 50, 100]
                            },
                            "vis_info": {
                                "label": "PLBF",
                                "marker": "D",
                                "markersize": 6.5,
                                "linestyle": "-.",
                                "color": "tab:green"
                            }
                        },
                        {
                            "conds": {
                                "model_type": "sandwichedlbf",
                                "F": F,
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000,
                                "trained_xgboost_folder_path_contain": f"max_depth_{str(max_depth)}_",
                                "trained_xgboost_num_boost_round": [1, 2, 5 ,10, 20, 50, 100]
                            },
                            "vis_info": {
                                "label": "Sandwiched LBF",
                                "marker": "^",
                                "markersize": 6.5,
                                "linestyle": ":",
                                "color": "tab:brown"
                            }
                        },
                        {
                            "conds": {
                                "model_type": "disjointadabf",
                                "fpr": (F * 0.75, F * 1.25),
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000,
                                "trained_xgboost_folder_path_contain": f"max_depth_{str(max_depth)}_",
                                "trained_xgboost_num_boost_round": [1, 2, 5 ,10, 20, 50, 100]
                            },
                            "vis_info": {
                                "label": "disjoint Ada-BF",
                                "marker": "X",
                                "markersize": 8,
                                "linestyle": None,
                                "color": "tab:purple"
                            }
                        },
                        {
                            "conds": {
                                "model_type": "habf",
                                "bits_per_key": 15,
                                "pos_query_ratio": pos_query_ratio,
                                "query_num": 40000
                            },
                            "vis_info": {
                                "label": "HABF",
                                "marker": ">",
                                "markersize": 6.5,
                                "linestyle": "-",
                                "color": "tab:orange"
                            }
                        },
                        {
                            "conds": {
                                "model_type": "phbf",
                                "pos_query_ratio": 0.0,
                                "query_num": 40000,
                                "hash_count": 20,
                            },
                            "vis_info": {
                                "label": "PHBF (k=30)",
                                "marker": "v",
                                "markersize": 8,
                                "linestyle": "-",
                                "color": "tab:pink",
                            }
                        },
                        {
                            "conds": {
                                "model_type": "phbf",
                                "pos_query_ratio": 0.0,
                                "query_num": 40000,
                                "hash_count": 20
                            },
                            "vis_info": {
                                "label": "PHBF (k=20)",
                                "marker": "v",
                                "markersize": 6,
                                "linestyle": "-.",
                                "color": "tab:pink",
                                "alpha": 0.8
                            }
                        },
                        {
                            "conds": {
                                "model_type": "phbf",
                                "pos_query_ratio": 0.0,
                                "query_num": 40000,
                                "hash_count": 10
                            },
                            "vis_info": {
                                "label": "PHBF (k=10)",
                                "marker": "v",
                                "markersize": 4.0,
                                "linestyle": ":",
                                "color": "tab:pink",
                                "alpha": 0.6
                            }
                        },
                    ],
                    "fig_path": os.path.join(fig_dir, "memory_query_time/memory_query_time_{dataset}_max_depth" + str(max_depth) + "_all.pdf")
                }
            )

    for vis_setting in vis_settings:
        plot(result_df, vis_setting, y_log_scale=True)
