"""
Evaluate a bug-fixer model served behind an OpenAI-compatible vLLM endpoint on BugBench.

For each example:
  - Build a BugBench-style task (question + buggy_solution + test + entry_point)
  - Prompt the model with the standard bug-fixer prompt (`examples/bugs/prompts.py`)
  - Score the model output with `RewardCodeFn` using BigCodeBench/BugBench runner

Reports pass@1 over the evaluated slice.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import os
import re
from dataclasses import asdict, dataclass
from typing import Any, Optional

from datasets import load_dataset

from examples.bugs.prompts import _build_bug_fixer_prompt, _build_code_generation_prompt
from rllm.data.utils import fetch_live_code_bench_system_prompt
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.agents.agent import Episode, Step, Trajectory
from rllm.rewards.code_reward import RewardCodeFn
from rllm.rewards.reward_types import RewardConfig
from rllm.workflows.workflow import TerminationReason, Workflow


def _resolve_api_key(base_url: str, explicit_api_key: Optional[str]) -> str:
    api_key = (explicit_api_key or os.getenv("OPENAI_API_KEY", "")) or ""
    api_key = str(api_key).strip()
    is_openai_api = "api.openai.com" in str(base_url)
    if is_openai_api and not api_key:
        raise ValueError(
            "model_url points to the OpenAI API, but OPENAI_API_KEY is missing/empty. "
            "Please export OPENAI_API_KEY or pass --api_key explicitly."
        )
    return api_key or "EMPTY"


def _to_fenced_python(code_or_text: str) -> str:
    s = (code_or_text or "").strip("\n")
    if s.startswith("```"):
        return s
    return f"```python\n{s}\n```"


def _truncate_instruct_prompt(instruct_prompt: str) -> str:
    pattern = r"You should write self-contained code starting with\s*:?\s*"
    truncated = re.split(pattern, instruct_prompt or "", maxsplit=1)[0]
    return truncated.rstrip()


def _maybe_build_full_code(code_prompt: str, solution_or_body: str) -> str:
    cp = (code_prompt or "").strip("\n")
    sol = (solution_or_body or "").strip("\n")
    sol_lower = sol.lstrip().lower()
    if "```" in sol or sol_lower.startswith("import ") or sol_lower.startswith("from ") or "def " in sol:
        return sol
    if cp:
        return cp + "\n" + sol
    return sol


def _to_task(example: dict[str, Any], idx: int) -> Optional[dict[str, Any]]:
    """
    Adapt a raw HF BugBench row into the task schema expected by RewardCodeFn + prompts.
    """
    instruct_prompt = example.get("instruct_prompt", "") or ""
    code_prompt = example.get("code_prompt", "") or ""
    truncated = _truncate_instruct_prompt(instruct_prompt)
    question = fetch_live_code_bench_system_prompt(truncated, code_prompt)

    # Buggy code is expected under the column name "buggy".
    # If it's missing/empty, skip the entry (common in partially-processed datasets).
    buggy_body = example.get("buggy", None)
    if not (isinstance(buggy_body, str) and buggy_body.strip()):
        return None
    buggy_full = _maybe_build_full_code(code_prompt, buggy_body)
    buggy_solution = _to_fenced_python(buggy_full)

    task_id = example.get("task_id", f"bugbench_{idx}")
    entry_point = example.get("entry_point", None)

    return {
        "uid": str(task_id),
        "index": int(idx),
        "data_source": "bugbench",
        "question": question,
        "buggy": buggy_body,
        "buggy_solution": buggy_solution,
        "ground_truth": example.get("test", ""),
        "entry_point": entry_point,
        "starter_code": code_prompt,
        "instruct_prompt": instruct_prompt,
        "code_prompt": code_prompt,
        "metadata": {"func_name": entry_point} if entry_point is not None else {},
    }


@dataclass
class EvalRow:
    uid: str
    index: int
    # bug-fix metrics (BugBench)
    bugfix_passed: bool
    bugfix_reward: float
    # codegen metrics (BigCodeBench, matched by task_id)
    codegen_uid: Optional[str] = None
    codegen_passed: Optional[bool] = None
    codegen_reward: Optional[float] = None
    codegen_skipped: bool = False
    codegen_skip_reason: Optional[str] = None
    error: Optional[str] = None
    buggy_solution: Optional[str] = None
    fixed_solution: Optional[str] = None
    codegen_solution: Optional[str] = None


def _read_done_uids_from_jsonl(path: str) -> set[str]:
    done: set[str] = set()
    if not path or not os.path.exists(path):
        return done
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = (line or "").strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            uid = obj.get("uid", None)
            if isinstance(uid, str) and uid:
                done.add(uid)
    return done


class BugBenchEvalWorkflow(Workflow):
    """
    Parallelizable workflow evaluation for:
      - BugBench bug fixing pass@1
      - (optional) BigCodeBench codegen pass@1 (matched by task_id)

    Each input task is independent, so it can be massively parallelized via:
      - in-process concurrency (`--n_parallel`)
      - sharding across jobs (`--num_shards/--shard_id`)
      - resuming from JSONL (`--resume`)
    """

    def __init__(
        self,
        rollout_engine,
        executor,
        reward_config: Optional[RewardConfig] = None,
        system_prompt: Optional[str] = None,
        do_codegen: bool = True,
        debug: bool = False,
        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, executor=executor, **kwargs)
        self.reward_fn = RewardCodeFn(reward_config or RewardConfig())
        self.system_prompt = system_prompt
        self.do_codegen = bool(do_codegen)
        self.debug = bool(debug)

    @staticmethod
    def _split_think(raw_response: str) -> tuple[str, str]:
        if (raw_response or "").count("</think>") == 1:
            thought, sep, action = (raw_response or "").partition("</think>")
            return (thought + sep).strip(), action.strip()
        return "", (raw_response or "").strip()

    async def run(self, task: dict, uid: str, **kwargs) -> Episode:
        self.reset(task, uid)

        bugbench_task = task["bugbench_task"]
        bcb_task = task.get("bcb_task", None)
        codegen_skip_reason = task.get("codegen_skip_reason", None)

        buggy_solution = str(bugbench_task.get("buggy_solution", "") or "")
        if not buggy_solution.strip():
            # Shouldn't happen if `_to_task` filtered correctly, but keep evaluation robust.
            raise ValueError('Task is missing "buggy_solution" (expected to be derived from HF row "buggy")')

        # -------------------------
        # BugBench bug fixing
        # -------------------------
        bugfix_prompt = _build_bug_fixer_prompt(problem=str(bugbench_task.get("question") or ""), buggy_code=buggy_solution)
        if self.debug:
            print(bugfix_prompt)
            breakpoint()
        fix_messages: list[dict[str, str]] = []
        if self.system_prompt:
            fix_messages.append({"role": "system", "content": self.system_prompt})
        fix_messages.append({"role": "user", "content": bugfix_prompt})

        fix_out = await self.rollout_engine.get_model_response(fix_messages)
        fix_raw = (fix_out.text or fix_out.content or "").strip()
        fix_thought, fixed_solution = self._split_think(fix_raw)

        bugfix_reward_out = await self.run_in_executor(self.reward_fn, bugbench_task, fixed_solution)
        bugfix_passed = bool(bugfix_reward_out.is_correct)

        bugfix_step = Step(
            chat_completions=fix_messages + [{"role": "assistant", "content": fixed_solution}],
            thought=fix_thought,
            action=fixed_solution,
            model_response=fix_raw,
            model_output=fix_out,
            reward=float(bugfix_reward_out.reward),
            info={
                "buggy_solution": buggy_solution,
                "reward_metadata": bugfix_reward_out.metadata or {},
            },
        )
        bugfix_traj = Trajectory(name="bugfix", steps=[bugfix_step], reward=float(bugfix_step.reward))

        # -------------------------
        # BigCodeBench codegen (optional)
        # -------------------------
        codegen_traj: Optional[Trajectory] = None
        codegen_passed: Optional[bool] = None
        codegen_reward: Optional[float] = None
        codegen_solution: Optional[str] = None

        if self.do_codegen and bcb_task is not None:
            codegen_prompt = _build_code_generation_prompt(bcb_task)
            codegen_prompt += "\n\nReturn only the full Python solution inside a single ```python``` block."
            codegen_messages: list[dict[str, str]] = []
            if self.system_prompt:
                codegen_messages.append({"role": "system", "content": self.system_prompt})
            codegen_messages.append({"role": "user", "content": codegen_prompt})

            codegen_out = await self.rollout_engine.get_model_response(codegen_messages)
            codegen_raw = (codegen_out.text or codegen_out.content or "").strip()
            codegen_thought, codegen_solution = self._split_think(codegen_raw)

            codegen_reward_out = await self.run_in_executor(self.reward_fn, bcb_task, codegen_solution)
            codegen_passed = bool(codegen_reward_out.is_correct)
            codegen_reward = float(codegen_reward_out.reward)

            codegen_step = Step(
                chat_completions=codegen_messages + [{"role": "assistant", "content": codegen_solution}],
                thought=codegen_thought,
                action=codegen_solution,
                model_response=codegen_raw,
                model_output=codegen_out,
                reward=float(codegen_reward_out.reward),
                info={"reward_metadata": codegen_reward_out.metadata or {}},
            )
            codegen_traj = Trajectory(name="codegen", steps=[codegen_step], reward=float(codegen_step.reward))

        metrics: dict[str, Any] = {
            "bugfix_pass": float(bugfix_passed),
            "bugfix_reward": float(bugfix_reward_out.reward),
            "codegen_enabled": float(self.do_codegen),
            "codegen_matched": float(bcb_task is not None),
        }
        if codegen_skip_reason:
            metrics["codegen_skip_reason"] = str(codegen_skip_reason)
        if codegen_passed is not None:
            metrics["codegen_pass"] = float(codegen_passed)
        if codegen_reward is not None:
            metrics["codegen_reward"] = float(codegen_reward)

        episode = Episode(
            id=uid,
            task=task,
            termination_reason=TerminationReason.UNKNOWN,
            is_correct=bool(bugfix_passed),
            trajectories=[t for t in [bugfix_traj, codegen_traj] if t is not None],
            metrics=metrics,
            info={
                "bugbench_uid": str(bugbench_task.get("uid", "")),
                "bugbench_index": int(bugbench_task.get("index", -1)),
                "buggy_solution": buggy_solution,
                "fixed_solution": fixed_solution,
                "codegen_solution": codegen_solution,
                "bugfix_error": (bugfix_reward_out.metadata or {}).get("error") if isinstance(bugfix_reward_out.metadata, dict) else None,
            },
        )
        return episode

    def assign_episode_correctness(self, episode: Episode) -> None:
        m = episode.metrics or {}
        episode.is_correct = bool(m.get("bugfix_pass", 0.0))


async def main_async() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--hf_dataset", type=str, default="anonymous/bugbench", help="HF dataset name (default: anonymous/bugbench)")
    p.add_argument("--hf_split", type=str, default="train", help="HF split name (default: train)")
    p.add_argument("--bcb_dataset", type=str, default="anonymous/bigcodebench", help="HF BigCodeBench dataset name (default: anonymous/bigcodebench)")
    p.add_argument("--bcb_split", type=str, default="v0.1.0_hf", help="HF BigCodeBench split name (default: v0.1.0_hf)")
    p.add_argument("--start", type=int, default=0)
    p.add_argument("--max_examples", type=int, default=1200)
    p.add_argument("--output_jsonl", type=str, default=None)
    p.add_argument("--resume", action="store_true", help="If set, skip uids already present in --output_jsonl and append new rows")
    p.add_argument("--n_parallel", type=int, default=64, help="Number of parallel workflow tasks to run in-process")
    p.add_argument("--retry_limit", type=int, default=2, help="Workflow retry limit for transient failures")
    p.add_argument("--do_codegen", action="store_true", help="Also evaluate BigCodeBench codegen pass@1 when task_id matches")
    p.add_argument("--debug", action="store_true", help="Print prompts and drop into a debugger (disables clean parallel runs)")
    p.add_argument("--num_shards", type=int, default=1, help="Total number of shards (for workflow engines / job arrays)")
    p.add_argument("--shard_id", type=int, default=0, help="Shard id in [0, num_shards)")

    # Bug-fixer model endpoint (BugBench)
    p.add_argument("--fix_model", type=str, required=True, help="Bug-fixer model name exposed by the vLLM server")
    p.add_argument("--fix_model_url", type=str, default="http://localhost:30000/v1", help="OpenAI-compatible base URL for bug-fixer vLLM")
    p.add_argument("--fix_api_key", type=str, default=None, help="Bug-fixer API key (defaults to OPENAI_API_KEY)")
    p.add_argument("--fix_temperature", type=float, default=0.6)
    p.add_argument("--fix_top_p", type=float, default=0.95)
    p.add_argument("--fix_max_tokens", type=int, default=2048)
    p.add_argument("--fix_system_prompt", type=str, default=None)

    args = p.parse_args()

    if args.num_shards <= 0:
        raise ValueError("--num_shards must be >= 1")
    if args.shard_id < 0 or args.shard_id >= args.num_shards:
        raise ValueError("--shard_id must be in [0, num_shards)")

    hf = load_dataset(args.hf_dataset, split=args.hf_split)
    end = min(len(hf), args.start + args.max_examples)

    # Load BigCodeBench once and build task_id -> row mapping.
    bcb = load_dataset(args.bcb_dataset, split=args.bcb_split)
    bcb_by_task_id: dict[str, dict[str, Any]] = {}
    for ex in bcb:
        tid = ex.get("task_id", None)
        if tid is None:
            continue
        bcb_by_task_id[str(tid)] = dict(ex)

    fix_engine = OpenAIEngine(
        model=args.fix_model,
        tokenizer=None,
        base_url=args.fix_model_url,
        api_key=_resolve_api_key(args.fix_model_url, args.fix_api_key),
        sampling_params={"temperature": args.fix_temperature, "top_p": args.fix_top_p, "max_tokens": args.fix_max_tokens},
        verbose=False,
    )

    out_f = None
    if args.output_jsonl:
        os.makedirs(os.path.dirname(args.output_jsonl) or ".", exist_ok=True)
        mode = "a" if args.resume and os.path.exists(args.output_jsonl) else "w"
        out_f = open(args.output_jsonl, mode, encoding="utf-8")

    done_uids: set[str] = set()
    if args.resume and args.output_jsonl:
        done_uids = _read_done_uids_from_jsonl(args.output_jsonl)

    # Build the task list for this shard.
    tasks: list[dict[str, Any]] = []
    task_ids: list[str] = []
    skipped_empty_buggy = 0
    for i in range(args.start, end):
        if (int(i) % int(args.num_shards)) != int(args.shard_id):
            continue

        bug_task = _to_task(dict(hf[int(i)]), int(i))
        if bug_task is None:
            skipped_empty_buggy += 1
            continue
        uid = str(bug_task.get("uid"))
        if done_uids and uid in done_uids:
            continue

        bcb_task: Optional[dict[str, Any]] = None
        codegen_skip_reason: Optional[str] = None
        bcb_row = bcb_by_task_id.get(uid)
        if bcb_row is None:
            codegen_skip_reason = "no_matching_bigcodebench_task_id"
        else:
            bcb_task = {
                "uid": str(bcb_row.get("task_id", uid)),
                "index": int(bcb_row.get("index", i)),
                "data_source": "bigcodebench",
                "question": bcb_row.get("question", "") or "",
                "instruct_prompt": bcb_row.get("instruct_prompt", "") or "",
                "code_prompt": bcb_row.get("code_prompt", "") or "",
                "starter_code": bcb_row.get("code_prompt", "") or "",
                "ground_truth": bcb_row.get("test", ""),
                "entry_point": bcb_row.get("entry_point", None),
            }

        tasks.append({"bugbench_task": bug_task, "bcb_task": bcb_task, "codegen_skip_reason": codegen_skip_reason})
        task_ids.append(uid)

    if not tasks:
        print(
            "No examples selected (check --start/--max_examples, --num_shards/--shard_id, --resume, "
            'and whether the dataset has non-empty "buggy" rows).'
        )
        return
    if skipped_empty_buggy:
        print(f"Skipped {skipped_empty_buggy} rows with missing/empty 'buggy' in this shard/slice.")

    engine = AgentWorkflowEngine(
        workflow_cls=BugBenchEvalWorkflow,
        workflow_args={
            "reward_config": RewardConfig(),
            "system_prompt": args.fix_system_prompt,
            "do_codegen": bool(args.do_codegen),
            "debug": bool(args.debug),
        },
        rollout_engine=fix_engine,
        config=None,
        n_parallel_tasks=int(args.n_parallel),
        retry_limit=int(args.retry_limit),
        raise_on_error=False,
    )

    total = 0
    bugfix_ok = 0
    codegen_total = 0
    codegen_ok = 0

    try:
        episodes = await engine.execute_tasks(tasks, task_ids=task_ids)

        for ep in episodes:
            bugbench_task = (ep.task or {}).get("bugbench_task", {}) or {}
            bug_uid = str(bugbench_task.get("uid", (ep.info or {}).get("bugbench_uid", "")))
            bug_index = int(bugbench_task.get("index", (ep.info or {}).get("bugbench_index", -1)))

            bugfix_passed = bool((ep.metrics or {}).get("bugfix_pass", 0.0))
            bugfix_reward = float((ep.metrics or {}).get("bugfix_reward", 0.0))

            # Codegen fields
            codegen_uid: Optional[str] = None
            codegen_passed: Optional[bool] = None
            codegen_reward: Optional[float] = None
            codegen_skipped = False
            codegen_skip_reason: Optional[str] = None

            if args.do_codegen:
                matched = bool((ep.metrics or {}).get("codegen_matched", 0.0))
                if not matched:
                    codegen_skipped = True
                    codegen_skip_reason = str((ep.metrics or {}).get("codegen_skip_reason", "no_matching_bigcodebench_task_id"))
                else:
                    codegen_passed = bool((ep.metrics or {}).get("codegen_pass", 0.0))
                    codegen_reward = float((ep.metrics or {}).get("codegen_reward", 0.0))
                    bcb_task = (ep.task or {}).get("bcb_task", None) or {}
                    if bcb_task:
                        codegen_uid = str(bcb_task.get("uid"))
                    codegen_total += 1
                    codegen_ok += int(bool(codegen_passed))

            total += 1
            bugfix_ok += int(bool(bugfix_passed))

            row = EvalRow(
                uid=bug_uid,
                index=bug_index,
                bugfix_passed=bool(bugfix_passed),
                bugfix_reward=float(bugfix_reward),
                codegen_uid=codegen_uid,
                codegen_passed=codegen_passed,
                codegen_reward=codegen_reward,
                codegen_skipped=bool(codegen_skipped),
                codegen_skip_reason=codegen_skip_reason,
                error=(ep.info or {}).get("bugfix_error", None),
                buggy_solution=(ep.info or {}).get("buggy_solution", None),
                fixed_solution=(ep.info or {}).get("fixed_solution", None),
                codegen_solution=(ep.info or {}).get("codegen_solution", None),
            )

            if out_f is not None:
                out_f.write(json.dumps(asdict(row)) + "\n")

            if total % 25 == 0:
                bugfix_rate = bugfix_ok / total if total else 0.0
                codegen_rate = (codegen_ok / codegen_total) if codegen_total else 0.0
                matched_str = f"(matched={codegen_total})" if args.do_codegen else "(codegen disabled)"
                print(f"[{total} eval'd] bugfix_pass@1={bugfix_rate:.3f}  codegen_pass@1={codegen_rate:.3f} {matched_str}")
    finally:
        if out_f is not None:
            out_f.close()

    print("\n" + "=" * 80)
    print(
        f"BugBench hf_dataset={args.hf_dataset} hf_split={args.hf_split} "
        f"slice=[{args.start}:{end}) shard={args.shard_id}/{args.num_shards} evaluated_examples={total}"
    )
    print(f"bugfix pass@1: {bugfix_ok/total:.6f} ({bugfix_ok}/{total})")
    if args.do_codegen:
        if codegen_total > 0:
            print(f"codegen pass@1 (BigCodeBench, matched by task_id): {codegen_ok/codegen_total:.6f} ({codegen_ok}/{codegen_total})")
        else:
            print("codegen pass@1 (BigCodeBench): n/a (no task_id matches found in this shard/slice)")
    else:
        print("codegen pass@1 (BigCodeBench): disabled (pass --do_codegen to enable)")
    print("=" * 80)


def main() -> None:
    asyncio.run(main_async())


if __name__ == "__main__":
    main()


