# eval.py

from collections import defaultdict
import os, json

RESULT_DIR = "../../results"
OUTPUT_FILE = "result.md"

class Counter:
    def __init__(self) -> None:
        self.accepted, self.total = 0, 0
    def step(self, new: bool):
        self.total += 1
        if new:
            self.accepted += 1
    def accuracy(self):
        if self.total == 0:
            return 0.0
        else:
            return self.accepted / self.total
        
def extract_to_dict(ans: str):
    res_dict = {}
    for pair in ans.split(";"):
        pair_split = pair.split(':')
        if len(pair_split) == 2:
            entity, type_ent = pair_split[0].strip(), pair_split[1].strip().removesuffix(".")
            res_dict[entity] = type_ent
    return res_dict

def get_data():
    result_output = defaultdict(dict)
    dataset_list, model_list = set(), set()

    for item in os.walk(RESULT_DIR):
        if len(item[1]) == 0:
            file_name = item[0] + "/" + item[2][0]
            lines = open(file_name, "r", encoding="utf-8").readlines()
            if len(lines) == 0:
                continue
            req_count = Counter()
            for line in lines:
                data = json.loads(line)
                ref_ans = str(data["instance"]["references"][0]["output"]["text"])
                mod_ans = str(data["request"]["result"]["completions"][0]["text"]).removeprefix("Answer: ")
                ref_dict, mod_dict = extract_to_dict(ref_ans), extract_to_dict(mod_ans)
                req_count.step(ref_dict == mod_dict)
            dir_split = item[0].split('/')
            dataset_name, model_name = dir_split[1], dir_split[2]
            dataset_list.add(dataset_name)
            model_list.add(model_name)
            result_output[dataset_name][model_name] = round(req_count.accuracy(), 3)
            
    return result_output, list(model_list), list(dataset_list)

if __name__ == "__main__":
    result_output, model_list, dataset_list = get_data()

    output_text = "|Dataset|" + "|".join(model_list) + "|\n"
    output_text += "|-" + "|-" * len(model_list) + "|\n"
    for dataset in dataset_list:
        output_text += "|" + dataset + "|"
        for model in model_list:
            if model in result_output[dataset]:
                output_text += str(result_output[dataset][model])
            output_text += "|"
        output_text += "\n"
        
    print(output_text, file=open(RESULT_DIR + "/" + OUTPUT_FILE, "w"))
