import re
import json
from argparse import ArgumentParser

from src.utils.load_dataset import get_datas

def read_pred_file(input_file):
    preds = {}
    with open(input_file, "r") as fin:
        for line in fin.readlines():
            preds.update(json.loads(line))
    return preds

def clean_answer(answer):
    pattern = r"\b[A-J]\b|[A-J](?=\s|:|.)"
    match = re.search(pattern, answer)
    if match is None:
        return None
    return match.group()

def match_answer(choices, gold_answer, prediction):
    prediction = prediction.split("\n")[-1]
    try:
        prediction = json.loads(prediction.strip())["answer"]
    except:
        pass
    gold_choice = choices[gold_answer - 1]
    if gold_choice in prediction or prediction in gold_choice:
        return True
    print(gold_choice)
    print(prediction)
    input()
    return False
    

def evaluate(choices, gold_answer, predict_answer):
    if type(gold_answer) is int:
        gold_answer = chr(ord("A") + gold_answer)
    cleaned_predict_answer = clean_answer(predict_answer)
    if cleaned_predict_answer is None:
        # print(predict_answer)
        # return match_answer(choices, gold_answer, predict_answer)
        return False
    if cleaned_predict_answer == gold_answer:
        return True
    return False

def compare(answer_1, answer_2):
    cleaned_answer_1 = clean_answer(answer_1)
    cleaned_answer_2 = clean_answer(answer_2)
    if cleaned_answer_1 is None or cleaned_answer_2 is None:
        return False
    if cleaned_answer_1 == cleaned_answer_2:
        return True
    return False

def main():
    parser = ArgumentParser()
    parser.add_argument("--dataset", type=str, default="humaneval")
    parser.add_argument("--file", type=str)
    
    args = parser.parse_args()
    
    # load datas and preds
    datas, _ = get_datas(args.dataset)
    preds = read_pred_file(args.file)
    
    cnt_valid = 0
    cnt_correct = 0
    for data in datas:
        _ = data["id"]
        choices = data["choices"]
        pred = preds.get(f"{_}")
        if pred is None:
            continue
        answer = data["answer"]
        if evaluate(choices, answer, pred):
            cnt_correct += 1
        cnt_valid += 1
    print(f"All sample number: {cnt_valid}")
    print(f"Correct sample number: {cnt_correct}")
    print(f"Accuracy: {cnt_correct / cnt_valid}")
        
if __name__ == "__main__":
    main()