import json
import argparse
from metrics import em, parse_answer, f1_score
from utils import read_jsonl, write_jsonl
from inference_model import OpenAILLM, HuggingLocalLoraLLM, HuggingLocalLLM
from tqdm import tqdm
import multiprocessing as mp
import time

def get_answer(data):
    question = data['question']
    text = llm.run(question)
    return question, text

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='evaluation')

    parser.add_argument("--data_dir", type=str, default="test.jsonl")
    parser.add_argument("--save_path", type=str, default="prediction_file_path")
    parser.add_argument("--score_path", type=str, default="score_file_path")
    parser.add_argument("--acc_path", type=str, default="acc_file_path")

    parser.add_argument("--pretrain_dir", type=str, default="Qwen2.5-7B")
    parser.add_argument("--lora_dir", type=str, default="")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--max_tokens", type=int, default=256)

    args = parser.parse_args()
    print('Args in experiment:')
    print(args)

    data_path = args.data_dir
    data = read_jsonl(data_path)

    if args.lora_dir == "":
        llm = HuggingLocalLLM(args.pretrain_dir, max_tokens=32)
    else:
        llm = HuggingLocalLoraLLM(args.pretrain_dir, args.lora_dir, max_tokens=32)

    for i in tqdm(range(0, len(data), args.batch_size)):
        prompts = [data[j]['question'] for j in range(i, min(len(data), i + args.batch_size))]
        texts = llm.run(prompts)
        with open(args.save_path, 'a+', encoding='utf-8') as f:
            for t in texts:
                json_str = {'predict_full': t}
                f.write(json.dumps(json_str) + '\n')

    # pool = mp.Pool(processes=24)
    # for q, text in tqdm(pool.imap(get_answer, data)):
    #     with open(args.save_path, 'a+', encoding='utf-8') as f:
    #         json_str = {'question': q, 'predict_full': text}
    #         f.write(json.dumps(json_str) + '\n')

    em_correct = 0
    f1_correct = 0
    gts = [d['answer'] for d in data]
    predicts = [parse_answer(d['predict_full']) for d in read_jsonl(args.save_path)]
    scores = []
    for p, g in zip(predicts, gts):
        tmp_f1, tmp_em = 0, 0
        if em(p, g):
            tmp_em = 1
            em_correct += 1
        if f1_score(p, g)[0] > 0:
            tmp_f1 = 1
            f1_correct += 1
        scores.append({'score': tmp_f1, 'em_score': tmp_em, 'f1_score': tmp_f1})
    write_jsonl(scores, args.score_path)


    acc_str = (f"EM Acc: {em_correct} / {len(data)} = {em_correct / len(data)}\n"
               f"F1 Acc: {f1_correct} / {len(data)} = {f1_correct / len(data)}")
    with open(args.acc_path, 'a+', encoding='utf-8') as f:
        f.write("---------------------------------\n")
        f.write(args.lora_dir + '\n' + args.save_path + '\n' + args.data_dir + '\n' + acc_str + '\n\n')

    print(args.save_path)
    print(acc_str)
