#!/usr/bin/env python3

import os
import json
import numpy as np
import pandas as pd 
from functools import reduce
from adjustText import adjust_text
from pandas.io.json import json_normalize
import matplotlib.pyplot as plt

model_rename_map = {
    "CBOW_NEG": "CBOW",
    "CBOW_BTLx_Sequential": "BTLx",
    "CBOW_BTL_Sequential": "CBOWBTL",
    "CBOW_BTL_Aggregate": "BTLa",
    "SkipGram_NEG": "SG",
    "SG_BTL_Sequential": "SGBTL",
    "CBOW_HybridModel": "Hybrid"
}

column_rename_map = {
    "accuracy_seqadd": "ADD(b-a+a*)",
    "accuracy_seqadd_1": "ADD(b+a*-a)",
    "accuracy_seqadd_2": "ADD(a*-a+b)",
    "accuracy_seqadd_3": "ADD(a*+b-a)",
    "accuracy_seqadd_4": "ADD(-a+b+a*)",
    "accuracy_seqadd_5": "ADD(-a+a*+b)",
    "accuracy_3cosadd": "ADD",
    "accuracy_seqadd_l1": "L1(b-a+a*)",
    "accuracy_seqadd_l1_1": "L1(b+a*-a)",
    "accuracy_seqadd_l1_2": "L1(a*-a+b)",
    "accuracy_seqadd_l1_3": "L1(a*+b-a)",
    "accuracy_seqadd_l1_4": "L1(-a+b+a*)",
    "accuracy_seqadd_l1_5": "L1(-a+a*+b)",
    "accuracy_seqadd_ignore_a": "IA(b+a*)",
    "accuracy_seqadd_ignore_a_1": "IA(a*+b)",
    "accuracy_ignore_a": "IA",
    "accuracy_seqadd_only_b": "OB ",
    "accuracy_only_b": "OB",
    "accuracy_seqadd_add_opposite": "AO(b-a*+a)",
    "accuracy_seqadd_add_opposite_1": "AO(b+a-a*)",
    "accuracy_seqadd_add_opposite_2": "AO(a-a*+b)",
    "accuracy_seqadd_add_opposite_3": "AO(a+b-a*)",
    "accuracy_seqadd_add_opposite_4": "AO(-a*+b+a)",
    "accuracy_seqadd_add_opposite_5": "AO(-a*+a+b)",
    "accuracy_add_opposite": "AO",
    "accuracy_seqadd_reverse": "RV(b*-a*+a(=b))",
    "accuracy_seqadd_reverse_1": "RV(b*+a-a*(=b))",
    "accuracy_seqadd_reverse_2": "RV(a-a*+b*(=b))",
    "accuracy_seqadd_reverse_3": "RV(a+b*-a*(=b))",
    "accuracy_seqadd_reverse_4": "RV(-a*+b*+a(=b))",
    "accuracy_seqadd_reverse_5": "RV(-a*+a+b*(=b))",
    "accuracy_reverse": "RV",
    "accuracy_seqadd_reverse_only_b": "RVB(b*(=b))",
    "accuracy_reverse_only_b": "RVB",
    "accuracy_pairdirection": "PD",
    "accuracy_3cosmul": "MUL",
}

irrelevant_keys = [
    # truly irrelevant (e.g. fluid) variables
    "mode","start_epoch","path","force_reload",
    # actually relevant, but are removed since we dont explore these hyperparameters
    "comment",
    "shuffle",
    "subsampling",
    "subsampling_threshold",
    # "epochs",
    # "batch_size",
    # "lr",
    "context_before",
    "context_after",
    # "embedding",
    # "initialization",
    "negative_sample",
    "linear_dim",
    "loss",
    "threads",
    # "model",
    "optimizer",
    "min_occurrence",
    "train",
    "valid",
    "test",
    # "annealing_schedule",
    "annealing_max",
    "annealing_min",
    # "annealing_start",
    "annealing_end",
    "straight_through",
    # "beta",
    "initial_state",
    # "affine",
    "noise"
]

def rename_old_to_new(model_folder, model_folder_path):
    rename_dict = {f"{model_folder}.json": "hyper.json",
                   f"{model_folder}.model": "model.json",
                   f"{model_folder}-time-memory.json": "time-memory.json",
                   f"{model_folder}-performance.json": "performance.json"}
    for old, new in rename_dict.items():
        if old in os.listdir(model_folder_path):
            os.rename(os.path.join(model_folder_path,old), 
                      os.path.join(model_folder_path,new))


def analogy_heat_map(df, feature, topk, aggregation="mean", exclusion="exclude_original", plot=False):
    """
    plot heatmap from analogy dataframe
    :param df: analogy dataframe, each row is a different model run
    :param feature: one of [method, category]
    :param topk: str, which topk accuracy to use as [top1, top5, top10]
    :param exclusion: str, [exclude_original, include_original]
    """
    df.index = df["model"]
    df = df.sort_index()

    drop_cols = []
    for col in df.columns:
        if (col.split(".")[-1] != ("accuracy_"+topk)) or (exclusion not in col) or ("final" in col):
            drop_cols.append(col)
    model_fullname = df["model_fullname"]
    df = df.drop(columns=drop_cols)
    pivot_df = df.copy()

    if feature == "category":
        categories = set([col.split(".")[0] for col in df.columns])
        for category in categories:
            cols = [col for col in df.columns if category in col]
            pivot_df[category] = df[cols].agg(aggregation, axis="columns")
            pivot_df = pivot_df.drop(columns=cols)
            df = df.drop(columns=cols)

    elif feature == "method":
        methods = set([col.split(".")[3] for col in df.columns])
        # print("methods:", methods)
        for method in methods:
            # for col in df.columns:
            #     print(col.split("."))
            #     print(col.split(".")[3])
            cols = [col for col in df.columns if method==col.split(".")[3]]
            pivot_df[method] = df[cols].agg(aggregation, axis="columns")
            pivot_df = pivot_df.drop(columns=cols)
            df = df.drop(columns=cols)

        # combine discrete and non discrete methods with their parallel version
        if "CBOW_NEG" in pivot_df.index:
            del_cols = []
            for col in pivot_df.columns: 
                if "seqadd" in col and (col not in ["accuracy_seqadd_l1", "accuracy_seqadd"]):
                    pivot_df[col]["CBOW_NEG"] = pivot_df[col.replace("seqadd_","")]["CBOW_NEG"]
                    del_cols.append(col.replace("seqadd_",""))
            pivot_df["accuracy_seqadd"]["CBOW_NEG"] = pivot_df["accuracy_3cosadd"]["CBOW_NEG"]
            del_cols.append("accuracy_3cosadd")
            pivot_df = pivot_df.drop(columns=del_cols)

        # rename all discrete methods to regular method names
        rename_dict = {}
        for col in pivot_df.columns:
            rename_dict[col] = col.replace("accuracy_","").replace("seqadd_","")
        rename_dict["accuracy_seqadd"]="3cosadd"
        rename_dict["accuracy_seqadd_l1"]="seqadd_l1"
        pivot_df = pivot_df.rename(columns=rename_dict)

    else: # plotting model specific, category by feature, no aggregation
        # print(pivot_df)
        df = df[model_fullname==feature]
        pivot_df = []
        for col in df.columns:
            pivot_df.append({
                "category": col.split('.')[0],
                "method": col.split('.')[3],
                "accuracy": df[col][0]
            })
        pivot_df = pd.DataFrame(pivot_df)
        pivot_df = pivot_df.pivot(index="category", columns="method", values="accuracy")

    if feature in ["category", "method"]:
        title = f"Analogy_accuracy_{topk}_by_{feature}_{exclusion}_agg_by_{aggregation}"
        pivot_df.index = pivot_df.index.map(model_rename_map)
        pivot_df = pivot_df.swapaxes("index", "columns") 
    else:
        title = f"analogy_accuracy_{topk}_{exclusion}_{feature}"
        pivot_df.columns = pivot_df.columns.map(column_rename_map)
        if "Sequential" in feature:
            pivot_df = pivot_df.drop(columns=["ADD", "IA", "AO", "RV", "ADD", "MUL", "RVB", "OB", "PD"])
        
    pivot_df = pivot_df.round(4)
    pivot_df.to_csv(f"{output_path}/{title}.csv")
        
    if plot:
        import seaborn as sns; sns.set()
        plt.figure(figsize=(16, 9))
        ax = sns.heatmap(pivot_df, annot=True, annot_kws={"size": 7}, xticklabels="auto")
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
        plt.tight_layout()
        plt.title(title)
        # plt.show()
        plt.savefig(f"{output_path}/{title}.png", format='png', dpi=150, bbox_inches='tight')
        plt.close()


def analogy_scatter_plot(df, model_name_pairs, method_name_pairs, topk, exclusion="exclude_original", plot=False, legends=[]):
    """
    plot heatmap from analogy dataframe
    :param df: analogy dataframe, each row is a different model run
    :param model_name_pairs: list of tuple of str, each tuple represents two models to plot as x and y for comparison
    :param method_name_pairs: list of tuple of str, each tuple represents method names for each models to extract accuracy
    :param topk: str, which topk accuracy to use as [top1, top5, top10]
    :param exclusion: str, [exclude_original, include_original]
    """
    df.index = df["model"]
    df = df.sort_index()

    # drop_cols = []
    df = df[df["exclude?"]==exclusion]
    df = df[df["scoretype"]==f"accuracy_{topk}"]
    df = df[df["embedding"]=="embedding0"]
    print(df.head())
    print(df.columns)

    accuracies = {}
    categories = np.sort(df["track"].unique())
    for model_names, method_names in zip(model_name_pairs, method_name_pairs):
        print("method_names", method_names)
        df1 = df[df["path"] == model_names[0]]
        df1 = df1[df1["method"] == method_names[0]]
        df2 = df[df["path"] == model_names[1]]
        df2 = df2[df2["method"] == method_names[1]]

        accuracies[model_names[0]+method_names[0]] = df1.sort_values(by="track")["score"].values
        accuracies[model_names[1]+method_names[1]] = df2.sort_values(by="track")["score"].values

    title = f"analogy_accuracy_between_{model_names[0]}" +\
    f"_{column_rename_map[method_names[0]]}_and_{model_names[0]}_{column_rename_map[method_names[1]]}_{exclusion}"

    def arrow(src,dst,color):
        plt.arrow(x=src[0], y=src[1],
                  dx=dst[0]-src[0], dy=dst[1]-src[1], 
                  color=color, 
                  length_includes_head=True)
    if plot:
        texts = []
        plt.figure(figsize=(9, 9))
        colors = ["red", "blue", "green", "brown", "black"]
        for idx, model_names, method_names in zip(range(len(model_name_pairs)), model_name_pairs, method_name_pairs):
            plt.scatter(model_names[0]+method_names[0],model_names[1]+method_names[1],data=accuracies,c=colors[idx])
            for i,c in enumerate(categories):
                texts.append(
                    plt.annotate(c, (accuracies[model_names[0]+method_names[0]][i], accuracies[model_names[1]+method_names[1]][i]+0.01),xytext=(5, 2),
                             textcoords='offset points', ha='right', va='bottom',c=colors[idx])
                )
                if idx > 0:
                    # previous
                    model_names_p = model_name_pairs[idx-1]
                    method_names_p = method_name_pairs[idx-1]
                    arrow((accuracies[model_names_p[0]+method_names_p[0]][i], accuracies[model_names_p[1]+method_names_p[1]][i]),
                          (accuracies[model_names[0]+method_names[0]][i], accuracies[model_names[1]+method_names[1]][i]),
                          "black",
                    )
        plt.legend(legends)
        # plt.title(title)
        plt.plot([0, 1], [0, 1], color="black", lw=2, linestyle='solid')
        
        plt.xlabel(model_names[0].split('-')[0]+", "+column_rename_map[method_names[0]])
        plt.ylabel(model_names[1].split('-')[0]+", "+column_rename_map[method_names[1]])
        # plt.show()
        adjust_text(texts)
        plt.savefig(f"{output_path}/{title}.png", format='png')


def calculate_analogy_total_accuracy(df, method="accuracy_3cosadd", topk="top1"):
    """
    given analogy performance df, calculate overall analogy accuracy
    """
    drop_cols = []
    for col in df.columns:
        if ((col.split(".")[-1] != topk) or ("exclude_original" not in col) or ("final" in col) or (method not in col)) and ("total" not in col):
            drop_cols.append(col)
    drop_cols.append("final.total")
    df = df.drop(columns=drop_cols)
    categories = set([c.split(".")[0] for c in df.columns])
    total = sum(df[c+".total"] for c in categories)
    hits = sum([df[f"{c}.embedding0.exclude_original.{method}.{topk}"] for c in categories])
    return float(hits/total)

def deep_get(dictionary, *keys):
    return reduce(lambda d, key: d.get(key) if d else None, keys, dictionary)

def collect_performance_data(data):
    performance_data = {
        "val_loss": deep_get(data, "valid","loss"),
        "val_acc1": deep_get(data, "valid","acc1"),
        "val_acc5": deep_get(data, "valid","acc5"),
        "val_acc10": deep_get(data, "valid","acc10"),
        # word similarity
        "MEN":deep_get(data,"word_similarity","bruni_men"),
        "MT":deep_get(data,"word_similarity","radinsky_mturk"),
        "RW":deep_get(data,"word_similarity","luong_rare"),
        "SM":deep_get(data,"word_similarity","sim999"),
        "WSR":deep_get(data,"word_similarity","ws353_relatedness"),
        "WSS":deep_get(data,"word_similarity","ws353_similarity"),
        "WS":deep_get(data,"word_similarity","ws353"),
        # # text classification 
        "SCI":deep_get(data,"text_classification","twenty_ng_sci","sum","val_accuracy"),
        "RELI":deep_get(data,"text_classification","twenty_ng_reli","sum","val_accuracy"),
        "SPORT":deep_get(data,"text_classification","twenty_ng_sport","sum","val_accuracy"),
        "COMP":deep_get(data,"text_classification","twenty_ng_comp","sum","val_accuracy"),
        "MS":deep_get(data,"text_classification","movie_sentiment","sum","val_accuracy"),
        # text classification test
        "SCI_test":deep_get(data,"text_classification","twenty_ng_sci","sum","test_accuracy"),
        "RELI_test":deep_get(data,"text_classification","twenty_ng_reli","sum","test_accuracy"),
        "SPORT_test":deep_get(data,"text_classification","twenty_ng_sport","sum","test_accuracy"),
        "COMP_test":deep_get(data,"text_classification","twenty_ng_comp","sum","test_accuracy"),
        "MS_test":deep_get(data,"text_classification","movie_sentiment","sum","test_accuracy"),
        # text classification 
        "SCI_m":deep_get(data,"text_classification","twenty_ng_sci","mean","val_accuracy"),
        "RELI_m":deep_get(data,"text_classification","twenty_ng_reli","mean","val_accuracy"),
        "SPORT_m":deep_get(data,"text_classification","twenty_ng_sport","mean","val_accuracy"),
        "COMP_m":deep_get(data,"text_classification","twenty_ng_comp","mean","val_accuracy"),
        "MS_m":deep_get(data,"text_classification","movie_sentiment","mean","val_accuracy"),
        # text classification test
        "SCI_mtest":deep_get(data,"text_classification","twenty_ng_sci","mean","test_accuracy"),
        "RELI_mtest":deep_get(data,"text_classification","twenty_ng_reli","mean","test_accuracy"),
        "SPORT_mtest":deep_get(data,"text_classification","twenty_ng_sport","mean","test_accuracy"),
        "COMP_mtest":deep_get(data,"text_classification","twenty_ng_comp","mean","test_accuracy"),
        "MS_mtest":deep_get(data,"text_classification","movie_sentiment","mean","test_accuracy"),
    }
    try:
        performance_data["SIM_total"] = (performance_data["RW"]*2034+performance_data["MT"]*287+\
                                        performance_data["WS"]*353+performance_data["WSS"]*203+\
                                        performance_data["WSR"]*252+performance_data["SM"]*999+\
                                        performance_data["MEN"]*3000) / (2034+287+353+203+252+999+3000)
    except:
        performance_data["SIM_total"] = None
    try:
        performance_data["TC_total"] = (performance_data["SCI"]*1994+performance_data["SPORT"]*1987+\
                                        performance_data["RELI"]*1995+performance_data["COMP"]*1981+\
                                        performance_data["MS"]*9142) / (1994+1987+1995+1981+9142)
        performance_data["TC_total_normalized"] = (performance_data["SCI"]+performance_data["SPORT"]+\
                                        performance_data["RELI"]+performance_data["COMP"]+\
                                        performance_data["MS"]) / 5
    except:
        performance_data["TC_total"] = None
    try:
        performance_data["TC_test_total"] = (performance_data["SCI_test"]*1994+performance_data["SPORT_test"]*1987+\
                                        performance_data["RELI_test"]*1995+performance_data["COMP_test"]*1981+\
                                        performance_data["MS_test"]*9142) / (1994+1987+1995+1981+9142)
        performance_data["TC_test_total_normalized"] = (performance_data["SCI_test"]+performance_data["SPORT_test"]+\
                                        performance_data["RELI_test"]+performance_data["COMP_test"]+\
                                        performance_data["MS_test"]) / 5
    except:
        performance_data["TC_test_total"] = None
    try:
        performance_data["TC_m_total"] = (performance_data["SCI_m"]*1994+performance_data["SPORT_m"]*1987+\
                                        performance_data["RELI_m"]*1995+performance_data["COMP_m"]*1981+\
                                        performance_data["MS_m"]*9142) / (1994+1987+1995+1981+9142)
        performance_data["TC_m_total_normalized"] = (performance_data["SCI_m"]+performance_data["SPORT_m"]+\
                                        performance_data["RELI_m"]+performance_data["COMP_m"]+\
                                        performance_data["MS_m"]) / 5
    except:
        performance_data["TC_m_total"] = None
    try:
        performance_data["TC_m_test_total"] = (performance_data["SCI_mtest"]*1994+performance_data["SPORT_mtest"]*1987+\
                                        performance_data["RELI_mtest"]*1995+performance_data["COMP_mtest"]*1981+\
                                        performance_data["MS_mtest"]*9142) / (1994+1987+1995+1981+9142)
        performance_data["TC_m_test_total_normalized"] = (performance_data["SCI_mtest"]+performance_data["SPORT_mtest"]+\
                                        performance_data["RELI_mtest"]+performance_data["COMP_mtest"]+\
                                        performance_data["MS_mtest"]) / 5
    except:
        performance_data["TC_m_test_total"] = None
    if "CBOW_NEG" in model_folder or ("Hybrid" in model_folder) or ("SkipGram" in model_folder):
        try:
            analogy_performance = json_normalize(data.get("analogy",[]))
            performance_data["analogy_acc1"] = calculate_analogy_total_accuracy(analogy_performance, method="accuracy_3cosadd", topk="top1")
            performance_data["analogy_acc5"] = calculate_analogy_total_accuracy(analogy_performance, method="accuracy_3cosadd", topk="top5")
            performance_data["analogy_acc10"] = calculate_analogy_total_accuracy(analogy_performance, method="accuracy_3cosadd", topk="top10")
        except:
            performance_data["analogy_acc1"] = None
            performance_data["analogy_acc5"] = None
            performance_data["analogy_acc10"] = None
    else:
        try:
            analogy_performance = json_normalize(data.get("analogy",[]))
            performance_data["analogy_acc1"] = calculate_analogy_total_accuracy(analogy_performance, method="accuracy_seqadd_5", topk="top1")
            performance_data["analogy_acc5"] = calculate_analogy_total_accuracy(analogy_performance, method="accuracy_seqadd_5", topk="top5")
            performance_data["analogy_acc10"] = calculate_analogy_total_accuracy(analogy_performance, method="accuracy_seqadd_5", topk="top10")
        except:
            performance_data["analogy_acc1"] = None
            performance_data["analogy_acc5"] = None
            performance_data["analogy_acc10"] = None
    return performance_data

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str)
parser.add_argument('--pivot', action="store_true",
    help="if extracting detail analogy task results with pivot table")
parser.add_argument('--scatter', action="store_true",
    help="if extracting analogy accuracy of specified method from two models")

args = parser.parse_args()

output_path = f"{args.path.strip('/')}_result"
os.makedirs(output_path,exist_ok=True)

df = []
analogy_df = pd.DataFrame()

def load_json(directory, name):
    if name in os.listdir(directory):
        with open(os.path.join(directory,name)) as f:
            return json.load(f)
    else:
        return {}

for model_folder in os.listdir(args.path):
    if ("CBOW" not in model_folder) and ("BTL" not in model_folder) and \
       ("Hybrid" not in model_folder) and ("SkipGram" not in model_folder):
        continue
    print(f"processing data:{model_folder}")
    # check if it's a model param file (with model_name-m5.json pattern)
    # if "time" not in file and "performance" not in file and "events" not in file:
    model_folder_path = os.path.join(args.path, model_folder)
    # rename old code file names:
    # rename_old_to_new(model_folder, model_folder_path)

    model_data = load_json(model_folder_path, "hyper.json")
    model_data = { k: v for k, v in model_data.items() if k not in irrelevant_keys }

    time_data = load_json(model_folder_path, "time-memory.json")

    data = load_json(model_folder_path, "performance.json")

    analogy_df_small = json_normalize(data.get("analogy",[]))
    model     = model_folder.split("-")[0]
    directory = model_folder
    result = []
    cols = ["model","path","track","embedding","exclude?","method","scoretype","score"]
    for col in analogy_df_small.columns:
        # col == "final.embedding0.exclude_original.accuracy_seqadd.accuracy_top1"
        path = col.split(".")
        # access_path == ["final","embedding0","exclude_original","accuracy_seqadd","accuracy_top1"]
        if len(path)!=5:
            continue
        if path[-1] in ["top1", "top10", "top5"]:
            continue
        for score in analogy_df_small[col].tolist():
            result.append([model,directory,*path,score])

    analogy_df_small2 = pd.DataFrame(result, columns=cols)
    analogy_df = analogy_df.append(analogy_df_small2)

    try:
        performance_data = collect_performance_data(data)
    except:
        performance_data = {}
    df.append({"model":model_folder.split("-")[0],"path":model_folder,**time_data,**model_data,**performance_data})

df = pd.DataFrame(df)

if "train" in df.columns:
    dataset_name = df["train"][0].split("/")[1]
else:
    dataset_name = "ccc"



analogy_df = analogy_df.sort_values(by=["model"])
analogy_df = analogy_df.round(4)
analogy_df["model"] = analogy_df["model"].map(model_rename_map)
analogy_df.to_csv(f"{output_path}/analogy_{dataset_name}.csv")

# if args.pivot:
#     for model_folder in os.listdir(args.path):
#         print("plotting the heat map:",model_folder)
#         analogy_heat_map(analogy_df, model_folder, "top10", aggregation="mean", exclusion="exclude_original", plot=True)
if args.scatter:
    # best analogy models by embeddings
    # model_name_pairs = [
    #     ["CBOW_NEG-ab09e264c6fd454ae87a418e034186c1", "CBOW_BTL_Sequential-5897474224791aa750466b4c36a77693"],
    #     ["CBOW_NEG-a478b776c3de8b72bf2c3ca002949c9b", "CBOW_BTL_Sequential-920379b95fa08aa3c26ff80446ceb1ff"],
    #     ["CBOW_NEG-b2cdce148bd074251e4c19bbadd98f76", "CBOW_BTL_Sequential-0f338df683ac64ee6771d25047de22da"], 
    #     ["CBOW_NEG-9c9a1ff10d5a10759fa29fdd1f72377e", "CBOW_BTL_Sequential-3fb06adb69aa461ffbca38390e7dfa98"]
    # ]
    # method_names = [["accuracy_3cosadd","accuracy_seqadd_5"],
    #      ["accuracy_3cosadd","accuracy_seqadd_5"],
    #      ["accuracy_3cosadd","accuracy_seqadd_5"],
    #      ["accuracy_3cosadd","accuracy_seqadd_5"],
    #      ]
    # analogy_scatter_plot(analogy_df, model_name_pairs, method_names, 
    #     "top10", 
    #     exclusion="exclude_original", 
    #     plot=True,
    #     legends=["200", "500", "1000", "2000"]
    #     )

    print("Usage note for this mode: Use 'scatter' directory, which has symlinks to the selected models")
    
    model_name_pairs = [
        ["CBOW_NEG-b2cdce148bd074251e4c19bbadd98f76", "CBOW_BTL_Sequential-0f338df683ac64ee6771d25047de22da"],
        ["CBOW_NEG-b2cdce148bd074251e4c19bbadd98f76", "CBOW_BTL_Sequential-0f338df683ac64ee6771d25047de22da"],
        ["CBOW_NEG-b2cdce148bd074251e4c19bbadd98f76", "CBOW_BTL_Sequential-0f338df683ac64ee6771d25047de22da"],
    ]
    method_names = [
        ["accuracy_only_b","accuracy_seqadd_only_b"],
        ["accuracy_ignore_a","accuracy_seqadd_ignore_a_1"],
        ["accuracy_3cosadd","accuracy_seqadd_5"],
    ]
    analogy_scatter_plot(analogy_df, model_name_pairs, method_names, 
        "top10", 
        exclusion="exclude_original", 
        plot=True,
        legends=["only b", "ignore a", "add"]
        )
    exit()

out_file_name = f"{output_path}/benchmark_{dataset_name}.csv"
df = df.sort_values(by=["model"])
df = df.round(4)
df["model"] = df["model"].map(model_rename_map)
df.to_csv(out_file_name, index=True)
