#!/usr/bin/env python3
"""
Classify math problem domain(s) and difficulty using a prompt template.

Reads a Parquet file with columns:
- prompt_content: the problem statement (Question)
- solution: the official/author solution (Solution)
- data_source: the problem source (Source)

Builds an instruction prompt from `prompt_domain_classification.txt` and
queries the model through OpenRouter's chat completions API.

Outputs one JSON line per row with keys:
- idx, model, template_sha256, question, solution, source, response, usage

Usage example (test with 5 samples):
  OPENROUTER_API_KEY=... python3 classify_domain_difficulty.py \
      --parquet ./data/dataset.parquet \
      --model openai/gpt-5-mini \
      --workers 20 \
      --test 5 \
      --jsonl ./outputs/classify_openai_gpt-5-mini_test5_seed_123.jsonl
"""
import argparse, json, os, sys, hashlib
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd
import requests
from tenacity import retry, wait_random_exponential, stop_after_attempt, RetryError


# -------- helpers --------
def extract_first_content(value) -> str:
    """Best-effort extraction of string content from common nested formats."""
    import re, ast
    import numpy as np

    if isinstance(value, list):
        for item in value:
            if isinstance(item, dict) and "content" in item:
                return str(item["content"])  # type: ignore[index]
    if isinstance(value, dict) and "content" in value:
        return str(value["content"])  # type: ignore[index]
    if "numpy" in str(type(value)):
        try:
            return extract_first_content(value.tolist())  # type: ignore[attr-defined]
        except Exception:
            pass

    if isinstance(value, str):
        s = value.strip()
        m = re.match(r"^array\((.*)\)$", s, flags=re.DOTALL)
        if m:
            inner = m.group(1)
            inner = re.sub(r",\s*dtype=object\s*$", "", inner)
            s = inner.strip()
        try:
            parsed = json.loads(s)
            return extract_first_content(parsed)
        except Exception:
            pass
        try:
            parsed = ast.literal_eval(s)
            return extract_first_content(parsed)
        except Exception:
            pass
        m = re.search(r"[\"']content[\"']\s*:\s*[\"'](.*?)[\"']", s, flags=re.DOTALL)
        if m:
            return m.group(1)
    return str(value)


def load_template_text(template_path: str) -> tuple[str, str]:
    """Return (template_text, sha256_hex)."""
    p = Path(template_path)
    text = p.read_text(encoding="utf-8")
    sha = hashlib.sha256(text.encode("utf-8")).hexdigest()
    return text, sha


def build_prompt(template_text: str, question: str, solution: str | None, source: str | None) -> str:
    """Fill the classification template with the concrete problem.

    Replaces tokens if present; otherwise appends a minimal problem block.
    """
    filled = template_text
    placeholders_present = all(tok in template_text for tok in ["{{Question Here}}", "{{Solution Here}}", "{{Source Here}}"])
    if placeholders_present:
        filled = filled.replace("{{Question Here}}", question)
        filled = filled.replace("{{Solution Here}}", solution or "")
        filled = filled.replace("{{Source Here}}", source or "")
        return filled

    # Fallback: append a succinct block at the end.
    problem_block = (
        "\n\n<math problem>\n"
        "[Question]:\n" + question + "\n"
        "[Solution]:\n" + (solution or "") + "\n"
        "[Source]:\n" + (source or "") + "\n"
        "</math problem>\n"
    )
    return template_text + problem_block


# -------- retry-wrapped API call --------
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def ask_gpt(idx: int, input_text: str, model: str, api_key: str, allow_fallbacks: bool, provider_order: list[str] | None) -> dict:
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
    }
    payload: dict = {
        "model": model,
        "messages": [
            {"role": "user", "content": input_text}
        ],
    }
    if allow_fallbacks:
        payload["allow_fallbacks"] = True
    if provider_order:
        payload["provider"] = {"order": provider_order, "allow_fallbacks": bool(allow_fallbacks)}

    resp = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload)
    if resp.status_code != 200:
        try:
            err = resp.json()
        except Exception:
            err = {"text": resp.text}
        raise RuntimeError(f"OpenRouter error {resp.status_code}: {err}")

    data = resp.json()
    choice = (data.get("choices") or [{}])[0]
    message = choice.get("message") or {}
    content = message.get("content") or ""
    usage = data.get("usage") or None
    return {
        "idx": idx,
        "response": content,
        "response_id": data.get("id"),
        "created": data.get("created"),
        "finish_reason": choice.get("finish_reason"),
        "usage": usage,
    }


# -------- main workflow --------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--parquet", default="./data/dataset.parquet", help="input .parquet file")
    ap.add_argument("--jsonl", help="output .jsonl file (append-safe); if omitted, a filename is derived from model/seed/test")
    ap.add_argument("--workers", type=int, default=20, help="#parallel requests")
    ap.add_argument("--model", default="openai/gpt-5-mini", help="OpenRouter model name (e.g., openai/gpt-5-mini or openai/gpt-5)")
    ap.add_argument("--test", type=int, help="sanity check: only process first N questions (sampled)")
    ap.add_argument("--seed", type=int, default=123, help="random seed for test sampling")
    ap.add_argument("--allow_fallbacks", action="store_true", help="allow model/provider fallbacks on OpenRouter")
    ap.add_argument("--provider", action="append", help="provider name to prefer in order (may repeat)")
    ap.add_argument("--template", default="./prompts/prompt_domain_classification.txt", help="path to classification prompt template")
    args = ap.parse_args()

    api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY")
    if not api_key:
        sys.exit("ERROR: set OPENROUTER_API_KEY (preferred) or OPENAI_API_KEY env-var first")

    # Load dataset
    df = pd.read_parquet(args.parquet)
    if args.test:
        n = min(int(args.test), len(df))
        df = df.sample(n=n, random_state=args.seed)
        print(f"TEST MODE: Sampling {len(df)} rows with seed {args.seed}")

    total = len(df)

    # Output path and resume support
    out_path = Path(args.jsonl) if args.jsonl else None
    if out_path is None:
        # Derive an output filename
        model_safe = args.model.replace('/', '_')
        test_tag = f"test{args.test}" if args.test else "all"
        out_dir = Path("./outputs")
        out_dir.mkdir(parents=True, exist_ok=True)
        out_path = out_dir / f"classify_{model_safe}_{test_tag}_seed_{args.seed}.jsonl"
    out_path.parent.mkdir(parents=True, exist_ok=True)

    done: set[int] = set()
    if out_path.exists():
        with out_path.open() as f:
            for line in f:
                try:
                    done.add(int(json.loads(line)["idx"]))
                except Exception:
                    continue
        print(f"Resuming – {len(done)}/{total} already completed")

    # Load and hash the template once
    template_text, template_sha = load_template_text(args.template)

    # Prepare tasks
    tasks: list[tuple[int, str, str | None, str | None]] = []
    for i, row in df.iterrows():
        if int(i) in done:
            continue
        q_raw = row.get("prompt_content") if hasattr(row, "get") else row["prompt_content"]
        s_raw = row.get("solution") if hasattr(row, "get") else row["solution"]
        src_raw = None
        try:
            src_raw = row.get("data_source") if hasattr(row, "get") else row["data_source"]
        except Exception:
            src_raw = None

        question = extract_first_content(q_raw)
        solution = extract_first_content(s_raw) if s_raw is not None else None
        source = str(src_raw) if src_raw is not None else None
        input_text = build_prompt(template_text, question, solution, source)
        tasks.append((int(i), input_text, solution, source))

    # Execute
    with ThreadPoolExecutor(max_workers=args.workers) as pool, out_path.open("a", encoding="utf-8") as fout:
        futures = {
            pool.submit(
                ask_gpt,
                idx,
                input_text,
                args.model,
                api_key,
                bool(args.allow_fallbacks),
                args.provider,
            ): (idx, input_text)
            for (idx, input_text, _solution, _source) in tasks
        }

        for fut in as_completed(futures):
            idx, _ = futures[fut]
            try:
                resp = fut.result()
                # Persist a compact record
                record = {
                    "idx": idx,
                    "model": args.model,
                    "template_sha256": template_sha,
                    "response": resp.get("response"),
                    "response_id": resp.get("response_id"),
                    "created": resp.get("created"),
                    "finish_reason": resp.get("finish_reason"),
                    "usage": resp.get("usage"),
                }
                fout.write(json.dumps(record, ensure_ascii=False) + "\n")
                fout.flush()
                done.add(idx)
                print(f"[{len(done)}/{total}] idx={idx} ✓")
            except Exception as e:
                if isinstance(e, RetryError) and getattr(e, "last_attempt", None):
                    final_exc = e.last_attempt.exception()
                    print(f"[{idx}] ERROR: {final_exc}", file=sys.stderr)
                else:
                    print(f"[{idx}] ERROR: {e}", file=sys.stderr)

    print("All done!")


if __name__ == "__main__":
    main()


