#!/usr/bin/env python3
"""
Process a JSONL dataset of reasoning problems and generate progressively easier versions
using the OpenAI Responses API.

Features:
- Async with bounded concurrency
- Robust retries with exponential backoff on 429/5xx
- Resume-safe: skips items already completed in the output JSONL
- Per-item logs and optional JSON extraction
- Clean, minimal dependencies (only openai + tqdm)

Input format (JSONL), one object per line:
{"instruction": "...", "input": "...", "output": "..."}

Output format (JSONL), one object per line:
{"id": <zero-based index>, "original": {...}, "model": "...", "response_raw": "<markdown + JSON blocks>", "timestamp": "..."}
Optionally extracts JSON blocks to `outputs/json_blocks/<id>_v*.json`.
"""

import argparse
import asyncio
import json
import os
import re
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

from openai import AsyncOpenAI, RateLimitError, APIStatusError, BadRequestError # pip install openai>=1.40
from tqdm.asyncio import tqdm as tqdm_async

# META_PROMPT = """You are an expert at reasoning question simplification. 
# I will provide you with a reasoning problem in JSON format that contains:
# - "instruction": the solving instruction
# - "input": the context and question
# - "output": the reasoning chain and final answer

# Your task is to automatically generate a *progressive difficulty ladder* of simplified versions of this problem. 
# Each new version should make the reasoning easier by moving more intermediate conclusions (from the reasoning steps in the output) directly into the input context. 
# Stop when the problem has become trivial (e.g., the final hypothesis is already in the input).

# Guidelines:
# 1. Identify all intermediate conclusions (int1, int2, …) in the reasoning chain.
# 2. Create Version 1 as the original (no added intermediates).
# 3. Then generate subsequent versions, each time inserting one or more intermediates into the input.
# 4. You may decide the number of versions automatically — fewer if the chain is short, more if it is long.
# 5. For each version, output in **a fenced JSON code blocks** with the same keys:
#    - "instruction"
#    - "input"
#    - "output"
# 6. Precede each block with a Markdown label like:
#    ## Version N — [difficulty descriptor]
#    Then immediately follow with:
#    ```json
#    { ... }
#    ```

# Goal: produce a set of progressively easier problems, where the solver needs fewer reasoning steps at each level.
# """

META_PROMPT="""
You are an expert at reasoning question simplification.
I will provide you with a reasoning problem in JSON format that contains:

- "instruction": the solving instruction
- "input": the context and question
- "output": the reasoning chain and final answer

Your task is to automatically generate a progressive difficulty ladder of simplified versions of this problem.
Each new version should make the reasoning easier by moving more intermediate conclusions (from the reasoning steps in the output) directly into the input context.
Stop when the problem has become trivial (e.g., the final hypothesis is already in the input).
For each version, also output the minimum number of reasoning steps required to reach the final answer from that version’s input.
Treat a reasoning step as a necessary inferential move that derives a new statement from previous facts/conclusions (e.g., one arithmetic operation, one logical implication, one factual lookup from the provided context).
Count merged paraphrases/restatements as 0 additional steps; do not double-count trivially equivalent rewrites.
When multiple independent sub-derivations are needed before a final combination, count each indispensable sub-derivation as one step.
The count must be a non-negative integer; use 0 for a trivial version where the answer is directly stated in the input.
Ensure monotonic non-increase across versions (later versions should never require more steps than earlier ones).

Guidelines:

1. Identify all intermediate conclusions (int1, int2, …) in the original reasoning chain.
2. Create Version 1 as the original (no added intermediates).
3. Then generate subsequent versions, each time inserting one or more intermediates into the input.
4. You may decide the number of versions automatically — fewer if the chain is short, more if it is long.
5. For each version, output in a fenced JSON code block with the following keys:
   - "instruction"
   - "input"
   - "answer" (string, the final answer to the problem)
   - "reasoning" (string, the reasoning chain leading to the answer)
   - "min_steps" (integer, the minimum number of steps to reach the answer)
   - "min_steps_note" (a short explanation explaining the count)
6. Precede each block with a Markdown label like:
   ## Version N — [difficulty descriptor]
   Then immediately follow with:
   ```json
   { ... }
   ```

Goal: produce a set of progressively easier problems, where the solver needs fewer reasoning steps at each level, and report the minimum required steps for each version.
"""

# Regex to extract fenced JSON code blocks (```json ... ```)
JSON_BLOCK_RE = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)


def load_jsonl(path: Path) -> List[Dict[str, Any]]:
    """
    Load problems from either .jsonl (one JSON object per line)
    or .json (a single list of objects).
    """
    if path.suffix.lower() == ".jsonl":
        items = []
        with path.open("r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                items.append(json.loads(line))
        return items

    elif path.suffix.lower() == ".json":
        with path.open("r", encoding="utf-8") as f:
            data = json.load(f)
            if isinstance(data, list):
                return data
            else:
                raise ValueError(f"Expected a list in {path}, got {type(data)}")

    else:
        raise ValueError(f"Unsupported file extension: {path.suffix}")


def append_jsonl(path: Path, obj: Dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")


def load_completed_ids(path: Path) -> set:
    done = set()
    if not path.exists():
        return done
    with path.open("r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            try:
                obj = json.loads(line)
                if "id" in obj:
                    done.add(int(obj["id"]))
            except Exception:
                # Skip partial/corrupt lines; you can clean later
                pass
    return done

def model_supports_temperature(model_name: str) -> bool:
    # Keep this conservative; let the runtime fallback handle surprises
    blocked_substrings = ["o4", "o3", "o1", "omni", "reasoning"]
    return not any(s in model_name.lower() for s in blocked_substrings)


async def call_model(
    client: AsyncOpenAI,
    model: str,
    meta_prompt: str,
    problem: Dict[str, Any],
    max_output_tokens: int,
    temperature: Optional[float],  # make optional
    reasoning: Optional[Dict[str, Any]] = None,
) -> str:
    """
    Calls the Responses API. Returns text content.
    Automatically falls back if 'temperature' or 'reasoning' is unsupported.
    """
    user_payload = {
        "instruction": problem["instruction"],
        "input": problem["input"],
        "output": problem["output"],
    }
    user_message = json.dumps(user_payload, ensure_ascii=False)

    def build_kwargs(include_temperature: bool = True, include_reasoning: bool = True) -> Dict[str, Any]:
        kwargs: Dict[str, Any] = {
            "model": model,
            "input": [
                {"role": "system", "content": meta_prompt},
                {"role": "user", "content": user_message},
            ],
            "max_output_tokens": max_output_tokens,
        }
        if include_temperature and temperature is not None:
            kwargs["temperature"] = temperature
        if include_reasoning and reasoning:
            kwargs["reasoning"] = reasoning
        return kwargs

    # First try with all params
    try:
        resp = await client.responses.create(**build_kwargs(include_temperature=True, include_reasoning=True))
    except BadRequestError as e:
        msg = (getattr(e, "message", None) or str(e)).lower()
        # Retry logic for unsupported params
        if "unsupported parameter" in msg and "temperature" in msg:
            # Retry without temperature
            resp = await client.responses.create(**build_kwargs(include_temperature=False, include_reasoning=True))
        elif "unsupported parameter" in msg and "reasoning" in msg:
            # Retry without reasoning
            resp = await client.responses.create(**build_kwargs(include_temperature=True, include_reasoning=False))
        elif "unsupported parameter" in msg and "max_output_tokens" in msg:
            # Some models might use different token knobs; try without it
            resp = await client.responses.create(**build_kwargs(include_temperature=True, include_reasoning=True) | {"max_output_tokens": None})
        else:
            # Last resort: strip both temperature and reasoning and try once
            try:
                resp = await client.responses.create(**build_kwargs(include_temperature=False, include_reasoning=False))
            except Exception:
                raise  # give up

    # ---- Extract text robustly across SDK/model variants ----
    # Prefer the convenience field when available.
    if getattr(resp, "output_text", None):
        return resp.output_text.strip()

    out_chunks = []

    # Some SDK versions may not populate `output`; guard accordingly.
    output_items = getattr(resp, "output", None) or []

    for item in output_items:
        # Case 1: plain text chunk (e.g., ResponseOutputText)
        text = getattr(item, "text", None)
        if isinstance(text, str) and text.strip():
            out_chunks.append(text.strip())
            continue

        # Case 2: message-style chunk (ResponseOutputMessage) -> has `.content` list
        content = getattr(item, "content", None)
        if isinstance(content, list):
            for part in content:
                # Parts often have .type ('output_text'/'input_text'/'text') and .text
                ptext = getattr(part, "text", None)
                if isinstance(ptext, str) and ptext.strip():
                    out_chunks.append(ptext.strip())
                else:
                    # Fallback: some parts may have `content` or other text-like fields
                    alt = getattr(part, "content", None)
                    if isinstance(alt, str) and alt.strip():
                        out_chunks.append(alt.strip())

    return "\n".join(out_chunks).strip()


async def worker_one(
    semaphore: asyncio.Semaphore,
    client: AsyncOpenAI,
    model: str,
    problem: Dict[str, Any],
    problem_id: int,
    out_path: Path,
    max_output_tokens: int,
    temperature: Optional[float],
    extract_json_blocks: bool,
    json_blocks_dir: Path,
    max_retries: int,
    base_backoff: float,
    reasoning_cfg: Optional[Dict[str, Any]],
):
    """
    Process exactly one problem (problem_id) under concurrency control.
    """
    await semaphore.acquire()
    try:
        attempt = 0
        while True:
            try:
                text = await call_model(
                    client=client,
                    model=model,
                    meta_prompt=META_PROMPT,
                    problem=problem,
                    max_output_tokens=max_output_tokens,
                    temperature=temperature,
                    reasoning=reasoning_cfg,
                )
                break
            except RateLimitError:
                attempt += 1
                if attempt > max_retries:
                    raise
                await asyncio.sleep(base_backoff * (2 ** (attempt - 1)))
            except APIStatusError as e:
                if 500 <= e.status_code < 600:
                    attempt += 1
                    if attempt > max_retries:
                        raise
                    await asyncio.sleep(base_backoff * (2 ** (attempt - 1)))
                else:
                    raise

        record = {
            "id": problem_id,
            "original": problem,
            "model": model,
            "response_raw": text,
            "timestamp": datetime.utcnow().isoformat() + "Z",
        }
        append_jsonl(out_path, record)

        if extract_json_blocks:
            blocks = JSON_BLOCK_RE.findall(text)
            if blocks:
                json_blocks_dir.mkdir(parents=True, exist_ok=True)
                for vi, block in enumerate(blocks, start=1):
                    try:
                        parsed = json.loads(block)
                        with (json_blocks_dir / f"{problem_id}_v{vi}.json").open("w", encoding="utf-8") as f:
                            json.dump(parsed, f, ensure_ascii=False, indent=2)
                    except Exception:
                        with (json_blocks_dir / f"{problem_id}_v{vi}.raw.txt").open("w", encoding="utf-8") as f:
                            f.write(block)
    finally:
        semaphore.release()


async def main():
    parser = argparse.ArgumentParser(description="Progressive simplification runner")
    parser.add_argument("--input", required=True, help="Path to input JSONL")
    parser.add_argument("--output", required=True, help="Path to output JSONL")
    parser.add_argument("--model", default="gpt-5-mini", help="OpenAI model name")
    parser.add_argument("--concurrency", type=int, default=8, help="Max parallel requests")
    parser.add_argument("--max-output-tokens", type=int, default=8192)
    parser.add_argument("--temperature", type=float, default=None)
    parser.add_argument("--max-retries", type=int, default=6)
    parser.add_argument("--base-backoff", type=float, default=2.0)
    parser.add_argument(
        "--extract-json-blocks",
        action="store_true",
        help="Extract ```json code blocks into outputs/json_blocks/",
    )
    parser.add_argument(
        "--reasoning-effort",
        choices=["none", "low", "medium", "high"],
        default="none",
        help="Optional reasoning config for supported models",
    )
    args = parser.parse_args()

    in_path = Path(args.input)
    out_path = Path(args.output)
    json_blocks_dir = Path(args.output).parent / "json_blocks"

    if not os.getenv("OPENAI_API_KEY"):
        print("ERROR: Please set OPENAI_API_KEY in your environment.", file=sys.stderr)
        sys.exit(1)

    problems = load_jsonl(in_path)
    completed = load_completed_ids(out_path)
    todo = []
    index_map = []  # positions to keep ordering
    for i, p in enumerate(problems):
        if i not in completed:
            todo.append(p)
            index_map.append(i)

    print(f"Total items: {len(problems)} | Already done: {len(completed)} | To do: {len(todo)}")

    if not todo:
        print("Nothing to do.")
        return

    client = AsyncOpenAI()

    # Optional reasoning config (some models support it)
    reasoning_cfg = None
    if args.reasoning_effort != "none":
        reasoning_cfg = {"effort": args.reasoning_effort}

    sem = asyncio.Semaphore(args.concurrency)

    temp_arg = args.temperature if (args.temperature is not None and model_supports_temperature(args.model)) else None
        
    tasks = [
        worker_one(
            semaphore=sem,
            client=client,
            model=args.model,
            problem=todo[k],
            problem_id=index_map[k],             # keep original line index as ID
            out_path=out_path,
            max_output_tokens=args.max_output_tokens,
            temperature=temp_arg,        # may be None; call_model handles fallback
            extract_json_blocks=args.extract_json_blocks,
            json_blocks_dir=json_blocks_dir,
            max_retries=args.max_retries,
            base_backoff=args.base_backoff,
            reasoning_cfg=reasoning_cfg,
        )
        for k in range(len(todo))
    ]

    # tqdm.asyncio manages progress; no manual pbar updates needed
    await tqdm_async.gather(*tasks, total=len(tasks), desc="Processing", unit="problem")



if __name__ == "__main__":
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        print("\nInterrupted by user, exiting…")
