import re
import json
import torch
from tqdm import tqdm
from datasets import load_dataset
from argparse import ArgumentParser

def read_pred_file(input_file):
    preds = {}
    with open(input_file, "r") as fin:
        for line in fin.readlines():
            preds.update(json.loads(line))
    return preds

def clean_answer(options, answer):
    for option, content in options.items():
        # if option not in answer:
        #     if answer.lower() in content.lower():
        #         return option
        # else: 
        answer = answer.replace("Option", "")
        answer = answer.split(":")[0]
        if option in answer:
            return option
    return None

def match(options, gold_answer, answer):
    if type(answer) is list:
        answer = answer[0]
    pattern = r"\b[A-D]\b|[A-D](?=\s|:)"
    match = re.search(pattern, answer)
    if match is None:
        if options[gold_answer] in answer:
            return True
        else:
            return False
    match = match.group()
    if match == gold_answer:
        return True
    return False

def compare(options, answer_1, answer_2):
    if type(answer_1) is list:
        answer_1 = answer_1[0]
    if type(answer_2) is list:
        answer_2 = answer_2[0]
    # print(answer_1)
    # print(answer_2)
    answer_1 = clean_answer(options, answer_1)
    answer_2 = clean_answer(options, answer_2)
    # print(answer_1)
    # print(answer_2)
    # input()
    if answer_1 is None or answer_2 is None:
        return False
    if answer_1 == answer_2:
        return True
    return False
        

def main():
    parser = ArgumentParser()
    parser.add_argument("--input_file", type=str)
    parser.add_argument("--input_file_2", type=str, default=None)
    
    args = parser.parse_args()
    
    with open("data/viquae/multiple_choice_data.json", "r") as fin:
        dataset = json.load(fin)
        
    if args.input_file_2 is None:
        preds = read_pred_file(args.input_file)
        cnt_correct = 0
        cnt = 0
        for data in dataset:
            data_id = data["id"]
            pred = preds.get(data_id)
            if pred is None:
                continue
            options = data["multiple_choices"]
            gold_answer = data["multiple_choices_answer"]
            flag = match(options, gold_answer, pred)
            if flag:
                cnt_correct += 1
            cnt += 1
        print(f"Accuracy: {cnt_correct / cnt}")
    else:
        preds_1 = read_pred_file(args.input_file)
        preds_2 = read_pred_file(args.input_file_2)
        cnt_conflict = 0
        cnt = 0
        for data in dataset:
            data_id = data["id"]
            pred_1 = preds_1.get(data_id)
            pred_2 = preds_2.get(data_id)
            if pred_1 is None or pred_2 is None:
                continue
            options = data["multiple_choices"]
            flag = compare(options, pred_1, pred_2)
            if not flag:
                cnt_conflict += 1
            cnt += 1
            with open("outputs/inference_time/post_hoc_preprocess.txt", "a+") as fout:
                fout.write(f"{json.dumps({data_id: flag})}\n")
        print(f"Conflict Rate: {cnt_conflict / cnt}")

if __name__ == "__main__":
    main()   