import os
import json
import argparse
from tqdm import tqdm
from openai import OpenAI


def generate_response_api(client, original, rule, ride, model_name):
    system_prompt = """
You are a rigorous math-problem rewrite judge. Compare two rewrites (A: model-generated, B: rule-generated) of the same original math problem and decide which rewrite is better for robustness testing. Do not solve any problem. Think silently and return the final preference.

Goal
Select the rewrite that best preserves the core skill being tested while introducing meaningful structural perturbations (for robustness), with a problem that is clear, consistent, and fully solvable (preferably single-solution), and with difficulty ≥ the original (no shortcut reasoning).

Required Criteria
1) Completeness & Solvability: The rewritten problem must be fully specified, mathematically consistent, and solvable, preferably with a single correct solution.
2) Core-Concept Consistency: Keep the core concept/skill of the original unchanged; numbers/labels/secondary conditions may change if the core assessment target is equivalent.
3) Clarity & No Leakage: Language must be clear, unambiguous, error-free, with no leaked conclusions, intermediate results, or hints that reduce the intended reasoning burden.
4) Structural Perturbation: Pure surface-level paraphrase without structural disturbance is not a good rewrite for robustness; prefer changes in structure/order/representation that do not alter the core skill or answer.
5) Difficulty Constraint: The new problem’s difficulty must be ≥ the original; it must not introduce shortcuts or reduce the required reasoning depth.

Forbidden
- Changing the problem type, core concept/skill, or final answer relative to the original.
- Inconsistent or conflicting data/units/constraints, or introducing multi-solution/unsolvable setups unless the original had that intent.
- Any answer leakage or revealing key intermediate results.
- Only superficial paraphrasing with no structural perturbation.
- Introducing reasoning shortcuts that lower difficulty.

Tie-breaking priorities
1) Prioritize Completeness & Solvability and Core-Concept Consistency.
2) Then meaningful structural perturbation and no leakage.
3) Then difficulty ≥ original without shortcuts.
4) If both are comparable, output a tie.

Process
- Do not solve the problem or output computations.
- Analyze A and B against the criteria and forbidden list.
- Decide the winner using the tie-breaking priorities.

Inputs
- Original problem:
{ORIGINAL_PROBLEM}

- Rewrite A:
{REWRITE_A}

- Rewrite B:
{REWRITE_B}
""".strip()

    resp = client.chat.completions.create(
        model=model_name,
        messages=[
            {
                "role": "system",
                "content": system_prompt.format(
                    ORIGINAL_PROBLEM=original,
                    REWRITE_A=rule,
                    REWRITE_B=ride,
                ),
            }
        ],
        stream=False,
    )
    return resp.choices[0].message.content.strip()


def load_existing_ids(output_file):
    if not os.path.exists(output_file):
        return set()
    with open(output_file, "r", encoding="utf-8") as f:
        try:
            data = json.load(f)
            return set(item.get("id") for item in data if isinstance(item, dict) and "id" in item)
        except json.JSONDecodeError:
            return set()


def append_to_json(output_file, record):
    if not os.path.exists(output_file):
        with open(output_file, "w", encoding="utf-8") as f:
            json.dump([record], f, ensure_ascii=False, indent=2)
        return
    with open(output_file, "r+", encoding="utf-8") as f:
        try:
            data = json.load(f)
        except json.JSONDecodeError:
            data = []
        data.append(record)
        f.seek(0)
        f.truncate()
        json.dump(data, f, ensure_ascii=False, indent=2)


def run(args):
    api_key = os.environ.get("OPENAI_API_KEY")
    base_url = os.environ.get("OPENAI_BASE_URL")
    timeout_s = int(os.environ.get("OPENAI_TIMEOUT", "12000"))
    model_name = os.environ.get("OPENAI_MODEL", "gpt-5-2025-08-07")

    client_kwargs = {"timeout": timeout_s}
    if api_key:
        client_kwargs["api_key"] = api_key
    if base_url:
        client_kwargs["base_url"] = base_url
    client = OpenAI(**client_kwargs)

    with open(args.rule_path, "r", encoding="utf-8") as f:
        rule_data = json.load(f)
    rule = [it.get("question", "") for it in rule_data]

    with open(args.ride_path, "r", encoding="utf-8") as f:
        ride_data = json.load(f)
    ride = [it.get("question", "") for it in ride_data]

    with open(args.original_path, "r", encoding="utf-8") as f:
        orig_data = json.load(f)
    original = [it.get("question", "") for it in orig_data]

    n = min(len(original), len(rule), len(ride))
    finished_ids = load_existing_ids(args.output_file)

    for i in tqdm(range(n)):
        if i in finished_ids:
            continue
        try:
            prediction = generate_response_api(client, original[i], rule[i], ride[i], model_name)
        except Exception as e:
            prediction = f"[ERROR] {e}"
        record = {
            "id": i,
            "original_question": original[i],
            "rule_question": rule[i],
            "ride_question": ride[i],
            "prediction": prediction,
        }
        append_to_json(args.output_file, record)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--rule-path", required=True)
    parser.add_argument("--ride-path", required=True)
    parser.add_argument("--original-path", required=True)
    parser.add_argument("--output-file", required=True)
    args = parser.parse_args()
    run(args)
