"""
Evaluate a bug-fixer model served behind an OpenAI-compatible endpoint on LiveCodeBench-style bug fixing.

This is the LiveCodeBench-analogue of `examples/bugs/eval_bugbench.py`:
- Supports sharding, resuming, JSONL output
- Runs N-way parallel bug-fix requests
- Scores fixes with `RewardCodeFn` using the LCB runner (`data_source="livecodebench"`)
- Optionally also reports **codegen pass@1** (generate a fresh solution from the prompt only)

Note: We intentionally avoid importing the torch-dependent AgentWorkflowEngine stack here.

Default dataset: `anonymous/lcb_bugbench`, split `test`.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import os
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass
from typing import Any, Optional

from datasets import load_dataset

from examples.bugs.prompts import _build_bug_fixer_prompt, _build_code_generation_prompt
from rllm.engine import OpenAIEngine
from rllm.rewards.code_reward import RewardCodeFn, extract_code_from_model
from rllm.rewards.reward_types import RewardConfig


def _resolve_api_key(base_url: str, explicit_api_key: Optional[str]) -> str:
    api_key = (explicit_api_key or os.getenv("OPENAI_API_KEY", "")) or ""
    api_key = str(api_key).strip()
    is_openai_api = "api.openai.com" in str(base_url)
    if is_openai_api and not api_key:
        raise ValueError(
            "model_url points to the OpenAI API, but OPENAI_API_KEY is missing/empty. "
            "Please export OPENAI_API_KEY or pass --api_key explicitly."
        )
    return api_key or "EMPTY"


def _to_fenced_python(code_or_text: str) -> str:
    s = (code_or_text or "").strip("\n")
    if s.startswith("```"):
        return s
    return f"```python\n{s}\n```"


def _extract_code_or_fallback(text: str) -> str:
    if text is None:
        return ""
    s = str(text)
    code = extract_code_from_model(s)
    if code is not None:
        return code.strip()
    return s.strip()


def _maybe_build_full_code(code_prompt: str, solution_or_body: str) -> str:
    """
    If the dataset stores starter code separately (BugBench-style), prepend it when the
    buggy snippet looks like a body-only fragment.
    """
    cp = (code_prompt or "").strip("\n")
    sol = (solution_or_body or "").strip("\n")
    sol_lower = sol.lstrip().lower()
    if "```" in sol or sol_lower.startswith("import ") or sol_lower.startswith("from ") or "def " in sol or "class " in sol or "__name__" in sol:
        return sol
    if cp:
        return cp + "\n" + sol
    return sol


def _prepare_ground_truth_for_reward(ground_truth: Any) -> tuple[Optional[Any], Optional[str]]:
    """
    Return (tests_for_reward, error_message). If error_message is not None, ground_truth is unusable.
    We return parsed objects (list/dict) to avoid downstream json.loads failures.
    """
    if ground_truth is None:
        return None, "ground_truth is None"
    if isinstance(ground_truth, (list, dict)):
        return ground_truth, None
    if isinstance(ground_truth, str):
        s = ground_truth.strip()
        if not s:
            return None, "ground_truth is empty string"
        try:
            parsed = json.loads(s)
        except Exception as e:
            return None, f"ground_truth is not valid JSON string: {type(e).__name__}: {e}"
        return parsed, None
    return None, f"ground_truth has unsupported type: {type(ground_truth)}"


def _to_task(example: dict[str, Any], idx: int, bug_field: str, skip_bad_ground_truth: bool) -> dict[str, Any]:
    # Support both schemas:
    # - original LCB bugbench: uid/question/ground_truth/buggy_solution*
    # - BugBench-style format: task_id/instruct_prompt/complete_prompt/test/buggy/code_prompt
    uid = str(example.get("uid") or example.get("task_id") or f"lcb_bugbench_{idx}")
    question = (
        str(example.get("question") or "")
        or str(example.get("complete_prompt") or "")
        or str(example.get("instruct_prompt") or "")
        or str(example.get("problem") or "")
    )

    buggy_text = example.get(bug_field, None)
    used_bug_field = str(bug_field)
    if buggy_text is None and str(bug_field).startswith("buggy_solution") and "buggy" in example:
        buggy_text = example.get("buggy")
        used_bug_field = "buggy"

    buggy_raw = "" if buggy_text is None else str(buggy_text)
    buggy_code = _extract_code_or_fallback(buggy_raw)
    code_prompt = str(example.get("code_prompt") or example.get("starter_code") or "")
    buggy_full = _maybe_build_full_code(code_prompt, buggy_code)
    buggy_solution = _to_fenced_python(buggy_full)

    gt_raw = example.get("ground_truth", None)
    if gt_raw is None and "test" in example:
        gt_raw = example.get("test", None)
    tests_for_reward, gt_err = _prepare_ground_truth_for_reward(gt_raw)
    if gt_err is not None:
        if skip_bad_ground_truth:
            return {
                "uid": uid,
                "index": int(idx),
                "bug_field": used_bug_field,
                "skipped": True,
                "skip_reason": "bad_ground_truth",
                "ground_truth_error": gt_err,
            }
        raise ValueError(f"Bad ground_truth for uid={uid}: {gt_err}")

    return {
        "uid": uid,
        "index": int(idx),
        "data_source": "livecodebench",
        "question": question,
        "buggy_solution": buggy_solution,
        "ground_truth": tests_for_reward,
        "bug_field": used_bug_field,
        "skipped": False,
        "metadata": example.get("metadata", {}) or {},
    }


@dataclass
class EvalRow:
    uid: str
    index: int
    bug_field: str
    bugfix_passed: Optional[bool]
    bugfix_reward: Optional[float]
    codegen_passed: Optional[bool] = None
    codegen_reward: Optional[float] = None
    codegen_solution: Optional[str] = None
    skipped: bool = False
    skip_reason: Optional[str] = None
    error: Optional[str] = None
    buggy_solution: Optional[str] = None
    fixed_solution: Optional[str] = None


def _read_done_uids_from_jsonl(path: str) -> set[str]:
    done: set[str] = set()
    if not path or not os.path.exists(path):
        return done
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = (line or "").strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            uid = obj.get("uid", None)
            if isinstance(uid, str) and uid:
                done.add(uid)
    return done


async def _solve_one(
    *,
    engine: OpenAIEngine,
    reward_fn: RewardCodeFn,
    executor: ThreadPoolExecutor,
    sem: asyncio.Semaphore,
    task: dict[str, Any],
    retry_limit: int,
    do_codegen: bool,
) -> EvalRow:
    uid = str(task.get("uid", ""))
    idx = int(task.get("index", -1))
    bug_field = str(task.get("bug_field", "buggy_solution"))

    if bool(task.get("skipped", False)):
        return EvalRow(
            uid=uid,
            index=idx,
            bug_field=bug_field,
            bugfix_passed=None,
            bugfix_reward=None,
            codegen_passed=None,
            codegen_reward=None,
            codegen_solution=None,
            skipped=True,
            skip_reason=str(task.get("skip_reason") or "skipped"),
            error=str(task.get("ground_truth_error") or ""),
        )

    question = str(task.get("question", "") or "")
    buggy_solution = str(task.get("buggy_solution", "") or "")
    buggy_code = _extract_code_or_fallback(buggy_solution)
    prompt = _build_bug_fixer_prompt(problem=question, buggy_code=buggy_code)
    messages = [{"role": "user", "content": prompt}]

    last_err: Optional[Exception] = None
    async with sem:
        for attempt in range(max(1, int(retry_limit) + 1)):
            try:
                out = await engine.get_model_response(messages)
                raw = (out.text or out.content or "").strip()
                fixed_code = _extract_code_or_fallback(raw)
                fixed_solution = _to_fenced_python(fixed_code)

                loop = asyncio.get_running_loop()
                reward_out = await loop.run_in_executor(executor, reward_fn, task, fixed_solution)

                row = EvalRow(
                    uid=uid,
                    index=idx,
                    bug_field=bug_field,
                    bugfix_passed=bool(reward_out.is_correct),
                    bugfix_reward=float(reward_out.reward),
                    skipped=False,
                    skip_reason=None,
                    error=(reward_out.metadata or {}).get("error") if isinstance(reward_out.metadata, dict) else None,
                    buggy_solution=buggy_solution,
                    fixed_solution=fixed_solution,
                )

                if bool(do_codegen):
                    # Code generation from prompt only (no buggy code).
                    codegen_prompt = _build_code_generation_prompt(task)
                    codegen_prompt += "\n\nReturn only the full Python solution inside a single ```python``` block."
                    codegen_messages = [{"role": "user", "content": codegen_prompt}]

                    codegen_out = await engine.get_model_response(codegen_messages)
                    codegen_raw = (codegen_out.text or codegen_out.content or "").strip()
                    codegen_code = _extract_code_or_fallback(codegen_raw)
                    codegen_solution = _to_fenced_python(codegen_code)
                    codegen_reward_out = await loop.run_in_executor(executor, reward_fn, task, codegen_solution)

                    row.codegen_passed = bool(codegen_reward_out.is_correct)
                    row.codegen_reward = float(codegen_reward_out.reward)
                    row.codegen_solution = codegen_solution

                return row
            except Exception as e:
                last_err = e
                if attempt < int(retry_limit):
                    await asyncio.sleep(0.2 * (attempt + 1))
                else:
                    break

    return EvalRow(
        uid=uid,
        index=idx,
        bug_field=bug_field,
        bugfix_passed=False,
        bugfix_reward=0.0,
        codegen_passed=False if bool(do_codegen) else None,
        codegen_reward=0.0 if bool(do_codegen) else None,
        codegen_solution=None,
        skipped=False,
        error=f"{type(last_err).__name__}: {last_err}" if last_err is not None else "unknown_error",
        buggy_solution=buggy_solution,
        fixed_solution=None,
    )


async def main_async() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--hf_dataset", type=str, default="anonymous/lcb_bugbench_weak")
    p.add_argument("--hf_split", type=str, default="test")
    p.add_argument(
        "--bug_field",
        type=str,
        default="buggy",
        help="Which column to treat as the buggy input code. "
        "For BugBench-style formatted datasets use 'buggy'. "
        "For original lcb_bugbench you can use buggy_solution / buggy_solution_weak / buggy_solution_strong.",
    )
    p.add_argument("--start", type=int, default=0)
    p.add_argument("--max_examples", type=int, default=5000)
    p.add_argument("--output_jsonl", type=str, default=None)
    p.add_argument("--resume", action="store_true")
    p.add_argument("--skip_bad_ground_truth", action="store_true", help="Skip rows with missing/invalid ground_truth (default: True).")
    p.add_argument("--n_parallel", type=int, default=64, help="Max concurrent model requests")
    p.add_argument("--retry_limit", type=int, default=2)
    p.add_argument("--do_codegen", action="store_true", help="Also compute codegen pass@1 (generate from prompt only)")
    p.add_argument("--num_shards", type=int, default=1)
    p.add_argument("--shard_id", type=int, default=0)

    # Bug-fixer model endpoint
    p.add_argument("--fix_model", type=str, required=True)
    p.add_argument("--fix_model_url", type=str, default="http://localhost:30000/v1")
    p.add_argument("--fix_api_key", type=str, default=None)
    p.add_argument("--fix_temperature", type=float, default=0.6)
    p.add_argument("--fix_top_p", type=float, default=0.95)
    p.add_argument("--fix_max_tokens", type=int, default=2048)

    args = p.parse_args()

    if not bool(args.skip_bad_ground_truth):
        args.skip_bad_ground_truth = True

    if args.num_shards <= 0:
        raise ValueError("--num_shards must be >= 1")
    if args.shard_id < 0 or args.shard_id >= args.num_shards:
        raise ValueError("--shard_id must be in [0, num_shards)")

    hf = load_dataset(args.hf_dataset, split=args.hf_split)
    end = min(len(hf), args.start + args.max_examples)

    fix_engine = OpenAIEngine(
        model=args.fix_model,
        tokenizer=None,
        base_url=args.fix_model_url,
        api_key=_resolve_api_key(args.fix_model_url, args.fix_api_key),
        sampling_params={"temperature": args.fix_temperature, "top_p": args.fix_top_p, "max_tokens": args.fix_max_tokens},
        verbose=False,
    )

    out_f = None
    if args.output_jsonl:
        os.makedirs(os.path.dirname(args.output_jsonl) or ".", exist_ok=True)
        mode = "a" if args.resume and os.path.exists(args.output_jsonl) else "w"
        out_f = open(args.output_jsonl, mode, encoding="utf-8")

    done_uids: set[str] = set()
    if args.resume and args.output_jsonl:
        done_uids = _read_done_uids_from_jsonl(args.output_jsonl)

    tasks: list[dict[str, Any]] = []
    task_ids: list[str] = []
    for i in range(args.start, end):
        if (int(i) % int(args.num_shards)) != int(args.shard_id):
            continue
        task = _to_task(dict(hf[int(i)]), int(i), str(args.bug_field), bool(args.skip_bad_ground_truth))
        uid = str(task.get("uid"))
        if done_uids and uid in done_uids:
            continue
        tasks.append(task)
        task_ids.append(uid)

    if not tasks:
        print("No examples selected (check --start/--max_examples, --num_shards/--shard_id, and/or --resume).")
        return

    reward_fn = RewardCodeFn(RewardConfig())
    sem = asyncio.Semaphore(int(args.n_parallel))
    executor = ThreadPoolExecutor(max_workers=max(1, min(int(args.n_parallel), 64)))

    total_eval = 0
    ok = 0
    codegen_total = 0
    codegen_ok = 0
    skipped = 0

    try:
        q: asyncio.Queue[Optional[dict[str, Any]]] = asyncio.Queue()
        for t in tasks:
            q.put_nowait(t)

        n_workers = max(1, min(int(args.n_parallel), len(tasks)))
        for _ in range(n_workers):
            q.put_nowait(None)  # sentinel

        lock = asyncio.Lock()

        async def worker() -> None:
            nonlocal total_eval, ok, skipped, codegen_total, codegen_ok
            while True:
                t = await q.get()
                if t is None:
                    return
                row = await _solve_one(
                    engine=fix_engine,
                    reward_fn=reward_fn,
                    executor=executor,
                    sem=sem,
                    task=t,
                    retry_limit=int(args.retry_limit),
                    do_codegen=bool(args.do_codegen),
                )
                async with lock:
                    if row.skipped:
                        skipped += 1
                    else:
                        total_eval += 1
                        ok += int(bool(row.bugfix_passed))
                        if bool(args.do_codegen):
                            codegen_total += 1
                            codegen_ok += int(bool(row.codegen_passed))

                    if out_f is not None:
                        out_f.write(json.dumps(asdict(row)) + "\n")

                    done_n = total_eval + skipped
                    if done_n % 25 == 0:
                        rate = (ok / total_eval) if total_eval else 0.0
                        if bool(args.do_codegen):
                            c_rate = (codegen_ok / codegen_total) if codegen_total else 0.0
                            print(
                                f"[{done_n} processed] evaluated={total_eval} skipped={skipped} "
                                f"bugfix_pass@1={rate:.3f} codegen_pass@1={c_rate:.3f}"
                            )
                        else:
                            print(f"[{done_n} processed] evaluated={total_eval} skipped={skipped} bugfix_pass@1={rate:.3f}")

        await asyncio.gather(*(worker() for _ in range(n_workers)))
    finally:
        if out_f is not None:
            out_f.close()
        executor.shutdown(wait=False, cancel_futures=True)

    done_n = total_eval + skipped
    print("\n" + "=" * 80)
    print(
        f"LCB BugBench hf_dataset={args.hf_dataset} hf_split={args.hf_split} bug_field={args.bug_field} "
        f"slice=[{args.start}:{end}) shard={args.shard_id}/{args.num_shards} processed={done_n} evaluated={total_eval} skipped={skipped}"
    )
    print(f"bugfix pass@1: {ok/max(total_eval,1):.6f} ({ok}/{total_eval})")
    if bool(args.do_codegen):
        print(f"codegen pass@1: {codegen_ok/max(codegen_total,1):.6f} ({codegen_ok}/{codegen_total})")
    print("=" * 80)


def main() -> None:
    asyncio.run(main_async())


if __name__ == "__main__":
    main()


