import json
import argparse
import os.path

from metrics import em, parse_answer_gpt, extract_answer, math_equal
from model.inference_model import LocalLLM, LocalLoraLLM, OpenAILLM
from model.inference_model import HuggingLocalLoraLLM, HuggingLocalLLM
from utils.data_split import read_jsonl, write_jsonl
from tqdm import tqdm
from collections import defaultdict as ddict
import multiprocessing as mp
import time

def get_answer(data):
    question = data['question']
    text = llm.run(question)
    return question, text


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='evaluation')

    parser.add_argument("--data_dir", type=str, default="test.jsonl")
    parser.add_argument("--save_path", type=str, default="prediction_file_path")
    parser.add_argument("--score_path", type=str, default="score_file_path")
    parser.add_argument("--acc_path", type=str, default="acc_file_path")

    parser.add_argument("--pretrain_dir", type=str, default="Qwen2.5-7B")
    parser.add_argument("--lora_dir", type=str, default="")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--max_tokens", type=int, default=256)

    args = parser.parse_args()
    print('Args in experiment:')
    print(args)

    data = read_jsonl(args.data_dir)
    if 'test' in args.data_dir:
        data = data

    if not os.path.exists(args.save_path):
        if args.lora_dir != "":
            llm = HuggingLocalLoraLLM(args.pretrain_dir, args.lora_dir, max_tokens=args.max_tokens, temperature=args.am_t)
        else:
            if 'qwen' in args.pretrain_dir.lower():
                llm = HuggingLocalLLM(args.pretrain_dir, max_tokens=args.max_tokens, temperature=args.am_t, do_sample=True) 
            else:
                llm = OpenAILLM(model_name=args.pretrain_dir, max_tokens=args.max_tokens, temperature=args.am_t)

    for i in tqdm(range(0, len(data), args.batch_size)):
        prompts = [data[j]['question'] for j in range(i, min(len(data), i + args.batch_size))]
        texts = llm.run(prompts)
        with open(args.save_path, 'a+', encoding='utf-8') as f:
            for t in texts:
                json_str = {'predict_full': t}
                f.write(json.dumps(json_str) + '\n')

    # pool = mp.Pool(processes=24)
    # for q, text in tqdm(pool.imap(get_answer, data)):
    #     with open(args.save_path, 'a+', encoding='utf-8') as f:
    #         json_str = {'question': q, 'predict_full': text}
    #         f.write(json.dumps(json_str) + '\n')

    correct = 0
    correct_ratio = ddict(int)
    cnt_ratio = ddict(int)
    scores = []
    gts = [{"type": d["type"], "ground_truth": extract_answer(d['answer'])} for d in data]
    predicts = [extract_answer(d['predict_full']) for d in read_jsonl(args.save_path)]
    for p, g in zip(predicts, gts):
        cnt_ratio[g["type"]] += 1
        tmp_score = 0
        if p and math_equal(p, g["ground_truth"]):
            tmp_score = 1
            correct += 1
            correct_ratio[g["type"]] += 1
        scores.append({'score': tmp_score})

    acc_str = f"Accuracy: {correct} / {len(data)} = {correct / len(data)}\n"
    for k in correct_ratio.keys():
        acc_str += f"{k}: {correct_ratio[k]} / {cnt_ratio[k]} = {correct_ratio[k] / cnt_ratio[k]}\n"

    with open(args.acc_path, 'a+', encoding='utf-8') as f:
        f.write("---------------------------------\n")
        f.write(args.lora_dir + '\n' + args.save_path + '\n' + args.data_dir + '\n' + acc_str + '\n\n')

    print(args.save_path)
    print(acc_str)
    write_jsonl(scores, args.score_path)
    

    print("done")
