import os
import json
import warnings
import numpy as np
from itertools import product
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

def friendly_name(llm):
    llm = llm.replace("WizardLM-13B-V1.2", "wizardlm-13b")
    llm = llm.replace("Nous-Hermes-Llama2-13b", "nous-hermes-13b")
    llm = llm.replace("WizardLM-13B-V1.2", "wizardlm-13b")
    llm = llm.lower()
    llm = llm.replace("-hf", "")
    llm = llm.replace("-32k", "")
    llm = llm.replace("-instruct", "")
    return llm

def load_pivot_metrics(dirs = ["data_dailylifeapis"], sub_dirs= ["metrics_alignment_all"], return_dataframe=False):
    pivot_table = []
    for domain_dir in dirs:
        domain = domain_dir.split("_")[1]
        for setting_dir in sub_dirs:
            metrics_dir = os.path.join(domain_dir, setting_dir)
            setting = setting_dir.replace("metrics", "main")
            for file_name in os.listdir(metrics_dir):
                if file_name.endswith(".json"):
                    # get base name
                    LLM = os.path.splitext(file_name)[0]
                    LLM = friendly_name(LLM)
                    try:
                        result = json.load(open(os.path.join(metrics_dir, file_name), "r"))
                        assert isinstance(result, dict), "result is not dict"
                    except:
                        continue
                    for structure_numtools, metrics in result.items():
                        structure, num_tools = structure_numtools.split("_")
                        for metric, value in metrics.items():
                            if "per_type" in metric:
                                for type, per_type_value in value.items():
                                    for leaf_metric, leaf_value in per_type_value.items():
                                        pivot_table.append([LLM, domain, setting, structure, num_tools, metric + ">" + type + ">" + leaf_metric, leaf_value])
                            else:
                                pivot_table.append([LLM, domain, setting, structure, num_tools, metric, value])
    if return_dataframe:
        return pd.DataFrame(pivot_table, columns=["llm", "domain", "setting", "structure", "num_tools", "metric", "value"])
    else:
        return pivot_table

def filter_pivot_table(pivot_table, dimension_values, schema):
    dimensions = list(pivot_table.columns)
    index = [dimensions.index(i) for i in schema]
    filtered_pivot_table = pivot_table
    for i, dimension_value in enumerate(dimension_values):
        filtered_pivot_table = filtered_pivot_table[filtered_pivot_table[dimensions[index[i]]] == dimension_value]
    return filtered_pivot_table

def get_metrics_by_xy(pivot_table: pd.DataFrame, x_dimension_values, y_dimension_values, X_schema, Y_schema):
    filtered_pivot_table = filter_pivot_table(pivot_table, x_dimension_values, X_schema)
    filtered_pivot_table = filter_pivot_table(filtered_pivot_table, y_dimension_values, Y_schema)
    metrics = {}
    for row in filtered_pivot_table.values:
        metrics[row[-2]] = row[-1]
    return metrics

def retrieve_metrics(x, Y, Y_schema="structure.num_tools.metric", return_printable_table=True, dirs = ["data_dailylifeapis"], sub_dirs= ["metrics_alignment_all"]):
    pivot_table = load_pivot_metrics(return_dataframe=True, dirs = dirs, sub_dirs = sub_dirs)
    dimensions = list(pivot_table.columns)
    metrics = list(set(pivot_table["metric"].values))

    Y_schema = Y_schema.split(".")
    for Y_i in Y:
        assert len(Y_i.split(".")) == len(Y_schema), "Y_i " + Y_i + " does not match Y_schema " + ".".join(Y_schema)

    if x == "":
        X_schema = []
    else:
        X_schema = x.split(".")
    result_table_height = 1
    cascade_X = []
    for X_i in X_schema:
        assert X_i in dimensions, "dimension " + X_i + " not in pivot table"
        values = list(set(pivot_table[X_i].values))
        height_i = len(values)
        result_table_height *= height_i
        cascade_X.append(values)
    retrieved_pivot_table = []
    retrieved_pivot_table_dimensions = X_schema + [".".join(Y_schema)] + ["value"]
    result_table_width = len(Y)
    result_table = np.zeros((result_table_height, result_table_width))
    row_names = []

    for row_id, x_dimension_values in enumerate(product(*cascade_X)):
        x_dimension_values = list(x_dimension_values)
        row_names.append(x_dimension_values)
        for col_id, Y_i in enumerate(Y):
            data = x_dimension_values + [Y_i]
            Y_i_split = Y_i.split(".")
            metric = Y_i_split[-1]
            assert metric in metrics, "metric " + metric + " not in pivot table"
            y_dimension_values = Y_i_split[:-1]

            filtered_metrics = get_metrics_by_xy(pivot_table, x_dimension_values, y_dimension_values, X_schema, Y_schema)
            try:
                data.append(filtered_metrics[metric])
                result_table[row_id, col_id] = filtered_metrics[metric]
            except Exception as e:
                data.append(0)
                print(e)
                print(f"Not found @ ({'.'.join(x_dimension_values)}, {'.'.join(y_dimension_values)}, {metric}), set to 0")
            retrieved_pivot_table.append(data)
    if return_printable_table:
        if len(row_names[0]) > 1:
            sort_key = [f"{row_name[0]:30s}" + f"{result_table[row_id, -1]:05.2f}" for row_id, row_name in enumerate(row_names)]
        else:
            sort_key = result_table[:, -1]
        
        index = np.argsort(sort_key)[::-1]
        result_table = result_table[index]
        row_names = [row_names[i] for i in index]
        return result_table, row_names
    else:
        return pd.DataFrame(retrieved_pivot_table, columns=retrieved_pivot_table_dimensions)

def print_table(x, Y, Y_schema="structure.num_tools.metric", type="percent", dirs = ["data_dailylifeapis"], sub_dirs= ["metrics_alignment_all"]):
    result_table, row_names = retrieve_metrics(x, Y, Y_schema, return_printable_table=True, dirs = dirs, sub_dirs = sub_dirs)
    row_names_str = []
    col_names_str = Y
    for row_name in row_names:
        format = ["{:25s}"]*len(row_name)
        name = " & ".join(format).format(*row_name)
        row_names_str.append(name)
    
    for x_i in x.split("."):
        print("{:25s} & ".format(x_i), end="")
    for col_id, col_name_i in enumerate(col_names_str):
        print(col_name_i, end="")
        if col_id != len(col_names_str) - 1:
            print(" & ", end="")
        else:
            print(" \\\\")

    for row_id, row_name_i in enumerate(row_names_str):
        print(row_name_i, end=" & ")
        for col_id, col_name_i in enumerate(col_names_str):
            if type == "percent":
                print(f"{result_table[row_id, col_id]:.2%}".replace("%", ""), end="")
            elif type == "float":
                print(f"{result_table[row_id, col_id]:.2f}", end="")
            elif type == "int":
                print(f"{result_table[row_id, col_id]:.0f}", end="")
            else:
                raise NotImplementedError
            if col_id != len(col_names_str) - 1:
                print(" & ", end="")
            else:
                print(" \\\\")

def print_paper_table():
    print("##### Table 1 #####")
    metrics = [
        "huggingface.main_alignment_all.overall.overall.step_rouge1",
        "huggingface.main_alignment_all.overall.overall.step_rouge2",
        "huggingface.main_alignment_all.overall.overall.step_bertscore_f1",
        "multimedia.main_alignment_all.overall.overall.step_rouge1",
        "multimedia.main_alignment_all.overall.overall.step_rouge2",
        "multimedia.main_alignment_all.overall.overall.step_bertscore_f1",
        "dailylifeapis.main_alignment_all.overall.overall.step_rouge1",
        "dailylifeapis.main_alignment_all.overall.overall.step_rouge2",
        "dailylifeapis.main_alignment_all.overall.overall.step_bertscore_f1"
        ]
    print_table("llm", metrics, Y_schema="domain.setting.structure.num_tools.metric", dirs = ["data_dailylifeapis", "data_huggingface", "data_multimedia"], sub_dirs= ["metrics_alignment_all"], type="percent")

    print("\n\n")
    print("##### Table 2 #####")
    metrics = [
        "main_alignment_all.single.overall.node_micro_f1_no_matching", 
        "main_alignment_all.chain.overall.node_micro_f1_no_matching",
        "main_alignment_all.chain.overall.link_binary_f1", 
        "main_alignment_all.chain.overall.edit_distance",
        "main_alignment_all.dag.overall.node_micro_f1_no_matching",
        "main_alignment_all.dag.overall.link_binary_f1", 
        "main_alignment_all.overall.overall.node_micro_f1_no_matching",
        "main_alignment_all.overall.overall.link_binary_f1", 
        ]
    print_table("domain.llm", metrics, Y_schema="setting.structure.num_tools.metric", dirs = ["data_dailylifeapis", "data_huggingface", "data_multimedia"], sub_dirs= ["metrics_alignment_all"])

    print("\n\n")
    print("##### Table 3 #####")
    metrics = [
        "main_alignment_all.single.overall.argument_task_argname_binary_f1_no_matching", 
        "main_alignment_all.single.overall.argument_task_argname_value_binary_f1_no_matching", 
        "main_alignment_all.chain.overall.argument_task_argname_binary_f1_no_matching", 
        "main_alignment_all.chain.overall.argument_task_argname_value_binary_f1_no_matching", 
        "main_alignment_all.dag.overall.argument_task_argname_binary_f1_no_matching", 
        "main_alignment_all.dag.overall.argument_task_argname_value_binary_f1_no_matching", 
        "main_alignment_all.overall.overall.argument_task_argname_binary_f1_no_matching", 
        "main_alignment_all.overall.overall.argument_task_argname_value_binary_f1_no_matching"
        ]
    print_table("domain.llm", metrics, dirs = ["data_dailylifeapis", "data_huggingface", "data_multimedia"], sub_dirs= ["metrics_alignment_all"], Y_schema="setting.structure.num_tools.metric")

    print("\n\n")
    print("##### Supports #####")
    metrics = [
        "gpt-4.overall.all_samples"
    ]
    print_table("domain.structure", Y=metrics, Y_schema="llm.num_tools.metric", dirs = ["data_dailylifeapis", "data_huggingface", "data_multimedia"], sub_dirs= ["metrics_alignment_all"], type="int")

    print("\n\n")
    print("##### Table 3 #####")
    metrics = [
        "main_alignment_all.single.overall.argument_task_argname_binary_f1_no_matching", 
        "main_alignment_all.single.overall.argument_task_argname_value_binary_f1_no_matching", 
        "main_alignment_all.chain.overall.argument_task_argname_binary_f1_no_matching", 
        "main_alignment_all.chain.overall.argument_task_argname_value_binary_f1_no_matching", 
        "main_alignment_all.dag.overall.argument_task_argname_binary_f1_no_matching", 
        "main_alignment_all.dag.overall.argument_task_argname_value_binary_f1_no_matching", 
        "main_alignment_all.overall.overall.argument_task_argname_binary_f1_no_matching", 
        "main_alignment_all.overall.overall.argument_task_argname_value_binary_f1_no_matching"
        ]
    print_table("domain.llm", metrics, Y_schema="setting.structure.num_tools.metric")


    print("\n\n")
    print("##### Table 4 #####")
    sub_dirs = [
        "metrics_reformat_by_gpt-3.5-turbo_alignment_all", 
        "metrics_use_demos_1_reformat_by_gpt-3.5-turbo_alignment_all",
        "metrics_use_demos_2_reformat_by_gpt-3.5-turbo_alignment_all"
        ]
    metrics = [
        "dailylifeapis.overall.overall.step_rougeL",
        "dailylifeapis.overall.overall.node_micro_f1_no_matching",
        "dailylifeapis.overall.overall.link_binary_f1",
        "dailylifeapis.overall.overall.argument_task_argname_binary_f1_no_matching",
        "dailylifeapis.overall.overall.argument_task_argname_value_binary_f1_no_matching",
    ]
    dirs = ["data_dailylifeapis"]
    print_table("setting.llm", Y=metrics, Y_schema="domain.structure.num_tools.metric", dirs = dirs, sub_dirs=sub_dirs)


def plot_fewshot():
    sub_dirs = [
        "metrics_reformat_by_gpt-3.5-turbo_alignment_all", 
        "metrics_use_demos_1_reformat_by_gpt-3.5-turbo_alignment_all",
        "metrics_use_demos_2_reformat_by_gpt-3.5-turbo_alignment_all"
        ]
    metrics = {
        "dailylifeapis.overall.overall.step_rougeL": "Rouge-L",
        "dailylifeapis.overall.overall.node_micro_f1_no_matching": "$n$-F1",
        "dailylifeapis.overall.overall.link_binary_f1": "$e$-F1",
        "dailylifeapis.overall.overall.argument_task_argname_binary_f1_no_matching": "$t$-F1",
        "dailylifeapis.overall.overall.argument_task_argname_value_binary_f1_no_matching": "$v$-F1",
    }
    dirs = ["data_dailylifeapis"]
    pivot_table = retrieve_metrics("setting.llm", list(metrics.keys()), Y_schema="domain.structure.num_tools.metric", dirs = dirs, sub_dirs=sub_dirs, return_printable_table=False)
    columns = ["Few-shot", "LLM", "Metric", "Performance"]
    pivot_table.columns = columns
    pivot_table["Metric"] = pivot_table["Metric"].apply(lambda x: metrics[x])
    def map_fewshot(x):
        if "use_demos_2" in x:
            return "2-shot"
        elif "use_demos_1" in x:
            return "1-shot"
        else:
            return "0-shot"
    pivot_table["Few-shot"] = pivot_table["Few-shot"].apply(lambda x: map_fewshot(x))
    pivot_table = pivot_table.sort_values(by=["Few-shot", "Performance"], ascending = [True, False])
    pivot_table["Performance"] = pivot_table["Performance"] * 100

    filter_llms = ["codellama-13b", "gpt-4", "text-davinci-003"]
    filter_metrics = ["Rouge-L", "$n$-F1", "$v$-F1"]
    pivot_table = pivot_table[pivot_table["LLM"].isin(filter_llms)]
    pivot_table = pivot_table[pivot_table["Metric"].isin(filter_metrics)]
    print(pivot_table.head)

    plt.figure(figsize=(8, 5))

    sns.set_theme()
    sns.set_theme(style="ticks")
    sns.set_context("poster")    

    sns.lineplot(x="Few-shot", y="Performance",
                hue="LLM", style="Metric", markers=True,
                data=pivot_table)
    
    plt.legend(loc='lower right', fontsize=14, frameon=True, ncol=2)
    plt.xlabel("Few-shot Demonstrations")
    plt.xlabel("")
    plt.ylabel("")
    plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter())
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)

    plt.tight_layout()
    plt.savefig("fewshot.pdf")

    plt.show()


def plot_overall():
    sns.set_theme()
    sns.set_theme(style="ticks")
    sns.set_context("poster")

    metrics = {
        "dailylifeapis.overall.overall.node_micro_f1_no_matching": "$n$-F1",
        "dailylifeapis.overall.overall.link_binary_f1": "$e$-F1",
        "dailylifeapis.overall.overall.argument_task_argname_binary_f1_no_matching": "$t$-F1",
        "dailylifeapis.overall.overall.argument_task_argname_value_binary_f1_no_matching": "$v$-F1",
    }
    dirs = ["data_dailylifeapis"]
    sub_dirs = ["metrics_alignment_all"]
    retrieved_pivot_table = retrieve_metrics("llm", list(metrics.keys()), Y_schema="domain.structure.num_tools.metric", dirs = dirs, sub_dirs=sub_dirs, return_printable_table=False)
    columns = ["LLM", "Metric", "Performance"]
    retrieved_pivot_table.columns = columns
    retrieved_pivot_table["Metric"] = retrieved_pivot_table["Metric"].apply(lambda x: metrics[x])
    retrieved_pivot_table["Performance"] = retrieved_pivot_table["Performance"] * 100
    filter_llms = ['gpt-4', 'text-davinci-003', 'gpt-3.5-turbo', 'nous-hermes-13b', 'codellama-13b', 'wizardlm-13b', 'vicuna-13b-v1.5']
    retrieved_pivot_table = retrieved_pivot_table[retrieved_pivot_table["LLM"].isin(filter_llms)]
    retrieved_pivot_table = retrieved_pivot_table.sort_values(by=["Performance"], ascending = [False])
    llm_names = list(retrieved_pivot_table["LLM"].values)
    llms = []
    for llm in llm_names:
        if llm not in llms:
            llms.append(llm)

    colors = ['#ccebc5', '#ff9f9b' , '#a1c9f4' , '#d0bbff']
    palette = dict(zip(metrics.values(), colors))

    g = sns.catplot(
        data=retrieved_pivot_table,
        x='Performance',
        y='LLM',
        hue = 'Metric',
        kind='bar',
        orient='h',
        aspect=1,
        height=11,
        ci=None,
        dodge=True,
        legend=False,
        palette = palette

    )

    plt.xlim(0, 100)
    xlabel_fontsize = 22
    ylabel_fontsize = 20

    plt.legend(loc='lower right', fontsize=18, frameon=True, ncol=1)
    plt.grid(axis='x', linestyle='--', linewidth=1, alpha=0.5)

    g.set_axis_labels("", "", fontsize=xlabel_fontsize).set_titles("{col_name}", size=22)

    g.set_xticklabels(fontsize=22)
    g.set_yticklabels(g.axes[0][0].get_yticklabels(), fontsize=22)
    ax = plt.gca()
    ax.spines['top'].set_visible(True)
    ax.spines['right'].set_visible(True)
    ax.set_yticklabels(llms, rotation=45, va='center')
    ax.xaxis.set_major_formatter(mtick.PercentFormatter())

    plt.tight_layout()
    plt.savefig(f'overall.pdf', bbox_inches='tight', format='pdf')
    plt.show()


print_paper_table()
plot_fewshot()
plot_overall()
