import json
import torch
from tqdm import tqdm
from datasets import load_dataset
from argparse import ArgumentParser

from src.evaluate.overlap import evaluate_overlap
from src.evaluate.model_alignment import evaluate_llama3

def read_pred_file(input_file):
    preds = {}
    with open(input_file, "r") as fin:
        for line in fin.readlines():
            preds.update(json.loads(line))
    return preds

def main():
    parser = ArgumentParser()
    parser.add_argument("--input_file", type=str)
    parser.add_argument("--input_file_2", type=str, default=None)
    parser.add_argument("--method", choices=["overlap", "llama"])
    
    args = parser.parse_args()
    
    if args.input_file_2 is None:
        mode = "accuracy"
        dataset = []
        for data_name in ["train", "validation", "test"]:
            ds = load_dataset("PaulLerner/viquae_dataset")[data_name]
            dataset.extend([d for d in ds])
        preds = read_pred_file(args.input_file)
        
        eval_golds = []
        eval_preds = []
        eval_questions = []
        for data in dataset:
            data_id = data["id"]
            pred = preds.get(data_id)
            if pred is None:
                continue
            eval_questions.append(data["original_question"])
            eval_golds.append(data["output"]["original_answer"])
            if type(pred) is list:
                pred = pred[0]
            eval_preds.append(pred)
        
        if args.method == "overlap":
            accuracy = evaluate_overlap(eval_golds, eval_preds)
        else:
            accuracy = evaluate_llama3(mode, eval_questions, eval_golds, eval_preds)
        print(f"Accuracy: {accuracy}")
    else:
        mode = "conflict"
        dataset = []
        for data_name in ["train", "validation", "test"]:
            ds = load_dataset("PaulLerner/viquae_dataset")[data_name]
            dataset.extend([d for d in ds])
        preds_1 = read_pred_file(args.input_file)
        preds_2 = read_pred_file(args.input_file_2)
        
        eval_preds_1 = []
        eval_preds_2 = []
        eval_questions = []
        for data in dataset:
            data_id = data["id"]
            pred_1 = preds_1.get(data_id)
            pred_2 = preds_2.get(data_id)
            if pred_1 is None or pred_2 is None:
                continue
            eval_questions.append(data["original_question"])
            if type(pred_1) is list:
                pred_1 = pred_1[0]
            if type(pred_2) is list:
                pred_2 = pred_2[0]
            eval_preds_1.append(pred_1)
            eval_preds_2.append(pred_2)
        
        if args.method == "overlap":
            conflict = evaluate_overlap(eval_preds_1, eval_preds_2)
        else:
            conflict = evaluate_llama3(mode, eval_questions, eval_preds_1, eval_preds_2)
        print(f"Conflict: {conflict}")
    
    
    

if __name__ == "__main__":
    main()