import os
import json

def judge(pred, ans):
    old_flag = True
    if not ans in pred:
        old_flag = False
    if "True" in pred:
        pred = pred.replace("True", "1")
    else:
        pred = pred.replace("False", "0")
    if ans == "False" or ans == "false":
        ans = "0"
    if ans == "True" or ans == "true":
        ans = "1"
    if ans == "None" or ans == "none":
        ans = "0"
    if ", " in ans:
        ans = ans.split(', ')
    if ans[-2:] == ".0":
        ans = ans[:-2]
    if not type(ans) == list:
        ans = [ans]
    new_flag = True
    for i in range(len(ans)):
        if not ans[i] in pred:
            new_flag = False
            break
    return (old_flag or new_flag)

logs_path = "<YOUR_LOGS_PATH>"
files = os.listdir(logs_path)

# read the files 
answer_book = "<YOUR_DATASET_PATH>"
with open(answer_book, 'r') as f:
    contents = json.load(f)
answers = {}
for i in range(len(contents)):
    answers[contents[i]['id']] = contents[i]['answer']

stats = {"total_num": 0, "correct": 0, "unfinished": 0, "incorrect": 0}

for file in files:
    if not file.split('.')[0] in answers.keys():
        continue
    with open(logs_path+file, 'r') as f:
        logs = f.read()
    split_logs = logs.split('\n----------------------------------------------------------\n')
    question = split_logs[0]
    answer = answers[file.split('.')[0]]
    if type(answer) == list:
        answer = ', '.join(answer)
    stats["total_num"] += 1
    if not "TERMINATE" in logs:
        stats["unfinished"] += 1
    else:
        if '"cell": "' in logs:
            last_code_start = logs.rfind('"cell": "')
            last_code_end = logs.rfind('"\n}')
            last_code = logs[last_code_start+9:last_code_end]
        else:
            last_code_end = logs.rfind('Solution:')
        prediction_end = logs.rfind('TERMINATE')
        prediction = logs[last_code_end:prediction_end]
        logs = logs.split('TERMINATE')[0]
        result = judge(prediction, answer)
        if result:
            stats["correct"] += 1
        else:
            stats["incorrect"] += 1

print(stats)
