# select 50 failed cases

import argparse, json, xlwt, os

RESULT_DIR = "../../results"
RESULT_FILE = "results.jsonl"
XLS_FILE = "failure.xls"

def extract_to_dict(ans: str):
    res_dict = {}
    for pair in ans.split(";"):
        pair_split = pair.split(':')
        if len(pair_split) == 2:
            entity, type_ent = pair_split[0].strip(), pair_split[1].strip().removesuffix(".")
            res_dict[entity] = type_ent
    return res_dict

def simplify_dict(ans: dict):
    to_drop = set()
    for key in ans:
        for another_key in ans:
            if key != another_key and key in another_key and ans[key] == ans[another_key]:
                to_drop.add(key)
    res = {}
    for key in ans:
        if key not in to_drop:
            res[key] = ans[key]
    return res

if __name__ == "__main__":
    excel_wb = xlwt.Workbook(encoding='utf-8', style_compression=0)
    excel_ws = excel_wb.add_sheet("failure", True)
    excel_ws.write(0, 0, "input")
    excel_ws.write(0, 1, "reference")
    excel_ws.write(0, 2, "model output")
    excel_ws.write(0, 3, "note")
    
    record_count = 0
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--number", type=int, default=50)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--model", type=str, required=True)
    args = parser.parse_args()
    
    file_name = RESULT_DIR + "/" + args.dataset + "/" + args.model + "/" + RESULT_FILE
    for line in open(file_name, "r", encoding="utf-8").readlines():
        data = json.loads(line)
        input_text = str(data["instance"]["input"]["text"])
        ref_ans = str(data["instance"]["references"][0]["output"]["text"])
        mod_ans = str(data["request"]["result"]["completions"][0]["text"]).removeprefix("Answer: ")
        ref_dict, mod_dict = extract_to_dict(ref_ans), extract_to_dict(mod_ans)
        # ref_dict = simplify_dict(ref_dict)
        if ref_dict != mod_dict:
            record_count += 1
            excel_ws.write(record_count, 0, input_text)
            excel_ws.write(record_count, 1, ref_ans)
            excel_ws.write(record_count, 2, mod_ans)
        
    output_file = RESULT_DIR + "/" + args.dataset + "/" + args.model + "/" + XLS_FILE
    if os.path.exists(output_file):
        os.remove(output_file)
    excel_wb.save(output_file)
    