#!/usr/bin/env python3
"""
Batch version of your simplification runner.

What it does
------------
1) Reads input problems (JSONL/JSON).
2) Builds a batch-input JSONL (one API request per line) targeting either:
     - /v1/responses   (preferred)  OR
     - /v1/chat/completions (fallback)
3) Uploads the batch file and creates a Batch job (24h window).
4) Polls until completion, downloads the results JSONL.
5) Writes your output JSONL with records:
   {"id", "original", "model", "response_raw", "timestamp"}

It also optionally extracts ```json code blocks to outputs/json_blocks/<id>_v*.json
just like your original script.

Usage
-----
python batch_runner.py \
  --input data/in.jsonl \
  --output outputs/out.jsonl \
  --model gpt-5-mini \
  --endpoint responses \
  --max-output-tokens 2000 \
  --temperature 0.2 \
  --extract-json-blocks

Notes
-----
- Batch is ~50% cheaper than sync and can take up to 24h. Results come back in
  an output file; order is not guaranteed. We match using custom_id.
- endpoint choices: "responses" (=> /v1/responses) or "chat" (=> /v1/chat/completions)
- Requires: openai>=1.40
"""

import argparse
import json
import os
import re
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

from openai import OpenAI  # pip install openai>=1.40

ENTAILMENT_BANK_META_PROMPT = """
You are an expert at reasoning question simplification.
I will provide you with a reasoning problem in JSON format that contains:

- "instruction": the solving instruction
- "input": the context and question
- "output": the reasoning chain and final answer

Your task is to automatically generate a progressive difficulty ladder of simplified versions of this problem.
Each new version should make the reasoning easier by moving more intermediate conclusions (from the reasoning steps in the output) directly into the input context.
Stop when the problem has become trivial (e.g., the final hypothesis is already in the input).
For each version, also output the minimum number of reasoning steps required to reach the final answer from that version’s input.
Treat a reasoning step as a necessary inferential move that derives a new statement from previous facts/conclusions (e.g., one arithmetic operation, one logical implication, one factual lookup from the provided context).
Count merged paraphrases/restatements as 0 additional steps; do not double-count trivially equivalent rewrites.
When multiple independent sub-derivations are needed before a final combination, count each indispensable sub-derivation as one step.
The count must be a non-negative integer; use 0 for a trivial version where the answer is directly stated in the input.
Ensure monotonic non-increase across versions (later versions should never require more steps than earlier ones).

Guidelines:

1. Identify all intermediate conclusions (int1, int2, …) in the original reasoning chain.
2. Create Version 1 as the original (no added intermediates).
3. Then generate subsequent versions, each time inserting one or more intermediates into the input.
4. You may decide the number of versions automatically — fewer if the chain is short, more if it is long.
5. For each version, output in a fenced JSON code block with the following keys:
   - "instruction"
   - "input"
   - "answer" (string, the final answer to the problem)
   - "reasoning" (string, the reasoning chain leading to the answer)
   - "min_steps" (integer, the minimum number of steps to reach the answer)
   - "min_steps_note" (a short explanation explaining the count)
6. Precede each block with a Markdown label like:
   ## Version N — [difficulty descriptor]
   Then immediately follow with:
   ```json
   { ... }
   ```

Goal: produce a set of progressively easier problems, where the solver needs fewer reasoning steps at each level, and report the minimum required steps for each version.
"""

GSM_META_PROMPT = """
You are an expert at math word problem simplification.
I will provide you with a math problem in JSON format that contains:

- "question": the text of the problem
- "answer": the worked-out reasoning and final numeric answer

Your task is to automatically generate a *progressive difficulty ladder* of simplified versions of this problem.
Each new version should make the reasoning easier by moving more intermediate results (from the solution steps in the answer) directly into the problem statement.
Stop when the problem has become trivial (e.g., the final numeric answer is already stated in the problem).

For each version, also output the **minimum number of reasoning steps** required to reach the final answer from that version’s problem statement.
- Treat a *reasoning step* as a necessary mathematical operation or logical inference (e.g., one arithmetic operation, one fraction simplification, one comparison).
- Do not double-count trivial rewrites or restatements.
- When multiple sub-calculations are required before combining, count each indispensable sub-calculation as one step.
- The count must be a non-negative integer; use **0** when the answer is already stated in the problem.
- Ensure the counts are **monotonic non-increasing** across versions (later versions should never require more steps than earlier ones).

Guidelines:

1. Identify all intermediate results (e.g., partial sums, multiplications, divisions) in the original worked-out solution.
2. Create **Version 1** as the original (no added intermediates).
3. Then generate subsequent versions, each time inserting one or more intermediate results directly into the problem statement.
4. You may decide the number of versions automatically — fewer if the chain is short, more if it is long.
5. For each version, output in a fenced JSON code block with the following keys:
   - "question" (string, the modified problem statement)
   - "answer" (string, the final numeric answer only)
   - "reasoning" (string, the reasoning steps leading to the answer)
   - "min_steps" (integer, the minimum number of steps required)
   - "min_steps_note" (short explanation for the count)
6. Precede each block with a Markdown label like:
   ## Version N — [difficulty descriptor]
   Then immediately follow with:
   ```json
   { ... }
   ```

Goal: produce a set of progressively easier GSM8K problems, where the solver needs fewer reasoning steps at each level, and report the minimum required steps for each version.
"""

# Regex to extract fenced JSON code blocks (```json ... ```)
JSON_BLOCK_RE = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)


def load_jsonl(path: Path) -> List[Dict[str, Any]]:
    if path.suffix.lower() == ".jsonl":
        items = []
        with path.open("r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                items.append(json.loads(line))
        return items
    elif path.suffix.lower() == ".json":
        with path.open("r", encoding="utf-8") as f:
            data = json.load(f)
            if isinstance(data, list):
                return data
            raise ValueError(f"Expected a list in {path}, got {type(data)}")
    else:
        raise ValueError(f"Unsupported file extension: {path.suffix}")


def append_jsonl(path: Path, obj: Dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")


def load_completed_ids(path: Path) -> set:
    done = set()
    if not path.exists():
        return done
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            try:
                obj = json.loads(line)
                if "id" in obj:
                    done.add(int(obj["id"]))
            except Exception:
                pass
    return done


def model_supports_temperature(model_name: str) -> bool:
    blocked_substrings = ["o4", "o3", "o1", "omni", "reasoning"]
    return not any(s in model_name.lower() for s in blocked_substrings)


def build_batch_line_responses(
    *, custom_id: str, model: str, meta_prompt: str, user_payload: Dict[str, Any],
    max_output_tokens: Optional[int], temperature: Optional[float], reasoning_effort: Optional[str]
) -> Dict[str, Any]:
    """One JSON line targeting /v1/responses."""
    body: Dict[str, Any] = {
        "model": model,
        "input": [
            {"role": "system", "content": meta_prompt},
            {"role": "user", "content": json.dumps(user_payload, ensure_ascii=False)},
        ],
    }
    if max_output_tokens is not None:
        body["max_output_tokens"] = max_output_tokens
    if temperature is not None:
        body["temperature"] = temperature
    if reasoning_effort and reasoning_effort != "none":
        body["reasoning"] = {"effort": reasoning_effort}

    return {
        "custom_id": custom_id,
        "method": "POST",
        "url": "/v1/responses",
        "body": body,
    }


def build_batch_line_chat(
    *, custom_id: str, model: str, meta_prompt: str, user_payload: Dict[str, Any],
    max_output_tokens: Optional[int], temperature: Optional[float]
) -> Dict[str, Any]:
    """One JSON line targeting /v1/chat/completions."""
    body: Dict[str, Any] = {
        "model": model,
        "messages": [
            {"role": "system", "content": meta_prompt},
            {"role": "user", "content": json.dumps(user_payload, ensure_ascii=False)},
        ],
    }
    if temperature is not None:
        body["temperature"] = temperature
    # Chat Completions uses "max_tokens" instead of "max_output_tokens"
    if max_output_tokens is not None:
        body["max_tokens"] = max_output_tokens

    return {
        "custom_id": custom_id,
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": body,
    }


def extract_text_from_responses_body(resp_body: Dict[str, Any]) -> str:
    """
    Robustly extract text from a /v1/responses output (batch line -> response.body).
    """
    # Convenience field sometimes exists:
    if "output_text" in resp_body and isinstance(resp_body["output_text"], str):
        return resp_body["output_text"].strip()

    # Otherwise, stitch together output parts.
    out_chunks = []
    output_items = resp_body.get("output", []) or []
    for item in output_items:
        # plain text
        if isinstance(item, dict) and isinstance(item.get("text"), str):
            t = item["text"].strip()
            if t:
                out_chunks.append(t)
        # message-like with content list
        content = item.get("content")
        if isinstance(content, list):
            for part in content:
                if isinstance(part, dict):
                    if isinstance(part.get("text"), str) and part["text"].strip():
                        out_chunks.append(part["text"].strip())
                    elif isinstance(part.get("content"), str) and part["content"].strip():
                        out_chunks.append(part["content"].strip())
    return "\n".join(out_chunks).strip()


def extract_text_from_chat_body(resp_body: Dict[str, Any]) -> str:
    """
    Extract text from a /v1/chat/completions output (batch line -> response.body).
    """
    try:
        return resp_body["choices"][0]["message"]["content"].strip()
    except Exception:
        return json.dumps(resp_body)


def main():
    p = argparse.ArgumentParser(description="Batch Progressive Simplification")
    p.add_argument("--input", required=True, help="Path to input JSONL/JSON")
    p.add_argument("--output", required=True, help="Path to output JSONL")
    p.add_argument("--model", default="gpt-5-mini", help="Model name")
    p.add_argument("--endpoint", choices=["responses", "chat"], default="responses",
                   help="Batch endpoint to use: 'responses' -> /v1/responses, 'chat' -> /v1/chat/completions")
    p.add_argument("--max-output-tokens", type=int, default=8192)
    p.add_argument("--temperature", type=float, default=None)
    p.add_argument("--reasoning-effort", choices=["none", "low", "medium", "high"], default="none",
                   help="Only for responses endpoint & supported models")
    p.add_argument("--extract-json-blocks", action="store_true",
                   help="Extract ```json code blocks into outputs/json_blocks/")
    p.add_argument("--resume", action="store_true",
                   help="Skip IDs already present in --output (resume-safe)")
    p.add_argument("--poll-interval", type=float, default=10.0, help="Seconds between status checks")
    p.add_argument("--completion-window", choices=["24h"], default="24h")
    p.add_argument("--workdir", default="openai_api_workdir", help="Folder to write batch files/results")
    args = p.parse_args()

    in_path = Path(args.input)
    out_path = Path(args.output)
    workdir = Path(args.workdir)
    workdir.mkdir(parents=True, exist_ok=True)
    json_blocks_dir = out_path.parent / "json_blocks"

    if not os.getenv("OPENAI_API_KEY"):
        print("ERROR: Please set OPENAI_API_KEY.", file=sys.stderr)
        sys.exit(1)

    problems = load_jsonl(in_path)

    # Resume support: figure out which IDs to skip
    completed = load_completed_ids(out_path) if args.resume else set()
    todo_indices = [i for i in range(len(problems)) if i not in completed]
    if not todo_indices:
        print("Nothing to do (everything already in output).")
        return

    # Build batch-input JSONL
    batch_lines: List[Dict[str, Any]] = []
    temp_arg = args.temperature if (args.temperature is not None and model_supports_temperature(args.model)) else None
    dataset_name = 'gsm' if 'gsm' in args.input.lower() else 'entailment_bank'
    print(f"Loaded {len(problems)} problems from {in_path}, {len(todo_indices)} to do (dataset={dataset_name}).")
    META_PROMPT = GSM_META_PROMPT if dataset_name == 'gsm' else ENTAILMENT_BANK_META_PROMPT

    for i in todo_indices:
        problem = problems[i]
        if dataset_name == 'gsm':
            user_payload = {
                "question": problem["question"],
                "answer": problem["answer"],
            }
        else:
            user_payload = {
                "instruction": problem["instruction"],
                "input": problem["input"],
                "output": problem["output"],
            }
        custom_id = f"prob-{i}"
        if args.endpoint == "responses":
            line = build_batch_line_responses(
                custom_id=custom_id,
                model=args.model,
                meta_prompt=META_PROMPT,
                user_payload=user_payload,
                max_output_tokens=args.max_output_tokens,
                temperature=temp_arg,
                reasoning_effort=args.reasoning_effort,
            )
        else:
            line = build_batch_line_chat(
                custom_id=custom_id,
                model=args.model,
                meta_prompt=META_PROMPT,
                user_payload=user_payload,
                max_output_tokens=args.max_output_tokens,
                temperature=temp_arg,
            )
        batch_lines.append(line)

    batch_in_path = workdir / f"batch_input_{int(time.time())}.jsonl"
    with batch_in_path.open("w", encoding="utf-8") as f:
        for obj in batch_lines:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

    client = OpenAI()

    # Upload the batch file
    batch_file = client.files.create(file=open(batch_in_path, "rb"), purpose="batch")
    print(f"Uploaded batch file: {batch_file.id} -> {batch_in_path}")

    # Map CLI choice to actual endpoint path
    endpoint_path = "/v1/responses" if args.endpoint == "responses" else "/v1/chat/completions"

    # Create batch job
    job = client.batches.create(
        input_file_id=batch_file.id,
        endpoint=endpoint_path,
        completion_window=args.completion_window,
    )
    print(f"Created batch: {job.id} | status={job.status}")

    # Poll
    while True:
        job = client.batches.retrieve(job.id)
        print(f"[{datetime.utcnow().isoformat()}Z] status={job.status}")
        if job.status in ("completed", "failed", "cancelled", "expired"):
            break
        time.sleep(args.poll_interval)

    if job.status != "completed":
        print(f"Batch job {job.id} finished with status '{job.status}'. Partial results may exist.", file=sys.stderr)

    # Download results (if available)
    if not job.output_file_id:
        print("No output_file_id on batch job. Exiting.", file=sys.stderr)
        return

    result_bytes = client.files.content(job.output_file_id).content
    batch_result_path = workdir / f"batch_output_{job.id}.jsonl"
    with batch_result_path.open("wb") as f:
        f.write(result_bytes)
    print(f"Wrote batch results: {batch_result_path}")

    # Parse results JSONL and write your output JSONL
    # Each line has: {"custom_id": "...", "response": {"status_code": 200, "body": {...}}, ...}
    custom_to_result: Dict[int, str] = {}

    with batch_result_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            cid = obj.get("custom_id", "")
            if not cid.startswith("prob-"):
                continue
            try:
                index = int(cid.split("-")[1])
            except Exception:
                continue

            response = obj.get("response", {})
            body = response.get("body", {}) if isinstance(response, dict) else {}

            if args.endpoint == "responses":
                text = extract_text_from_responses_body(body)
            else:
                text = extract_text_from_chat_body(body)

            custom_to_result[index] = text

    # Write output JSONL appending only new IDs, preserving your schema
    for i in sorted(custom_to_result.keys()):
        record = {
            "id": i,
            "original": problems[i],
            "model": args.model,
            "response_raw": custom_to_result[i],
            "timestamp": datetime.utcnow().isoformat() + "Z",
        }
        append_jsonl(out_path, record)

        if args.extract_json_blocks:
            blocks = JSON_BLOCK_RE.findall(custom_to_result[i])
            if blocks:
                json_blocks_dir.mkdir(parents=True, exist_ok=True)
                for vi, block in enumerate(blocks, start=1):
                    try:
                        parsed = json.loads(block)
                        with (json_blocks_dir / f"{i}_v{vi}.json").open("w", encoding="utf-8") as f:
                            json.dump(parsed, f, ensure_ascii=False, indent=2)
                    except Exception:
                        with (json_blocks_dir / f"{i}_v{vi}.raw.txt").open("w", encoding="utf-8") as f:
                            f.write(block)

    print(f"Done. Wrote {len(custom_to_result)} items to {out_path}")


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nInterrupted by user, exiting…")
