import numpy as np 
import json 
import prettytable as pt
import click
import os
from evaluate import f1_score, node_hallucination_rate, prediction_loader


def bi_evaluate(dataset, llm, method, tool_nodes, tool_links):
    global table, print_base
    
    raw_pred_file = f"../prediction/{dataset}/{llm}/direct.json"
    refine_pred_file = f"../prediction/{dataset}/{llm}/{method}.json"

    gt_file = f"../data/{dataset}/data.json"

    if not os.path.exists(raw_pred_file) or not os.path.exists(refine_pred_file):
        return

    alignment_ids = list(prediction_loader(refine_pred_file, "id").keys())
    # print(len(alignment_ids), len(set(alignment_ids)))
    raw_pred_dict = prediction_loader(raw_pred_file, content_type="graph")
    refine_pred_dict = prediction_loader(refine_pred_file, content_type="graph")
    efficiency_dict = prediction_loader(refine_pred_file, content_type="efficiency")
    gt_dict = prediction_loader(gt_file, content_type="graph")
    
    cost_times, query_times = [], []
    init_scores, refine_scores = [], []

    for data_id in alignment_ids:
        gt_graph = gt_dict[data_id]
        raw_pred_graph, refine_pred_graph = raw_pred_dict[data_id], refine_pred_dict[data_id]

        raw_node_f1 = f1_score(raw_pred_graph["nodes"], gt_graph["nodes"])
        cur_node_f1 = f1_score(refine_pred_graph["nodes"], gt_graph["nodes"])

        raw_link_f1 = f1_score(raw_pred_graph["links"], gt_graph["links"])
        cur_link_f1 = f1_score(refine_pred_graph["links"], gt_graph["links"])

        raw_node_hr = node_hallucination_rate(raw_pred_graph["nodes"], tool_nodes)
        cur_node_hr = node_hallucination_rate(refine_pred_graph["nodes"], tool_nodes)

        raw_link_hr = node_hallucination_rate(raw_pred_graph["links"], tool_links)
        cur_link_hr = node_hallucination_rate(refine_pred_graph["links"], tool_links)

        init_scores.append([raw_node_f1, raw_link_f1] + raw_node_hr + raw_link_hr)
        refine_scores.append([cur_node_f1, cur_link_f1] + cur_node_hr + cur_link_hr)

        cost_times.append(efficiency_dict[data_id]["cost_time"])
        query_times.append(efficiency_dict[data_id]["llm_query_times"])
    
    avg_pred_score = np.round(np.mean(np.array(init_scores), axis=0), 4)
    avg_refine_score = np.round(np.mean(np.array(refine_scores), axis=0), 4)
    avg_query_time = np.round(np.mean(np.array(query_times)), 2)

    if not print_base:
        print_base = True
    raw_efficiency_dict = prediction_loader(raw_pred_file, content_type="efficiency")
    avg_raw_time = np.round(np.mean(np.array([raw_efficiency_dict[data_id]["cost_time"]/4 for data_id in alignment_ids])), 2)
    
    table.add_row([dataset, llm, "Direct", avg_pred_score[0], avg_pred_score[1], avg_pred_score[2], avg_pred_score[4], avg_raw_time, 1.0])
    avg_cost_time = np.round(np.mean(np.array(cost_times)), 2)
    table.add_row([dataset, llm, method, avg_refine_score[0], avg_refine_score[1], avg_refine_score[2], avg_refine_score[4], avg_cost_time, avg_query_time])


@click.command()
@click.option("--datasets", default=["huggingface", "multimedia", "dailylife", "tmdb"], type=list, help="The directory of the data")
@click.option("--llm", type=str, default="CodeLlama-13b") 
@click.option("--methods", default=["graphsearch_greedy", "graphsearch_adaptive", "graphsearch_beam_2", "graphsearch_beam_3", "lightgcn"], type=list)
def main(datasets, llm, methods):
    global table, print_base
    for dataset in datasets:
        print_base = False
        tool_nodes = [node["id"] for node in json.load(open(f"data/{dataset}/tool_desc.json", 'r'))["nodes"]]
        tool_links = [", ".join([link["source"], link["target"]]) for link in json.load(open(f"data/{dataset}/graph_desc.json", 'r'))["links"]]
        for method in methods:
            bi_evaluate(dataset, llm, method, tool_nodes, tool_links)
    
    print(table)


if __name__ == "__main__":
    table = pt.PrettyTable()
    table.field_names = ["Dataset", "LLM", "Method", "Node-F1", "Link-F1", "Node-Hall", "Link-Hall", "Time", "#Query"]
    print_base = False
    main()
