#!/usr/bin/env python3
"""
Fetch GPT responses for every question in a Parquet file
and append (idx, question, answer) as JSON lines to disk.

append_hint is a boolean flag to append a hint to the question.
hint_text is a string to append to the question.
seed is a random seed to sample the questions.
test is an integer to sample the questions.
workers is an integer to set the number of workers.
model is a string to set the model.
jsonl is a string to set the output file.

Usage:
  OPENAI_API_KEY=... python3 fetch_responses.py \
      --parquet /path/to/file.parquet \
      --jsonl    responses.jsonl \
      --workers  10 \
      --model    gpt-5 \
      --append_hint \
      --seed     42 \
      --test     30
"""
import argparse, json, os, sys, time, re, ast
import numpy as np

# -------- helpers --------
def extract_first_content(value) -> str:
    """Extract the first message.content string from common prompt formats.
    Accepts list/dict objects or stringified versions of them. Falls back to str(value).
    """
    # Direct python/numpy objects
    if isinstance(value, list):
        for item in value:
            if isinstance(item, dict) and "content" in item:
                return str(item["content"])
    if isinstance(value, dict) and "content" in value:
        return str(value["content"])
    if isinstance(value, np.ndarray):
        try:
            return extract_first_content(value.tolist())
        except Exception:
            pass

    # String inputs: try JSON then Python literal
    if isinstance(value, str):
        s = value.strip()
        # Strip possible numpy array wrapper
        # e.g., "array([{'content': '...'}], dtype=object)"
        m = re.match(r"^array\((.*)\)$", s, flags=re.DOTALL)
        if m:
            inner = m.group(1)
            inner = re.sub(r",\s*dtype=object\s*$", "", inner)
            s = inner.strip()
        # JSON
        try:
            parsed = json.loads(s)
            return extract_first_content(parsed)
        except Exception:
            pass
        # Python literal (e.g., single quotes)
        try:
            parsed = ast.literal_eval(s)
            return extract_first_content(parsed)
        except Exception:
            pass
        # Regex fallback: grab content field
        m = re.search(r"[\"']content[\"']\s*:\s*[\"'](.*?)[\"']", s, flags=re.DOTALL)
        if m:
            return m.group(1)
    return str(value)

from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd
import requests
from tenacity import retry, wait_random_exponential, stop_after_attempt, RetryError

# -------- retry-wrapped API call --------
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def ask_gpt(idx: int, original_question: str, input_text: str, model: str, api_key: str, n_per_question: int, allow_fallbacks: bool, provider_order: list[str] | None) -> dict:
    """Return dict for one problem with possibly multiple responses.
    When n_per_question == 1, includes 'response' and 'usage'. Otherwise returns 'responses' and 'usages'.
    """
    responses_list = []
    usages_list = []
    response_ids = []
    created_list = []
    finish_reasons = []

    # Minimal implementation: issue n separate requests
    for _ in range(max(1, int(n_per_question))):
        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        }
        payload = {
            "model": model,
            "messages": [
                {"role": "user", "content": input_text}
            ],
        }
        # Top-level allow_fallbacks supported by OpenRouter
        if allow_fallbacks:
            payload["allow_fallbacks"] = True
        # Optional provider object
        if provider_order:
            payload["provider"] = {"order": provider_order, "allow_fallbacks": bool(allow_fallbacks)}

        resp = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload)
        if resp.status_code != 200:
            # Try to include server message for better retries/visibility
            try:
                err = resp.json()
            except Exception:
                err = {"text": resp.text}
            raise RuntimeError(f"OpenRouter error {resp.status_code}: {err}")

        data = resp.json()
        choice = (data.get("choices") or [{}])[0]
        message = choice.get("message") or {}
        content = message.get("content") or ""
        responses_list.append(content)

        response_ids.append(data.get("id"))
        created_list.append(data.get("created"))
        finish_reasons.append(choice.get("finish_reason"))

        usage = data.get("usage") or {}
        if usage:
            usages_list.append({
                "prompt_tokens": usage.get("prompt_tokens"),
                "completion_tokens": usage.get("completion_tokens"),
                "total_tokens": usage.get("total_tokens"),
            })
        else:
            usages_list.append(None)

    result = {
        "idx": idx,
        "question": input_text,
        "model": model,
    }

    if len(responses_list) == 1:
        result.update({
            "response": responses_list[0],
            "response_id": response_ids[0],
            "created": created_list[0],
            "finish_reason": finish_reasons[0],
        })
        if usages_list and usages_list[0] is not None:
            result["usage"] = usages_list[0]
    else:
        result.update({
            "responses": responses_list,
            "response_ids": response_ids,
            "created_list": created_list,
            "finish_reasons": finish_reasons,
            "usages": usages_list,
        })

    return result

# -------- main workflow --------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--parquet", required=True, help="input .parquet file")
    ap.add_argument("--jsonl",   default="./outputs/responses.jsonl", help="output .jsonl file (append-safe)")
    ap.add_argument("--workers", type=int, default=10, help="#parallel requests")
    ap.add_argument("--model",   default="openai/gpt-4o-mini", help="OpenRouter model name (e.g., openai/gpt-4o-mini)")
    ap.add_argument("--n_per_question", type=int, default=1, help="number of responses per question")
    ap.add_argument("--test", type=int, help="sanity check: only process first N questions")
    ap.add_argument("--seed", type=int, default=42, help="random seed for test sampling")
    ap.add_argument(
        "--allow_fallbacks",
        action="store_true",
        help="allow model/provider fallbacks on OpenRouter",
    )
    ap.add_argument(
        "--provider",
        action="append",
        help="provider name to prefer in order (may repeat). If omitted, no provider object is sent",
    )
    ap.add_argument(
        "--append_hint",
        action="store_true",
        help="append concise boxed-answer hint to each question",
    )
    ap.add_argument(
        "--hint_text",
        type=str,
        default=(
            "Think step by step, and output a clear step-by-step solution with the final answer/result. After the step-by-step reasoning, output only the final "
            "numeric/symbolic result inside \\boxed{...} on the last line. Keep the boxed "
            "content under 30 characters and do not add any text after the box."
        ),
        help="custom hint text to append when --append_hint is set",
    )
    args = ap.parse_args()

    # Prefer OPENROUTER_API_KEY; fallback to OPENAI_API_KEY to ease transition
    api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY")
    if not api_key:
        sys.exit("ERROR: set OPENROUTER_API_KEY (preferred) or OPENAI_API_KEY env-var first")

    # Load data
    df = pd.read_parquet(args.parquet)
    
    # Apply test limit if specified (random sample with seed)
    if args.test:
        n = min(args.test, len(df))
        df = df.sample(n=n, random_state=args.seed)
        print(f"TEST MODE: Sampling {len(df)} questions with seed {args.seed}")
    
    total = len(df)

    # Determine which indices are already done (if resuming)
    done = set()
    out_path = Path(args.jsonl)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    if out_path.exists():
        with out_path.open() as f:
            done = {json.loads(line)["idx"] for line in f}
        print(f"Resuming – {len(done)}/{total} already completed")

    # ThreadPoolExecutor works well for I/O bound HTTP requests
    with ThreadPoolExecutor(max_workers=args.workers) as pool, \
         out_path.open("a") as fout:

        # Extract questions and ground truth from the dataframe
        # tuples: (idx, original_question, prompt_to_send, ground_truth)
        questions = []
        for i, row in df.iterrows():
            if i not in done:
                # Handle different possible structures of the prompt column
                prompt = row["prompt"]
                question = extract_first_content(prompt)

                # Optional hint to encourage concise boxed answer
                if args.append_hint:
                    # Remove any pre-existing instruction like:
                    # "Think step by step, and put your final answer within \\boxed{}"
                    cleaned = question
                    patterns = [
                        r"(?i)think\s+step\s+by\s+step,?\s*and\s*put\s+your\s+final\s+answer\s+within\s*\\boxed\{\}\.?",
                        r"(?i)think\s+step\s+by\s+step\.?\s*put\s+your\s+final\s+answer\s+within\s*\\boxed\{\}\.?",
                    ]
                    for pat in patterns:
                        cleaned = re.sub(pat, "", cleaned).strip()
                    prompt_to_send = f"{cleaned}\n\n{args.hint_text}"
                else:
                    prompt_to_send = question

                # Extract ground truth if present
                ground_truth = None
                try:
                    reward_model = row.get("reward_model") if hasattr(row, "get") else row["reward_model"]
                    if isinstance(reward_model, dict):
                        ground_truth = reward_model.get("ground_truth")
                except Exception:
                    ground_truth = None

                questions.append((int(i), question, prompt_to_send, ground_truth))
        
        futures = {
            pool.submit(
                ask_gpt,
                idx,
                original_question,
                prompt_to_send,
                args.model,
                api_key,
                args.n_per_question,
                bool(args.allow_fallbacks),
                args.provider,
            ): (idx, ground_truth)
            for idx, original_question, prompt_to_send, ground_truth in questions
        }

        for fut in as_completed(futures):
            i, ground_truth = futures[fut]
            try:
                record = fut.result()
                # Attach ground truth to the saved record if available
                record["ground_truth"] = ground_truth
                fout.write(json.dumps(record, ensure_ascii=False) + "\n")
                fout.flush()
                print(f"[{len(done)+1}/{total}] idx={i} ✓")
                done.add(i)
            except Exception as e:
                # failure is already retried; if still failing, log and continue
                if isinstance(e, RetryError) and getattr(e, "last_attempt", None):
                    final_exc = e.last_attempt.exception()
                    print(f"[{i}] ERROR: {final_exc}", file=sys.stderr)
                else:
                    print(f"[{i}] ERROR: {e}", file=sys.stderr)

    print("All done!")

if __name__ == "__main__":
    main()