import json
import os
from tqdm import tqdm
from openai import OpenAI
import argparse

def generate_response_api(client, question, system_prompt, model_name):
    response = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question},
        ],
        stream=False
    )
    return response.choices[0].message.content.strip()

def load_existing_ids(output_file):
    if not os.path.exists(output_file):
        return set()
    with open(output_file, "r", encoding="utf-8") as f:
        try:
            data = json.load(f)
            return set(item.get("id") for item in data if isinstance(item, dict) and "id" in item)
        except json.JSONDecodeError:
            return set()

def append_to_json(output_file, record):
    if not os.path.exists(output_file):
        with open(output_file, "w", encoding="utf-8") as f:
            json.dump([record], f, ensure_ascii=False, indent=2)
    else:
        with open(output_file, "r+", encoding="utf-8") as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError:
                data = []
            data.append(record)
            f.seek(0)
            f.truncate()
            json.dump(data, f, ensure_ascii=False, indent=2)

def inference_api(args):
    api_key = os.environ.get("OPENAI_API_KEY")
    base_url = os.environ.get("OPENAI_BASE_URL")
    timeout = int(os.environ.get("OPENAI_TIMEOUT", "120"))
    model_name = os.environ.get("OPENAI_MODEL", "gpt-4o")

    client = OpenAI(api_key=api_key, base_url=base_url, timeout=timeout)

    with open(args.data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    finished_ids = load_existing_ids(args.output_file)

    system_prompt = "You are a mathematical reasoning assistant. Now solve the following question. The final answer should be marked with \\boxed{}"

    for item in tqdm(data):
        if item.get("id") in finished_ids:
            continue
        try:
            prediction = generate_response_api(client, item.get("question", ""), system_prompt, model_name)
        except Exception as e:
            prediction = f"[ERROR] {e}"
        record = {
            "id": item.get("id"),
            "query": item.get("question", ""),
            "answer": item.get("answer", ""),
            "prediction": prediction
        }
        append_to_json(args.output_file, record)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", required=True)
    parser.add_argument("--output-file", required=True)
    args = parser.parse_args()
    inference_api(args)
