import pickle
import networkx as nx
import random 
import numpy as np
import os
import time
import matplotlib.pyplot as plt
import math
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import PercentFormatter
sns.set()
sns.set_style("white")

def set_seed():
    # Set random seeds
  seed = 4
  random.seed(seed)
  np.random.seed(seed + 1)

def latexify():
    matplotlib.rcParams['text.usetex'] = True
    matplotlib.rcParams['text.latex.preamble'] = r"\usepackage{bm} \usepackage{amsmath,amsfonts} \boldmath"
    matplotlib.rcParams['axes.spines.right'] = False
    matplotlib.rcParams['axes.spines.top'] = False
    matplotlib.rcParams['legend.frameon'] = False  # Remove the border from the legend
    plt.rc('font', family='serif', weight='bold')
    plt.rc('xtick', labelsize=28)
    plt.rc('ytick', labelsize=28)
    plt.rc('axes', linewidth=1, labelsize=34, labelweight='bold')
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
    plt.rc('legend', fontsize=20, handlelength=2, loc='lower right',markerscale=1.5,handletextpad=0.3,frameon=True, edgecolor='none')
    plt.rc('lines', markersize=plt.rcParams['lines.markersize'] * 2)

## Cost factors
DotA = 1
DotB = 1
CosA = 1
CosB = 1
AsymA = 1
AsymB = 1
WjacA = 1
WjacB = 1

metric_dict_dec10 = { #"topKscore_10":0,\
                   "num_evals":1,\
                   "conservative_map":2,\
                    "ndcg5":3,\
                    "ndcg10":4,\
                    "ndcg100":5,\
                    "ndcg1000":6,\
                    "ndcg5000":7,\
                    "ndcg10000":8,\
                    "ndcg20000":9}

metric_to_label =  { #"topKscore_10":r"Total Score @ Top-10 $\longrightarrow$",\
                #    "num_evals":1,\
                   "conservative_map":r"\textbf{MAP} $\longrightarrow$",\
                    "ndcg5":r"\textbf{NDCG} @ \textbf{5} $\longrightarrow$",\
                    "ndcg10":r"\textbf{NDCG} @ \textbf{10} $\longrightarrow$",\
                    "ndcg100":r"\textbf{NDCG} @ \textbf{100} $\longrightarrow$",\
                    "ndcg1000":r"\textbf{NDCG} @ \textbf{1K} $\longrightarrow$",\
                    "ndcg5000":r"\textbf{NDCG} @ $\mathbf{5K}$ $\longrightarrow$",\
                    "ndcg10000":r"\textbf{NDCG} @ $\mathbf{10K}$ $\longrightarrow$",\
                    "ndcg20000":r"\textbf{NDCG} @ $\mathbf{20K}$ $\longrightarrow$"}

x_label_time = r"\textbf{\# Retrieved Graphs} $\longrightarrow$"

filename_dict = {} 
for rel_mode in [ "sub_iso", "ged","uneq_ged"] : 
    filename_dict[rel_mode] = {}
    for dname  in [ "ptc_fr", "ptc_fm", "ptc_mr", "cox2"]:
        filename_dict[rel_mode][dname] = {}

        filename_dict[rel_mode][dname]["NANL-SA-GhashTr-SmED-QA-BCE3-same"] = f"{dname}_{rel_mode}_NANL_None_Ghash2_Trained_query_aware8BCE3muse10_parallel3_med.pkl"
        filename_dict[rel_mode][dname]["NANL-SA-GhashTr-SmED-QA-BCE3-diff"] = f"{dname}_{rel_mode}_NANL_None_Ghash2_Trained_query_aware8BCE3muse10_SubsetType_different_parallel3_med.pkl"
        filename_dict[rel_mode][dname]["NANL-SA-GhashTr-SmED-QA-BCE3L-same"] = f"{dname}_{rel_mode}_NANL_None_Ghash2_Trained_query_aware8BCE3-Lmuse10_FmapLoss_BCE3-L_parallel3_med.pkl"
        filename_dict[rel_mode][dname]["NANL-SA-GhashTr-SmED-QA-BCE3L-diff"] = f"{dname}_{rel_mode}_NANL_None_Ghash2_Trained_query_aware8BCE3-Lmuse10_SubsetType_different_FmapLoss_BCE3-L_parallel3_med.pkl"

        filename_dict[rel_mode][dname]["NANL-SA-GhashTr-SmNone-QA-BCE3-same"] = f"{dname}_{rel_mode}_NANL_None_Ghash2_Trained_query_aware8BCE3muse10_SmNone_parallel3_med.pkl"
        filename_dict[rel_mode][dname]["NANL-SA-GhashTr-SmNone-QA-BCE3-diff"] = f"{dname}_{rel_mode}_NANL_None_Ghash2_Trained_query_aware8BCE3muse10_SubsetType_different_SmNone_parallel3_med.pkl"
        filename_dict[rel_mode][dname]["NANL-SA-GhashTr-SmNone-QA-BCE3L-same"] = f"{dname}_{rel_mode}_NANL_None_Ghash2_Trained_query_aware8BCE3-Lmuse10_SmNone_FmapLoss_BCE3-L_parallel3_med.pkl"
        filename_dict[rel_mode][dname]["NANL-SA-GhashTr-SmNone-QA-BCE3L-diff"] = f"{dname}_{rel_mode}_NANL_None_Ghash2_Trained_query_aware8BCE3-Lmuse10_SubsetType_different_SmNone_FmapLoss_BCE3-L_parallel3_med.pkl"
 
        filename_dict[rel_mode][dname]["NANL-SA-RHTr-QA"] = f"{dname}_{rel_mode}_NANL_cos_RH_Trained_query_aware8_SepEval_SA_parallel3_med.pkl"
        filename_dict[rel_mode][dname]["GEN-cos-RHTr-QA"] = f"{dname}_{rel_mode}_GEN_cos_RH_Trained_query_aware8_parallel3_med.pkl"

        filename_dict[rel_mode][dname]["NANL-SA-FhashTr-QA-BCE3"] = f"{dname}_{rel_mode}_NANL_hinge_Fhash_Trained_query_aware8BCE3muse10_SepEval_SA_parallel3_med.pkl"
        filename_dict[rel_mode][dname]["GEN-hinge-FhashTr-QA-BCE3"] = f"{dname}_{rel_mode}_GEN_hinge_Fhash_Trained_query_aware8BCE3muse10_parallel3_med.pkl"

        filename_dict[rel_mode][dname]["NANL-SA-DPRHTr-QA"] = f"{dname}_{rel_mode}_NANL_dot_DPRH_Trained_query_aware8_SepEval_SA_parallel3_med.pkl"
        filename_dict[rel_mode][dname]["GEN-dot-DPRHTr-QA"] = f"{dname}_{rel_mode}_GEN_dot_DPRH_Trained_query_aware8_parallel3_med.pkl"

        filename_dict[rel_mode][dname]['faiss_cos'] =  f"{dname}_{rel_mode}_NANL_Faiss_cosine.pkl" 
        filename_dict[rel_mode][dname]['faiss_l2'] =  f"{dname}_{rel_mode}_NANL_Faiss_l2.pkl"
        
        filename_dict[rel_mode][dname]['diskann_cos'] =  f"{dname}_{rel_mode}_NANL_DiskANN_cosine.pkl" 
        filename_dict[rel_mode][dname]['diskann_l2'] =  f"{dname}_{rel_mode}_NANL_DiskANN_l2.pkl"       
        
        filename_dict[rel_mode][dname]['random'] =  f"{dname}_{rel_mode}_NANL_Random.pkl"  
            
dir_loc = "./allPklDumps/hashResultPickles"
# check if all the files are present
for rel_mode in filename_dict.keys():
    for dname in filename_dict[rel_mode].keys():
        for key in filename_dict[rel_mode][dname].keys():
            if not os.path.exists(f"{dir_loc}/{filename_dict[rel_mode][dname][key]}"):
                print(f"{dir_loc}/{filename_dict[rel_mode][dname][key]}")
                print("File not found")
                # assert False
                
rel_modes = ["sub_iso", "ged", "uneq_ged"]
datasets = ["ptc_fr", "ptc_mr", "ptc_fm", "cox2"]

NANL_SA_GhashTr_SmED_QA_BCE3_same = {rel_mode: {} for rel_mode in rel_modes}
NANL_SA_GhashTr_SmED_QA_BCE3_diff = {rel_mode: {} for rel_mode in rel_modes}
NANL_SA_GhashTr_SmED_QA_BCE3L_same = {rel_mode: {} for rel_mode in rel_modes}
NANL_SA_GhashTr_SmED_QA_BCE3L_diff = {rel_mode: {} for rel_mode in rel_modes}

NANL_SA_GhashTr_SmNone_QA_BCE3_same = {rel_mode: {} for rel_mode in rel_modes}
NANL_SA_GhashTr_SmNone_QA_BCE3_diff = {rel_mode: {} for rel_mode in rel_modes}
NANL_SA_GhashTr_SmNone_QA_BCE3L_same = {rel_mode: {} for rel_mode in rel_modes}
NANL_SA_GhashTr_SmNone_QA_BCE3L_diff = {rel_mode: {} for rel_mode in rel_modes}

NANL_SA_RHTr_QA = {rel_mode: {} for rel_mode in rel_modes}
GEN_cos_RHTr_QA = {rel_mode: {} for rel_mode in rel_modes}

NANL_SA_FhashTr_QA_BCE3 = {rel_mode: {} for rel_mode in rel_modes}
GEN_hinge_FhashTr_QA_BCE3 = {rel_mode: {} for rel_mode in rel_modes}

NANL_SA_DPRHTr_QA = {rel_mode: {} for rel_mode in rel_modes}
GEN_dot_DPRHTr_QA = {rel_mode: {} for rel_mode in rel_modes}

faiss_cos = {rel_mode: {} for rel_mode in rel_modes}
faiss_l2 = {rel_mode: {} for rel_mode in rel_modes}
diskann_cos = {rel_mode: {} for rel_mode in rel_modes}
diskann_l2 = {rel_mode: {} for rel_mode in rel_modes}
random_select = {rel_mode: {} for rel_mode in rel_modes}

dir_loc = "./allPklDumps/hashResultPickles"

for rel_mode in rel_modes:
    for dname in datasets:
        NANL_SA_GhashTr_SmED_QA_BCE3_same[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-GhashTr-SmED-QA-BCE3-same']}", "rb"))
        NANL_SA_GhashTr_SmED_QA_BCE3_diff[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-GhashTr-SmED-QA-BCE3-diff']}", "rb"))
        NANL_SA_GhashTr_SmED_QA_BCE3L_same[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-GhashTr-SmED-QA-BCE3L-same']}", "rb"))
        NANL_SA_GhashTr_SmED_QA_BCE3L_diff[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-GhashTr-SmED-QA-BCE3L-diff']}", "rb"))

        NANL_SA_GhashTr_SmNone_QA_BCE3_same[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-GhashTr-SmNone-QA-BCE3-same']}", "rb"))
        NANL_SA_GhashTr_SmNone_QA_BCE3_diff[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-GhashTr-SmNone-QA-BCE3-diff']}", "rb"))
        NANL_SA_GhashTr_SmNone_QA_BCE3L_same[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-GhashTr-SmNone-QA-BCE3L-same']}", "rb"))
        NANL_SA_GhashTr_SmNone_QA_BCE3L_diff[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-GhashTr-SmNone-QA-BCE3L-diff']}", "rb"))

        NANL_SA_RHTr_QA[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-RHTr-QA']}", "rb"))
        GEN_cos_RHTr_QA[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['GEN-cos-RHTr-QA']}", "rb"))

        NANL_SA_FhashTr_QA_BCE3[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-FhashTr-QA-BCE3']}", "rb"))
        GEN_hinge_FhashTr_QA_BCE3[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['GEN-hinge-FhashTr-QA-BCE3']}", "rb"))

        NANL_SA_DPRHTr_QA[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['NANL-SA-DPRHTr-QA']}", "rb"))
        GEN_dot_DPRHTr_QA[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['GEN-dot-DPRHTr-QA']}", "rb"))

        faiss_cos[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['faiss_cos']}", "rb"))
        faiss_l2[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['faiss_l2']}", "rb"))
        diskann_cos[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['diskann_cos']}", "rb"))
        diskann_l2[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['diskann_l2']}", "rb"))

        random_select[rel_mode][dname] = pickle.load(open(f"{dir_loc}/{filename_dict[rel_mode][dname]['random']}", "rb"))

xmetric = "num_evals"

########################################################

def plot_1(dname, rel_mode, metric, xmetric, save_legend=False):
    fig = plt.figure(figsize=(6, 6))

    sns.scatterplot(x=AsymA * (np.array(NANL_SA_GhashTr_SmED_QA_BCE3L_diff[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(NANL_SA_GhashTr_SmED_QA_BCE3L_diff[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="Red", marker='o',
                    label=r"\textbf{\textsc{GraphHash}}")

    sns.scatterplot(x=AsymA * (np.array(NANL_SA_FhashTr_QA_BCE3[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(NANL_SA_FhashTr_QA_BCE3[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="limegreen", marker='s',
                    label=r"\textbf{\textsc{FourierHashNet}}")

    sns.scatterplot(x=AsymA * (np.array(NANL_SA_RHTr_QA[rel_mode][dname])[::6, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(NANL_SA_RHTr_QA[rel_mode][dname])[::6, metric_dict_dec10[metric]], 
                    color="Magenta", marker='s',
                    label=r"\textbf{RH (Subsampled)}")

    sns.scatterplot(x=AsymA * (np.array(diskann_l2[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(diskann_l2[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="Orange", marker='v',
                    label=r"\textbf{DiskANN}")

    sns.scatterplot(x=AsymA * (np.array(faiss_l2[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(faiss_l2[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="Blue", marker='^',
                    label=r"\textbf{IVF}")

    sns.scatterplot(x=AsymA * (np.array(random_select[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(random_select[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="Grey", marker='o',
                    label=r"\textbf{Random}")

    if metric == "conservative_map":
        plt.axhline(y=NANL_SA_GhashTr_SmED_QA_BCE3_same[rel_mode][dname][0][2] / 2, color='r', linestyle='--', linewidth=2)

    output_dir = "./perf"
    os.makedirs(output_dir, exist_ok=True)

    if save_legend:
        handles, labels = plt.gca().get_legend_handles_labels()
        legend_fig = plt.figure(figsize=(6, 1))
        legend_ax = legend_fig.add_subplot(111)
        legend_ax.axis('off')
        legend = legend_ax.legend(handles, labels, loc='center', frameon=False, ncol=6)
        legend_fig.savefig(f"{output_dir}/legend_1.pdf", bbox_inches='tight', dpi=300)
        return
    plt.gca().get_legend().remove()
    plt.ylabel(metric_to_label[metric], labelpad=3)
    plt.xlabel(x_label_time, labelpad=3)
    plt.xscale('linear')
    plt.xticks([0, 25000, 50000, 75000, 100000], [r"\textbf{0}", r"\textbf{25K}", r"\textbf{50K}", r"\textbf{75K}", r"\textbf{100K}"])
    # plt.xticks([0, 20000, 40000, 60000, 80000, 100000], [r"\textbf{0}", r"\textbf{20K}", r"\textbf{40K}", r"\textbf{60K}", r"\textbf{80K}", r"\textbf{100K}"])
    plt.grid(axis='y', linestyle='-', linewidth=1)
    plt.grid(axis='x', linestyle='-', linewidth=1)

    
    output_file = os.path.join(output_dir, f"{dname}_{rel_mode}_{metric}_1_nored.pdf")
    plt.savefig(output_file, bbox_inches='tight', dpi=300)
    plt.close(fig)

### TODO: scrub this off, as we do not include this?
def plot_2(dname, rel_mode, metric, xmetric,save_legend=False):
    fig = plt.figure(figsize=(6, 6))

    sns.scatterplot(x=AsymA * (np.array(NANL_SA_GhashTr_SmED_QA_BCE3L_diff[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(NANL_SA_GhashTr_SmED_QA_BCE3L_diff[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="Red", marker='o',
                    label=r"\textsc{GraphHash}")

    # sns.scatterplot(x=AsymA * (np.array(NANL_SA_FhashTr_QA_BCE3[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
    #                 y=np.array(NANL_SA_FhashTr_QA_BCE3[rel_mode][dname])[:, metric_dict_dec10[metric]], 
    #                 color="limegreen", marker='s',
    #                 label=r"NANL $+$ \textsc{FourierHashNet}")

    sns.scatterplot(x=AsymA * (np.array(GEN_hinge_FhashTr_QA_BCE3[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(GEN_hinge_FhashTr_QA_BCE3[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="Orange", marker='s',
                    label=r"GEN $+$ \textsc{FourierHashNet}")

    # sns.scatterplot(x=AsymA * (np.array(GEN_dot_DPRHTr_QA[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
    #                 y=np.array(GEN_dot_DPRHTr_QA[rel_mode][dname])[:, metric_dict_dec10[metric]], 
    #                 color="Cyan", marker='s',
    #                 label=r"GEN $+$ Dot Product RH")

    sns.scatterplot(x=AsymA * (np.array(GEN_cos_RHTr_QA[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(GEN_cos_RHTr_QA[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="Magenta", marker='s',
                    label=r"GEN $+$ RH Hashing")

    if metric == "conservative_map":
        plt.axhline(y=NANL_SA_GhashTr_SmED_QA_BCE3_same[rel_mode][dname][0][2] / 2, color='r', linestyle='--', linewidth=2)

    output_dir = "./perf"
    os.makedirs(output_dir, exist_ok=True)
    
    if save_legend:
        handles, labels = plt.gca().get_legend_handles_labels()
        legend_fig = plt.figure(figsize=(6, 1))
        legend_ax = legend_fig.add_subplot(111)
        legend_ax.axis('off')
        legend = legend_ax.legend(handles, labels, loc='center', frameon=False, ncol=6)
        legend_fig.savefig(f"{output_dir}/legend_2.pdf", bbox_inches='tight', dpi=300)
        return
    plt.gca().get_legend().remove()
    plt.ylabel(metric_to_label[metric], labelpad=3)
    plt.xlabel(x_label_time, labelpad=3)
    plt.xscale('linear')
    plt.xticks([0, 25000, 50000, 75000, 100000], [r"\textbf{0}", r"\textbf{25K}", r"\textbf{50K}", r"\textbf{75K}", r"\textbf{100K}"])
    plt.grid(axis='y', linestyle='-', linewidth=1)
    plt.grid(axis='x', linestyle='-', linewidth=1)

    
    output_file = os.path.join(output_dir, f"{dname}_{rel_mode}_{metric}_2_nored.pdf")
    plt.savefig(output_file, bbox_inches='tight', dpi=300)
    plt.close(fig)

def plot_3(dname, rel_mode, metric, xmetric, save_legend=False):
    fig = plt.figure(figsize=(6, 6))

    sns.scatterplot(x=AsymA * (np.array(NANL_SA_GhashTr_SmED_QA_BCE3L_diff[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(NANL_SA_GhashTr_SmED_QA_BCE3L_diff[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="Red", marker='o',
                    label=r"\textbf{\textsc{GraphHash}}")

    sns.scatterplot(x=AsymA * (np.array(NANL_SA_RHTr_QA[rel_mode][dname])[:, metric_dict_dec10[xmetric]]) + AsymB,
                    y=np.array(NANL_SA_RHTr_QA[rel_mode][dname])[:, metric_dict_dec10[metric]], 
                    color="Magenta", marker='s',
                    label=r"\textbf{RH}")



    if metric == "conservative_map":
        plt.axhline(y=NANL_SA_GhashTr_SmED_QA_BCE3_same[rel_mode][dname][0][2] / 2, color='r', linestyle='--', linewidth=2)

    output_dir = "./perf"
    os.makedirs(output_dir, exist_ok=True)

    if save_legend:
        handles, labels = plt.gca().get_legend_handles_labels()
        legend_fig = plt.figure(figsize=(6, 1))
        legend_ax = legend_fig.add_subplot(111)
        legend_ax.axis('off')
        legend = legend_ax.legend(handles, labels, loc='center', frameon=False, ncol=6)
        legend_fig.savefig(f"{output_dir}/legend_1.pdf", bbox_inches='tight', dpi=300)
        return
    plt.gca().get_legend().remove()
    plt.ylabel(metric_to_label[metric], labelpad=3)
    plt.xlabel(x_label_time, labelpad=3)
    plt.xscale('linear')
    plt.xticks([0, 25000, 50000, 75000, 100000], [r"\textbf{0}", r"\textbf{25K}", r"\textbf{50K}", r"\textbf{75K}", r"\textbf{100K}"])
    # plt.xticks([0, 20000, 40000, 60000, 80000, 100000], [r"\textbf{0}", r"\textbf{20K}", r"\textbf{40K}", r"\textbf{60K}", r"\textbf{80K}", r"\textbf{100K}"])
    plt.grid(axis='y', linestyle='-', linewidth=1)
    plt.grid(axis='x', linestyle='-', linewidth=1)

    
    output_file = os.path.join(output_dir, f"{dname}_{rel_mode}_{metric}_3_nored.pdf")
    plt.savefig(output_file, bbox_inches='tight', dpi=300)
    plt.close(fig)
    

if __name__ == "__main__":
    latexify()

    # plot_1("ptc_fr", "sub_iso", "ndcg10000", xmetric, save_legend=True)
    for  idx1 in range(len(datasets)):
        dname  = datasets[idx1]
        for rel_mode in ["uneq_ged", "sub_iso", "ged"]:
            # for metric in ("conservative_map", "ndcg10000","ndcg1000", "ndcg5000"):
            for metric in ["ndcg1000"]:
                # plot_1(dname,rel_mode,metric,xmetric)
                # plot_2(dname,rel_mode,metric,xmetric)
                plot_3(dname,rel_mode,metric,xmetric)
              
