import argparse
import matplotlib.pyplot as plt
import os
import pandas as pd
import pprint
import math
from load_all_result import load_all_result

def plot(result_df, vis_setting, ratio = False):
    datasets = result_df["dataset"].unique()

    for dataset in datasets:
        df_d = result_df[result_df["dataset"] == dataset]

        plt.rcParams['axes.axisbelow'] = True
        plt.grid(True,which="both",ls="--",c='gray')

        for target in vis_setting["targets_list"]:
            df_dm = df_d
            for cond_key, cond_val in target["conds"].items():
                if cond_key == "trained_xgboost_folder_path":
                    df_dm = df_dm[df_dm[cond_key] == cond_val.format(dataset=dataset)]
                else:
                    df_dm = df_dm[df_dm[cond_key] == cond_val]

            df_dm = df_dm.sort_values(by="F")
            fprs = df_dm["F"]
            # fprs = df_dm["fpr"]
            total_model_sizes = df_dm["model_size_kb"]
            if "bloom_filter_mem_sum" in df_dm.columns:
                bf_sizes = df_dm["bloom_filter_mem_sum"]
            else:
                bf_sizes = 0
            ml_model_sizes = total_model_sizes - bf_sizes
            # untrimmed_ml_model_sizes = df_dm["trained_xgboost_model_size_kb"]
            # estimated_ml_model_sizes = df_dm["trained_xgboost_model_size_kb"] * (df_dm["d"] / df_dm["trained_xgboost_num_boost_round"])

            if ratio:
                ml_model_sizes = ml_model_sizes / total_model_sizes
                total_model_sizes = total_model_sizes / total_model_sizes

            plt.fill_between(
                fprs,
                0,
                ml_model_sizes,
                label="XGBoost Model",
                color="tab:orange",
                alpha=0.3
            )

            plt.fill_between(
                fprs,
                ml_model_sizes,
                total_model_sizes,
                label="Bloom Filter",
                color="tab:blue",
                alpha=0.3
            )

            plt.plot(
                fprs,
                ml_model_sizes,
                marker="x",
                markersize=6.5,
                linestyle="-.",
                color="tab:orange",
            )

            plt.plot(
                fprs,
                total_model_sizes,
                marker="x",
                markersize=6.5,
                linestyle="-.",
                color="tab:blue",
            )

            # plt.plot(
            #     fprs,
            #     untrimmed_ml_model_sizes,
            #     label="(Untrimmed) XGBoost Model Size",
            #     color="tab:green",
            #     alpha=0.5,
            #     linestyle="--",
            #     marker="o",
            #     markersize=6.5
            # )

            # plt.plot(
            #     fprs,
            #     estimated_ml_model_sizes,
            #     label="Estimated XGBoost Model Size",
            #     color="tab:red",
            #     alpha=0.5,
            #     linestyle="--",
            #     marker="o",
            #     markersize=6.5
            # )

        plt.xlabel(r"$F$", fontsize=22)
        if ratio:
            plt.ylabel("Memory Usage Ratio", fontsize=22)
            plt.ylim(0, 1.1)
        else:
            plt.ylabel("Memory Usage [kB]", fontsize=22)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.xscale("log")
        plt.ylim(0, None)
        # plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=18)

        ax = plt.gca()
        for dir in ['right', 'left', 'top', 'bottom']:
            ax.spines[dir].set_color('black')

        fig_path = vis_setting["fig_path"].format(dataset=dataset)
        if ratio:
            fig_path = fig_path.replace(".", "_ratio.")
        os.makedirs(os.path.dirname(fig_path), exist_ok=True)
        print(f"Saving figure to {fig_path}")
        plt.savefig(fig_path, bbox_inches='tight')
        plt.close()

    # Only plot the legend
    legends = [
        ("tab:orange", "XGBoost Model"),
        ("tab:blue", "Bloom Filter"),
    ]
    alpha = 0.3
    legend_handles = []
    for color, label in legends:
        legend_handles.append(plt.Rectangle((0, 0), 1, 1, fc=color, edgecolor='none', alpha=alpha))
    legend_labels = [label for color, label in legends]
    fig, ax = plt.subplots(figsize=(3.5, 1.0))
    ax.axis('off')
    ax.legend(legend_handles, legend_labels, loc="center", fontsize=16)
    fig_path = vis_setting["fig_path"].replace("{dataset}", "legend")
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    print(f"=== Saving figure to {fig_path} ===")
    plt.savefig(fig_path, bbox_inches='tight')
    plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot results and save figures.")
    parser.add_argument("--models_dir", type=str, default="models", help="Directory to save the models. Defaults to 'models'.", required=False)
    parser.add_argument("--results_dir", type=str, default="results", help="Directory to save the results. Defaults to 'result'.", required=False)
    parser.add_argument("--fig_dir", type=str, default="fig", help="Directory to save the figures. Defaults to 'fig'.", required=False)
    args = parser.parse_args()
    models_dir = args.models_dir
    result_dir = args.results_dir
    fig_dir = args.fig_dir

    result_df = load_all_result(result_dir)
    # print(result_df.columns)

    # df_with_not_0_fnr = result_df[result_df["fnr"] != 0]
    # print(df_with_not_0_fnr)

    datasets = result_df["dataset"].unique()
    model_types = result_df["model_type"].unique()
    # print("datasets: ", datasets)
    # print("model_types: ", model_types)
    # print("columns: ", result_df.columns)

    vis_settings = [
        {
            "targets_list": [
                {
                    "conds": {
                        "model_type": "clbf",
                        "lambda": 1.0,
                        "mu": 0.0,
                        "pos_query_ratio": 0.0,
                        "query_num": 40000,
                        "trained_xgboost_folder_path": os.path.join(models_dir, "{dataset}/xgboost/max_depth_4_num_boost_round_100_eta_03")
                    },
                    "vis_info": {
                        "marker": "x",
                        "markersize": 6.5,
                        "linestyle": "-"
                    }
                }
            ],
            "fig_path": os.path.join(fig_dir, "fpr_memory_ratio/fpr_memory_ratio_{dataset}.pdf")
        },
    ]

    for vis_setting in vis_settings:
        plot(result_df, vis_setting)
        # plot(result_df, vis_setting, ratio=True)
