import argparse
import matplotlib.pyplot as plt
import os
import pandas as pd
import pprint
import math
from load_all_result import load_all_result

def plot(result_df, vis_setting):
    datasets = result_df["dataset"].unique()

    for dataset in datasets:
        df_d = result_df[result_df["dataset"] == dataset]

        if "all" in vis_setting["fig_path"]:
            fig, ax = plt.subplots(figsize=(6.4, 8.0))
        plt.rcParams['axes.axisbelow'] = True
        plt.grid(True,which="both",ls="--",c='gray')

        for target in vis_setting["targets_list"]:
            df_dm = df_d
            for cond_key, cond_val in target["conds"].items():
                if cond_key == "trained_xgboost_folder_path":
                    df_dm = df_dm[df_dm[cond_key] == cond_val.format(dataset=dataset)]
                else:
                    df_dm = df_dm[df_dm[cond_key] == cond_val]

            df_dm = df_dm.sort_values(by="model_size_kb")

            if "target" in target:
                plt.plot(
                    df_dm["model_size_kb"],
                    df_dm[target["target"]],
                    label=target["vis_info"]["label"],
                    marker=target["vis_info"]["marker"],
                    markersize=target["vis_info"]["markersize"],
                    linestyle=target["vis_info"]["linestyle"],
                    color=target["vis_info"]["color"],
                    alpha=target["vis_info"].get("alpha", 1.0)
                )
            else:
                plt.plot(
                    df_dm["model_size_kb"],
                    df_dm["fpr"],
                    label=target["vis_info"]["label"],
                    marker=target["vis_info"]["marker"],
                    markersize=target["vis_info"]["markersize"],
                    linestyle=target["vis_info"]["linestyle"],
                    color=target["vis_info"]["color"],
                    alpha=target["vis_info"].get("alpha", 1.0)
                )

        plt.xlabel("Memory Usage [kB]", fontsize=22)
        plt.ylabel("False Positive Rate", fontsize=22)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.yscale("log")
        plt.xlim(0, None)

        # plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=18)

        ax = plt.gca()
        for dir in ['right', 'left', 'top', 'bottom']:
            ax.spines[dir].set_color('black')

        fig_path = vis_setting["fig_path"].format(dataset=dataset)
        os.makedirs(os.path.dirname(fig_path), exist_ok=True)
        print(f"=== Saving figure to {fig_path} ===")
        plt.savefig(fig_path, bbox_inches='tight')
        plt.close()

        # unique_Fs = df_d["F"].unique()
        # model_size_kbs_dict = {F: {} for F in unique_Fs}
        # for target in vis_setting["targets_list"]:
        #     df_dm = df_d
        #     for cond_key, cond_val in target["conds"].items():
        #         if cond_key == "trained_xgboost_folder_path":
        #             df_dm = df_dm[df_dm[cond_key] == cond_val.format(dataset=dataset)]
        #         else:
        #             df_dm = df_dm[df_dm[cond_key] == cond_val]

        #     df_dm = df_dm.sort_values(by="F")
        #     Fs = df_dm["F"]
        #     ds = df_dm["d"]
        #     fprs = df_dm["fpr"]
        #     model_size_kbs = df_dm["model_size_kb"]
        #     print(target["vis_info"]["label"])
        #     for d, F, fpr, model_size_kb in zip(ds, Fs, fprs, model_size_kbs):
        #         print(f"F={F}, d={d}, fpr={fpr}, model_size_kb={model_size_kb}")
        #         model_size_kbs_dict[F][target["vis_info"]["label"]] = model_size_kb
        #     print()

        # for F in model_size_kbs_dict.keys():
        #     for label, model_size_kb in model_size_kbs_dict[F].items():
        #         print(f"F={F}, {label}: {model_size_kb}")
        #     print()

    # Only plot the legend
    fig, ax = plt.subplots(figsize=(1, 1))
    for target in vis_setting["targets_list"]:
        ax.plot(
            [],
            [], 
            label=target["vis_info"]["label"],
            marker=target["vis_info"]["marker"],
            markersize=target["vis_info"]["markersize"],
            linestyle=target["vis_info"]["linestyle"],
            color=target["vis_info"]["color"],
            alpha=target["vis_info"].get("alpha", 1.0)
        )
    ax.axis('off')
    legend = ax.legend(loc='center', bbox_to_anchor=(0.5, 0.5), ncol=1, fontsize=18)
    legend_fig = legend.figure
    legend_fig.canvas.draw()
    bbox = legend.get_window_extent().transformed(legend_fig.dpi_scale_trans.inverted())
    bbox = bbox.expanded(1.05, 1.05)
    fig.savefig(vis_setting["fig_path"].replace("{dataset}", "legend"), bbox_inches=bbox)
    plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot results and save figures.")
    parser.add_argument("--models_dir", type=str, default="models", help="Directory to save the models. Defaults to 'models'.", required=False)
    parser.add_argument("--results_dir", type=str, default="results", help="Directory to save the results. Defaults to 'result'.", required=False)
    parser.add_argument("--fig_dir", type=str, default="fig", help="Directory to save the figures. Defaults to 'fig'.", required=False)
    args = parser.parse_args()
    models_dir = args.models_dir
    result_dir = args.results_dir
    fig_dir = args.fig_dir

    result_df = load_all_result(result_dir)
    result_df.to_csv("results/result_df.csv")
    # print(result_df.columns)
    # print("alpha_l: ")
    # print(result_df["alpha_l"].unique())
    # print("alpha_r: ")
    # print(result_df["alpha_r"].unique())

    # df_with_not_0_fnr = result_df[result_df["fnr"] != 0]
    # print(df_with_not_0_fnr)

    datasets = result_df["dataset"].unique()
    model_types = result_df["model_type"].unique()
    # print("datasets: ", datasets)
    # print("model_types: ", model_types)

    vis_settings = [
        {
            "targets_list": [
                {
                    "conds": {
                        "model_type": "clbf",
                        "lambda": 1.0,
                        "mu": 0.0,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "CLBF (Proposed)",
                        "marker": "o",
                        "markersize": 6.5,
                        "linestyle": "-",
                        "color": "tab:red"
                    }
                },
                {
                    "conds": {
                        "model_type": "bloom",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000
                    },
                    "vis_info": {
                        "label": "Bloom Filter",
                        "marker": "s",
                        "markersize": 6.5,
                        "linestyle": "-",
                        "color": "tab:blue"
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "PLBF (D=100)",
                        "marker": "D",
                        "markersize": 8,
                        "linestyle": "-.",
                        "color": "tab:green"
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_10_eta_03")
                    },
                    "vis_info": {
                        "label": "PLBF (D=10)",
                        "marker": "D",
                        "markersize": 6.5,
                        "linestyle": "-.",
                        "color": "tab:green",
                        "alpha": 0.8
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_1_eta_03")
                    },
                    "vis_info": {
                        "label": "PLBF (D=1)",
                        "marker": "D",
                        "markersize": 4.0,
                        "linestyle": ":",
                        "color": "tab:green",
                        "alpha": 0.6
                    }
                }
            ],
            "fig_path": os.path.join(fig_dir, "memory_fpr/memory_fpr_{dataset}.pdf")
        },
        {
            "targets_list": [
                {
                    "conds": {
                        "model_type": "clbf",
                        "lambda": 1.0,
                        "mu": 0.0,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "CLBF (Proposed)",
                        "marker": "o",
                        "markersize": 6.5,
                        "linestyle": "-",
                        "color": "tab:red"
                    }
                },
                {
                    "conds": {
                        "model_type": "bloom",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000
                    },
                    "vis_info": {
                        "label": "Bloom Filter",
                        "marker": "s",
                        "markersize": 6.5,
                        "linestyle": "-",
                        "color": "tab:blue"
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "PLBF (D=100)",
                        "marker": "D",
                        "markersize": 8,
                        "linestyle": "-",
                        "color": "tab:green"
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_10_eta_03")
                    },
                    "vis_info": {
                        "label": "PLBF (D=10)",
                        "marker": "D",
                        "markersize": 6.5,
                        "linestyle": "-.",
                        "color": "tab:green",
                        "alpha": 0.8
                    }
                },
                {
                    "conds": {
                        "model_type": "plbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_1_eta_03")
                    },
                    "vis_info": {
                        "label": "PLBF (D=1)",
                        "marker": "D",
                        "markersize": 4.0,
                        "linestyle": ":",
                        "color": "tab:green",
                        "alpha": 0.6
                    }
                },
                {
                    "conds": {
                        "model_type": "disjointadabf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "disjoint Ada-BF (D=100)",
                        "marker": "X",
                        "markersize": 8,
                        "linestyle": "-",
                        "color": "tab:purple"
                    }
                },
                {
                    "conds": {
                        "model_type": "disjointadabf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_10_eta_03")
                    },
                    "vis_info": {
                        "label": "disjoint Ada-BF (D=10)",
                        "marker": "X",
                        "markersize": 6.5,
                        "linestyle": "-.",
                        "color": "tab:purple",
                        "alpha": 0.8
                    }
                },
                {
                    "conds": {
                        "model_type": "disjointadabf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_1_eta_03")
                    },
                    "vis_info": {
                        "label": "disjoint Ada-BF (D=1)",
                        "marker": "X",
                        "markersize": 4.0,
                        "linestyle": ":",
                        "color": "tab:purple",
                        "alpha": 0.6
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF (D=100)",
                        "marker": "^",
                        "markersize": 8,
                        "linestyle": "-",
                        "color": "tab:brown"
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_10_eta_03")
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF (D=10)",
                        "marker": "^",
                        "markersize": 6.5,
                        "linestyle": "-.",
                        "color": "tab:brown",
                        "alpha": 0.8
                    }
                },
                {
                    "conds": {
                        "model_type": "sandwichedlbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_1_eta_03")
                    },
                    "vis_info": {
                        "label": "Sandwiched LBF (D=1)",
                        "marker": "^",
                        "markersize": 4.0,
                        "linestyle": ":",
                        "color": "tab:brown",
                        "alpha": 0.6
                    }
                },
                {
                    "conds": {
                        "model_type": "phbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "hash_count": 20
                    },
                    "vis_info": {
                        "label": "PHBF (k=30)",
                        "marker": "v",
                        "markersize": 8,
                        "linestyle": "-",
                        "color": "tab:pink",
                    }
                },
                {
                    "conds": {
                        "model_type": "phbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "hash_count": 20
                    },
                    "vis_info": {
                        "label": "PHBF (k=20)",
                        "marker": "v",
                        "markersize": 6,
                        "linestyle": "-.",
                        "color": "tab:pink",
                        "alpha": 0.8
                    }
                },
                {
                    "conds": {
                        "model_type": "phbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "hash_count": 10
                    },
                    "vis_info": {
                        "label": "PHBF (k=10)",
                        "marker": "v",
                        "markersize": 4.0,
                        "linestyle": ":",
                        "color": "tab:pink",
                        "alpha": 0.6
                    }
                },
                {
                    "conds": {
                        "model_type": "habf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                    },
                    "vis_info": {
                        "label": "HABF",
                        "marker": ">",
                        "markersize": 6.0,
                        "linestyle": "-",
                        "color": "tab:orange",
                    }
                },
                {
                    "conds": {
                        "model_type": "habf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                    },
                    "vis_info": {
                        "label": "(HABF FPR on Training Data)",
                        "marker": ">",
                        "markersize": 6.0,
                        "linestyle": ":",
                        "color": "tab:orange",
                        "alpha": 0.8
                    },
                    "target": "fpr_on_val"
                },
            ],
            "fig_path": os.path.join(fig_dir, "memory_fpr/memory_fpr_{dataset}_all.pdf")
        },
    ]

    for max_depth in [1, 2, 4, 6]:
        vis_settings.append(
            {
                "targets_list": [
                    {
                        "conds": {
                            "model_type": "clbf",
                            "lambda": 1.0,
                            "mu": 0.0,
                            "pos_query_ratio": 0.0,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_100_eta_03")
                        },
                        "vis_info": {
                            "label": "CLBF (Proposed)",
                            "marker": "o",
                            "markersize": 6.5,
                            "linestyle": "-",
                            "color": "tab:red"
                        }
                    },
                    {
                        "conds": {
                            "model_type": "bloom",
                            "pos_query_ratio": 0.0,
                            "query_num": 40000
                        },
                        "vis_info": {
                            "label": "Bloom Filter",
                            "marker": "s",
                            "markersize": 6.5,
                            "linestyle": "-",
                            "color": "tab:blue"
                        }
                    },
                    {
                        "conds": {
                            "model_type": "plbf",
                            "pos_query_ratio": 0.0,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_100_eta_03")
                        },
                        "vis_info": {
                            "label": "PLBF (D=100)",
                            "marker": "D",
                            "markersize": 8,
                            "linestyle": "-",
                            "color": "tab:green"
                        }
                    },
                    {
                        "conds": {
                            "model_type": "plbf",
                            "pos_query_ratio": 0.0,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_10_eta_03")
                        },
                        "vis_info": {
                            "label": "PLBF (D=10)",
                            "marker": "D",
                            "markersize": 6.5,
                            "linestyle": "-.",
                            "color": "tab:green",
                            "alpha": 0.8
                        }
                    },
                    {
                        "conds": {
                            "model_type": "plbf",
                            "pos_query_ratio": 0.0,
                            "query_num": 40000,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_1_eta_03")
                        },
                        "vis_info": {
                            "label": "PLBF (D=1)",
                            "marker": "D",
                            "markersize": 4.0,
                            "linestyle": ":",
                            "color": "tab:green",
                            "alpha": 0.6
                        }
                    }
                ],
                "fig_path": os.path.join(fig_dir, "memory_fpr/memory_fpr_{dataset}_max_depth_" + str(max_depth) + ".pdf")
            }
        )

    max_depth_2_markersize_and_alpha_dict = {
        1: (3.0, 0.6),
        2: (4.0, 0.7),
        4: (5.0, 0.8),
        6: (6.5, 1.0)
    }
    target_list = [
        {
            "conds": {
                "model_type": "bloom",
                "pos_query_ratio": 0.0,
                "query_num": 40000
            },
            "vis_info": {
                "label": "Bloom Filter",
                "marker": "s",
                "markersize": 6.5,
                "linestyle": "-",
                "color": "tab:blue"
            }
        }
    ]
    for max_depth in [1, 2, 4, 6]:
        target_list.append(
            {
                "conds": {
                    "model_type": "clbf",
                    "lambda": 1.0,
                    "mu": 0.0,
                    "pos_query_ratio": 0.0,
                    "query_num": 40000,
                    "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_100_eta_03")
                },
                "vis_info": {
                    "label": "CLBF (Proposed)",
                    "marker": "o",
                    "markersize": max_depth_2_markersize_and_alpha_dict[max_depth][0],
                    "linestyle": "-",
                    "color": "tab:red",
                    "alpha": max_depth_2_markersize_and_alpha_dict[max_depth][1]
                }
            }
        )
        target_list.append(
            {
                "conds": {
                    "model_type": "plbf",
                    "pos_query_ratio": 0.0,
                    "query_num": 40000,
                    "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_10_eta_03")
                },
                "vis_info": {
                    "label": "PLBF (D=10)",
                    "marker": "D",
                    "markersize": max_depth_2_markersize_and_alpha_dict[max_depth][0],
                    "linestyle": "-.",
                    "color": "tab:green",
                    "alpha": max_depth_2_markersize_and_alpha_dict[max_depth][1]
                }
            }
        )
    vis_settings.append(
        {
            "targets_list": target_list,
            "fig_path": os.path.join(fig_dir, "memory_fpr/memory_fpr_{dataset}_max_depth_ablation.pdf")
        }
    )

    for vis_setting in vis_settings:
        plot(result_df, vis_setting)
