import json
import argparse
from metrics import parse_answer, extract_answers
from model.inference_model import OpenAILLM, HuggingLocalLoraLLM, HuggingLocalLLM
from utils.data_split import read_jsonl, write_jsonl
from tqdm import tqdm
import multiprocessing as mp
import os
import time

def get_answer(data):
    question = data['question']
    text = llm.run(question)
    return question, text


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='evaluation')

    parser.add_argument("--data_dir", type=str, default="test.jsonl")
    parser.add_argument("--save_path", type=str, default="prediction_file_path")
    parser.add_argument("--score_path", type=str, default="score_file_path")
    parser.add_argument("--acc_path", type=str, default="acc_file_path")

    parser.add_argument("--pretrain_dir", type=str, default="Qwen2.5-7B")
    parser.add_argument("--lora_dir", type=str, default="")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--max_tokens", type=int, default=256)


    args = parser.parse_args()
    print('Args in experiment:')
    print(args)

    data_path = args.data_dir
    data = read_jsonl(data_path)

    if not os.path.exists(args.save_path):
        if args.lora_dir != "":
            llm = HuggingLocalLoraLLM(args.pretrain_dir, args.lora_dir, max_tokens=args.max_tokens)
        else:
            if 'qwen' in args.pretrain_dir.lower():
                llm = HuggingLocalLLM(args.pretrain_dir, max_tokens=args.max_tokens)
            else:
                llm = OpenAILLM(model_name=args.pretrain_dir, max_tokens=args.max_tokens)

    correct = 0
    for i in tqdm(range(0, len(data), args.batch_size)):
        prompts = [data[j]['question'] for j in range(i, min(len(data), i + args.batch_size))]
        texts = llm.run(prompts)
        with open(args.save_path, 'a+', encoding='utf-8') as f:
            for t in texts:
                json_str = {'predict_full': t}
                f.write(json.dumps(json_str) + '\n')

    # pool = mp.Pool(processes=24)
    # for q, text in tqdm(pool.imap(get_answer, data)):
    #     with open(args.save_path, 'a+', encoding='utf-8') as f:
    #         json_str = {'question': q, 'predict_full': text}
    #         f.write(json.dumps(json_str) + '\n')

    gts = [parse_answer(d['answer']) for d in data]
    predicts = [extract_answers(d['predict_full']) for d in read_jsonl(args.save_path)]
    scores = []
    for p, g in zip(predicts, gts):
        if p and em(p, g):
            correct += 1
            scores.append({'score': 1})
        else:
            scores.append({'score': 0})

    acc_str = f"Accuracy: {correct} / {len(data)} = {correct / len(data)}"
    with open(args.acc_path, 'a+', encoding='utf-8') as f:
        f.write("---------------------------------\n")
        f.write(args.lora_dir + '\n' + args.save_path + '\n' + args.data_dir + '\n' + acc_str + '\n\n')

    print(args.save_path)
    print(acc_str)
