import json
from tqdm import tqdm
import fire
from utils.grader import check_is_correct
from utils.parser import extract_answer
from openai import OpenAI

def evaluate_prediction(model_pred: str, gold_answer: str) -> bool:
    extracted_answer = extract_answer(model_pred)
    return check_is_correct(extracted_answer, gold_answer)

def request_openai_api(prompt, client, model_name="DeepSeek-R1-Distill-Llama-8B_0.25", max_tokens=1024):
    system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
    messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt},
    ]
    completion = client.chat.completions.create(model=model_name, messages=messages)
    content = completion.choices[0].message.content
    return content

def compute_pass_at_1(data_path: str, output_path: str, client, model_name: str = "DeepSeek-R1-Distill-Llama-8B_0.25") -> float:
    with open(data_path, 'r') as f:
        dataset = [json.loads(l) for l in f]

    correct = 0
    total = len(dataset)

    with tqdm(total=total, desc="Evaluating", ncols=80) as pbar, open(output_path, 'a+', encoding='utf-8') as fout:
        for i, item in enumerate(dataset, 1):
            gold_answer = item["answer"]
            problem = item["problem"]
            model_pred = request_openai_api(problem, client, model_name=model_name)
            print('model_pred', model_pred)

            is_correct = evaluate_prediction(model_pred, gold_answer)

            result = {
                "problem": problem,
                "gold_answer": gold_answer,
                "model_pred": model_pred,
                "is_correct": is_correct,
            }
            if i == 0:
                print(result)
            fout.write(json.dumps(result, ensure_ascii=False) + '\n')

            if is_correct:
                correct += 1
            current_pass1 = correct / i
            pbar.set_description(f"pass@1: {current_pass1:.2%}")
            pbar.update(1)

    final_pass1 = correct / total if total > 0 else 0.0
    print(f"\nFinal pass@1: {final_pass1:.2%}")
    return final_pass1

def main(data_path: str, output_path: str, api_base: str = "http://localhost:8000/v1", model_name: str = "DeepSeek-R1-Distill-Llama-8B_0.25"):
    client = OpenAI(
        api_key='Dummy',
        base_url=api_base,
        timeout=120000.0,
    )

    compute_pass_at_1(data_path, output_path, client, model_name)

if __name__ == "__main__":
    fire.Fire(main)
