import argparse
import matplotlib.pyplot as plt
import os
import pandas as pd
import pprint
import math
from load_all_result import load_all_result

def plot(result_df, vis_setting):
    datasets = result_df["dataset"].unique()

    for dataset in datasets:
        df_d = result_df[result_df["dataset"] == dataset]

        plt.rcParams['axes.axisbelow'] = True
        plt.grid(True,which="both",ls="--",c='gray')

        for target in vis_setting["targets_list"]:
            df_dm = df_d
            for cond_key, cond_val in target["conds"].items():
                if cond_key == "trained_xgboost_folder_path":
                    df_dm = df_dm[df_dm[cond_key] == cond_val.format(dataset=dataset)]
                elif cond_key == "max_depth":
                    df_dm = df_dm[df_dm["trained_xgboost_folder_path"].str.contains(cond_val)]
                else:
                    df_dm = df_dm[df_dm[cond_key] == cond_val]

            ds = df_dm["trained_xgboost_num_boost_round"]
            try:
                total_model_sizes = df_dm["model_size_kb"]
                if "bloom_filter_mem_sum" in df_dm.columns:
                    bf_sizes = df_dm["bloom_filter_mem_sum"]
                    ml_model_sizes = total_model_sizes - bf_sizes
                else:
                    bf_sizes = 0
                    ml_model_sizes = total_model_sizes
                assert not total_model_sizes.isnull().values.any()
                assert not ml_model_sizes.isnull().values.any()
            except:
                total_model_sizes = df_dm["model_size_kb"]
                ml_model_sizes = df_dm["trained_xgboost_model_size_kb"]

            ds, ml_model_sizes, total_model_sizes = zip(*sorted(zip(ds, ml_model_sizes, total_model_sizes)))

            plt.fill_between(
                ds,
                0,
                ml_model_sizes,
                label="XGBoost Model",
                color="tab:orange",
                alpha=0.3
            )

            plt.fill_between(
                ds,
                ml_model_sizes,
                total_model_sizes,
                label="Bloom Filter",
                color="tab:blue",
                alpha=0.3
            )

            plt.plot(
                ds,
                ml_model_sizes,
                marker="x",
                markersize=6.5,
                linestyle="-.",
                color="tab:orange",
            )

            plt.plot(
                ds,
                total_model_sizes,
                marker="x",
                markersize=6.5,
                linestyle="-.",
                color="tab:blue",
            )

            # CLBF
            if "clbf_conds" not in target:
                continue

            clbf_df = df_d
            for cond_key, cond_val in target["clbf_conds"].items():
                if cond_key == "trained_xgboost_folder_path":
                    clbf_df = clbf_df[clbf_df[cond_key] == cond_val.format(dataset=dataset)]
                else:
                    clbf_df = clbf_df[clbf_df[cond_key] == cond_val]

            assert len(clbf_df) == 1
            clbf_total_model_size = clbf_df["model_size_kb"].values[0]
            if "bloom_filter_mem_sum" in clbf_df.columns:
                bf_sizes = clbf_df["bloom_filter_mem_sum"].values[0]
            else:
                bf_sizes = 0
            clbf_ml_model_size = clbf_total_model_size - bf_sizes
            clbf_d = clbf_df["d"].values[0]

            # plt.plot(
            #     [clbf_d, clbf_d],
            #     [0, clbf_total_model_size],
            #     linestyle=":",
            #     color="tab:red",
            #     linewidth=2.5,
            # )
            # plt.plot(
            #     [clbf_d], [clbf_ml_model_size],
            #     marker="o",
            #     markersize=8,
            #     color="tab:red",
            # )
            plt.plot(
                [clbf_d], [clbf_total_model_size],
                marker="o",
                markersize=8,
                color="tab:red",
            )

        if "/plbf/" in vis_setting["fig_path"]:
            plt.xlabel(r"$D$", fontsize=22)
        if "/clbf/" in vis_setting["fig_path"]:
            plt.xlabel(r"$\bar{D}$", fontsize=22)
        plt.ylabel("Memory Usage [kB]", fontsize=22)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        # plt.xscale("log")
        plt.xlim(0, None)
        plt.ylim(0, None)
        # plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=18)

        ax = plt.gca()
        for dir in ['right', 'left', 'top', 'bottom']:
            ax.spines[dir].set_color('black')

        fig_path = vis_setting["fig_path"].format(dataset=dataset)
        os.makedirs(os.path.dirname(fig_path), exist_ok=True)
        # print(f"Saving figure to {fig_path}")
        plt.savefig(fig_path, bbox_inches='tight')
        plt.close()

    # Only plot the legend
    legends = [
        ("tab:orange", "ML Model"),
        ("tab:blue", "Bloom Filter"),
    ]
    if "clbf_conds" in target:
        pointLegends = [
            ("tab:red", "CLBF (Proposed)"),
        ]
    else:
        pointLegends = []
    alpha = 0.3
    legend_handles = []
    legend_labels = []
    for color, label in legends:
        legend_handles.append(plt.Rectangle((0, 0), 1, 1, fc=color, edgecolor='none', alpha=alpha))
        legend_labels.append(label)
    for color, label in pointLegends:
        legend_handles.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=8))
        legend_labels.append(label)
    fig, ax = plt.subplots(figsize=(2.5, 1.0))
    ax.axis('off')
    ax.legend(legend_handles, legend_labels, loc="center", fontsize=16)
    fig_path = vis_setting["fig_path"].split("_{dataset}_")[0] + "_legend.pdf"
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    # print(f"=== Saving figure to {fig_path} ===")
    plt.savefig(fig_path, bbox_inches='tight')
    plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot results and save figures.")
    parser.add_argument("--models_dir", type=str, default="models", help="Directory to save the models. Defaults to 'models'.", required=False)
    parser.add_argument("--results_dir", type=str, default="results", help="Directory to save the results. Defaults to 'result'.", required=False)
    parser.add_argument("--fig_dir", type=str, default="fig", help="Directory to save the figures. Defaults to 'fig'.", required=False)
    args = parser.parse_args()
    models_dir = args.models_dir
    result_dir = args.results_dir
    fig_dir = args.fig_dir

    result_df = load_all_result(result_dir)
    # print(result_df.columns)

    # df_with_not_0_fnr = result_df[result_df["fnr"] != 0]
    # print(df_with_not_0_fnr)

    datasets = result_df["dataset"].unique()
    model_types = result_df["model_type"].unique()
    # print("datasets: ", datasets)
    # print("model_types: ", model_types)
    # print("columns: ", result_df.columns)

    vis_settings = []
    for F in [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]:
        F_str = f"{F}".replace(".", "_")
        vis_settings.append(
        {
            "targets_list": [
                {
                    "conds": {
                        "model_type": "plbf",
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "F": F,
                        "max_depth": "max_depth_4_",
                    },
                    "clbf_conds": {
                        "model_type": "clbf",
                        "lambda": 1.0,
                        "mu": 0.0,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "F": F,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    }
                }
            ],
            "fig_path": os.path.join(fig_dir, "D_model_size_and_bf_size/plbf/D_model_size_and_bf_size_plbf_{dataset}" + f"_F_{F_str}" + ".pdf")
        })
        vis_settings.append(
        {
            "targets_list": [
                {
                    "conds": {
                        "model_type": "clbf",
                        "lambda": 1.0,
                        "mu": 0.0,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "F": F,
                        "max_depth": "max_depth_4_",
                    }
                }
            ],
            "fig_path": os.path.join(fig_dir, "D_model_size_and_bf_size/clbf/D_model_size_and_bf_size_clbf_{dataset}" + f"_F_{F_str}" + ".pdf")
        })

    for F in [0.001]:
        F_str = f"{F}".replace(".", "_")
        for max_depth in [1, 2, 4, 6]:
            vis_settings.append(
            {
                "targets_list": [
                    {
                        "conds": {
                            "model_type": "plbf",
                            "pos_query_ratio": 0.0,
                            "query_num": 40000,
                            "F": F,
                            "max_depth": f"max_depth_{max_depth}_",
                        },
                        "clbf_conds": {
                            "model_type": "clbf",
                            "lambda": 1.0,
                            "mu": 0.0,
                            "pos_query_ratio": 0.0,
                            "query_num": 40000,
                            "F": F,
                            "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_" + str(max_depth) + "_num_boost_round_100_eta_03")
                        }
                    }
                ],
                "fig_path": os.path.join(fig_dir, "D_model_size_and_bf_size/plbf/D_model_size_and_bf_size_plbf_{dataset}" + f"_F_{F_str}" + "_max_depth_" + str(max_depth) + ".pdf")
            })
            vis_settings.append(
            {
                "targets_list": [
                    {
                        "conds": {
                            "model_type": "clbf",
                            "lambda": 1.0,
                            "mu": 0.0,
                            "pos_query_ratio": 0.0,
                            "query_num": 40000,
                            "F": F,
                            "max_depth": f"max_depth_{max_depth}_",
                        }
                    }
                ],
                "fig_path": os.path.join(fig_dir, "D_model_size_and_bf_size/clbf/D_model_size_and_bf_size_clbf_{dataset}" + f"_F_{F_str}" + "_max_depth_" + str(max_depth) + ".pdf")
            })

    for vis_setting in vis_settings:
        plot(result_df, vis_setting)
