import json
import re
import os
import time
from openai import OpenAI
import random

data_file = "datasets/gsm8k_test.jsonl"
with open(data_file, "r", encoding="utf-8") as f:
    data = [json.loads(line) for line in f]
data = random.sample(data, 1000)

# Function to parse the final answer after '#### '
def parse_final_answer(answer_str):
    match = re.search(r"####\s*(.*)", answer_str)
    if match:
        return match.group(1).strip()
    return None

def strip_code_block_markers(code_str):
    # Remove triple backticks and optional language hints
    code_str = re.sub(r"^```[a-zA-Z]*\n", "", code_str)
    code_str = re.sub(r"```$", "", code_str)
    return code_str.strip()

def test_generated_code(code_str):
    """
    Executes the generated code and checks for errors.
    Returns (True, None) if success, (False, error_message) if error.
    """
    local_vars = {}
    try:
        exec(code_str, {}, local_vars)
        return True, None
    except Exception as e:
        return False, str(e)

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

results = []

for i, entry in enumerate(data):
    question = entry.get("question", "")
    answer = entry.get("answer", "")
    final_answer = parse_final_answer(answer)

    print(f"Question: {question}")
    print(f"Answer: {answer}")
    print(f"Final answer: {final_answer}")

    prompt = f"""
Given the following math word problem and its full worked solution, write a Python function that solves the problem for all possible initial values (i.e., make the numbers in the problem parameters). The function should take the relevant parameters as arguments and return the answer. Do not use hardcoded values from the example; generalize the solution. Please ensure that the function throws an exception when it receives improper parameter values.

Problem:
{question}

Full Solution:
{answer}

Write only the Python function code. Then write a function call with the initial values from the problem description and assert the answer is equal to the ground truth answer.
"""

    print(f"Prompt: {prompt}")

    while True: 
        try:
            completion = client.chat.completions.create(
                model="o3",
                messages=[
                    {"role": "developer", "content": "You are a helpful assistant that writes Python code for math problems."},
                    {"role": "user", "content": prompt}
                ],
            )
            code = completion.choices[0].message.content.strip()
            print(f"Code: {code}")
            break
        except Exception as e:
            code = f"Error: {e}"
            print(f"Error: {e}")
            time.sleep(1)

    # Remove code block markers if present
    code_clean = strip_code_block_markers(code)

    # Test the generated code
    ok, err = test_generated_code(code_clean)
    code_test_result = "success" if ok else f"error: {err}"

    results.append({
        "question": question,
        "answer": answer,
        "parsed_final_answer": final_answer,
        "generated_code": code,
        "code_clean": code_clean,
        "code_test_result": code_test_result
    })

    print(f"Processed {i+1}/1 entries. Code test result: {code_test_result}")
    time.sleep(1)  # To avoid rate limits

    while True:
        try:
            with open("datasets/gsm8k_test_with_code_1000.jsonl", "a", encoding="utf-8") as f:
                json.dump(results[-1], f, ensure_ascii=False)  # no indent
                f.write("\n")
            break
        except Exception as e:
            print(f"Error: {e}")
            time.sleep(10)

print("Saved results to datasets/gsm8k_test_with_code_1000.jsonl")
