#!/usr/bin/env python3
"""
Evaluate a base code-generation model on BigCodeBench, then run a bug-fixer model
to modify the base output, and report pass@1 for both.

This script uses:
  - `examples.bugs.prompts._build_code_generation_prompt` for base prompting
  - `rllm.rewards.code_reward.RewardCodeFn` for unit test execution

It works with any OpenAI-compatible inference server (including OpenAI, vLLM, etc.)
via `rllm.engine.OpenAIEngine`.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import os
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass
from typing import Optional, Any

from examples.bugs.prompts import _build_bug_fixer_prompt, _build_code_generation_prompt
from rllm.data.utils import fetch_live_code_bench_system_prompt
from rllm.engine import OpenAIEngine
from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.code_reward import RewardCodeFn, extract_code_from_model
from rllm.rewards.reward_types import RewardConfig
from rllm.workflows.workflow import TerminationReason, Workflow

from datasets import load_dataset  # type: ignore[import-not-found]


def _resolve_api_key(base_url: str, explicit_api_key: Optional[str]) -> str:
    # If CLI arg is None/empty, fall back to environment.
    api_key = (explicit_api_key or os.getenv("OPENAI_API_KEY", "")) or ""
    api_key = str(api_key).strip()
    is_openai_api = "api.openai.com" in str(base_url)
    if is_openai_api and not api_key:
        raise ValueError(
            "base_url points to the OpenAI API, but OPENAI_API_KEY is missing/empty. "
            "Please export OPENAI_API_KEY or pass --*_api_key explicitly."
        )
    return api_key or "EMPTY"


def _format_code_block(code: str) -> str:
    return f"```python\n{code.rstrip()}\n```"


def _maybe_prepend_starter(code: str, starter_code: Optional[str]) -> str:
    """BigCodeBench often expects solutions to start from provided starter code."""
    if not starter_code or not isinstance(starter_code, str) or not starter_code.strip():
        return code
    sc = starter_code.strip()
    # If the model already included the starter code verbatim, don't duplicate.
    if sc in code:
        return code
    return f"{starter_code.rstrip()}\n\n{code.lstrip()}"


def _to_task(example: dict, idx: int) -> dict:
    """Adapt a raw HF BigCodeBench row into the minimal task schema used by prompts/reward."""
    instruct_prompt = str(example.get("instruct_prompt") or "")
    code_prompt = str(example.get("code_prompt") or "")
    entry_point = example.get("entry_point", None)
    test_code = example.get("test", None)

    # Build the LCB-style 'question' used by bug-fixer prompt in the bugs workflows.
    # This mirrors `examples/bugs/data_processing/prepare_bigcodebench_data.py`.
    question = fetch_live_code_bench_system_prompt(instruct_prompt, code_prompt) if instruct_prompt else instruct_prompt

    return {
        "uid": str(example.get("task_id", f"bigcodebench_{idx}")),
        "index": int(idx),
        "data_source": "bigcodebench",
        "question": question,
        "ground_truth": test_code,
        "entry_point": entry_point,
        "starter_code": code_prompt,
        "code_prompt": code_prompt,
        "instruct_prompt": instruct_prompt,
        "metadata": {"func_name": entry_point} if entry_point is not None else {},
    }


@dataclass
class EvalRow:
    uid: str
    index: int
    base_pass: bool
    fixed_pass: bool
    base_reward: float
    fixed_reward: float
    base_error: Optional[str] = None
    fixed_error: Optional[str] = None
    base_code: Optional[str] = None
    fixed_code: Optional[str] = None


async def _generate_one(engine: OpenAIEngine, messages: list[dict], **gen_kwargs) -> str:
    out = await engine.get_model_response(messages, **gen_kwargs)
    # Prefer `text` (may include thought delimiters), fall back to `content`.
    return (out.text or out.content or "").strip()


def _read_done_uids_from_jsonl(path: str) -> set[str]:
    done: set[str] = set()
    if not path or not os.path.exists(path):
        return done
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = (line or "").strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            uid = obj.get("uid", None)
            if isinstance(uid, str) and uid:
                done.add(uid)
    return done


class BigCodeBenchBugFixerWorkflow(Workflow):
    """
    Workflow for:
      1) Base code generation on BigCodeBench tasks
      2) Bug-fixer edit if the base solution fails

    Designed to run under `AgentWorkflowEngine` for massive in-process parallelism,
    plus sharding/resume at the script level.
    """

    def __init__(
        self,
        rollout_engine: OpenAIEngine,
        executor: ThreadPoolExecutor,
        *,
        fix_engine: OpenAIEngine,
        reward_executor: ThreadPoolExecutor,
        reward_config: Optional[RewardConfig] = None,
        base_system_prompt: Optional[str] = None,
        fix_system_prompt: Optional[str] = None,
        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, executor=executor, **kwargs)
        self.fix_engine = fix_engine
        self.reward_executor = reward_executor
        self.reward_fn = RewardCodeFn(reward_config or RewardConfig())
        self.base_system_prompt = base_system_prompt
        self.fix_system_prompt = fix_system_prompt

    async def _reward(self, task_info: dict, action: str):
        # BigCodeBench correctness checking forks subprocesses internally, so keep this bounded.
        loop = asyncio.get_running_loop()
        return await loop.run_in_executor(self.reward_executor, lambda: self.reward_fn(task_info=task_info, action=action))

    async def run(self, task: dict, uid: str, **kwargs) -> Episode:
        self.reset(task, uid)

        # ----------------
        # 1) Base code-gen
        # ----------------
        base_prompt = _build_code_generation_prompt(task)
        base_prompt += "\n\nReturn only the full Python solution inside a single ```python``` block."
        base_msgs: list[dict[str, str]] = []
        if self.base_system_prompt:
            base_msgs.append({"role": "system", "content": self.base_system_prompt})
        base_msgs.append({"role": "user", "content": base_prompt})

        base_out = await self.rollout_engine.get_model_response(base_msgs)
        base_raw = (base_out.text or base_out.content or "").strip()
        base_extracted = extract_code_from_model(base_raw) or ""
        base_code = _format_code_block(base_extracted) if base_extracted else base_raw

        base_reward_out = await self._reward(task, base_code)
        base_pass = bool(base_reward_out.is_correct)

        base_step = Step(
            chat_completions=base_msgs + [{"role": "assistant", "content": base_code}],
            thought="",
            action=base_code,
            model_response=base_raw,
            model_output=base_out,
            reward=float(base_reward_out.reward),
            info={"reward_metadata": base_reward_out.metadata or {}},
        )
        base_traj = Trajectory(name="base_codegen", steps=[base_step], reward=float(base_step.reward))

        # -----------------
        # 2) Bug-fixer edit (only if base fails)
        # -----------------
        if base_pass:
            # Do not call the bug-fixer model. For pass@1, treat this as a "fixed" pass.
            bugfix_called = False
            fix_code = base_code
            fixed_reward_out = base_reward_out
            fixed_pass = True
        else:
            bugfix_called = True
            problem = str(task.get("question") or task.get("instruct_prompt") or "")
            fix_prompt = _build_bug_fixer_prompt(problem, base_code)
            fix_msgs = []
            if self.fix_system_prompt:
                fix_msgs.append({"role": "system", "content": self.fix_system_prompt})
            fix_msgs.append({"role": "user", "content": fix_prompt})

            fix_out = await self.fix_engine.get_model_response(fix_msgs)
            fix_raw = (fix_out.text or fix_out.content or "").strip()
            fix_extracted = extract_code_from_model(fix_raw) or ""
            fix_code = _format_code_block(fix_extracted) if fix_extracted else fix_raw

            fixed_reward_out = await self._reward(task, fix_code)
            fixed_pass = bool(fixed_reward_out.is_correct)
            fix_step = Step(
                chat_completions=fix_msgs + [{"role": "assistant", "content": fix_code}],
                thought="",
                action=fix_code,
                model_response=fix_raw,
                model_output=fix_out,
                reward=float(fixed_reward_out.reward),
                info={"reward_metadata": fixed_reward_out.metadata or {}, "base_pass": base_pass},
            )
            fix_traj = Trajectory(name="bugfix", steps=[fix_step], reward=float(fix_step.reward))

        episode = Episode(
            id=uid,
            task=task,
            termination_reason=TerminationReason.UNKNOWN,
            is_correct=bool(fixed_pass),
            trajectories=[base_traj] + ([fix_traj] if bugfix_called else []),
            metrics={
                "base_pass": float(base_pass),
                "fixed_pass": float(fixed_pass),
                "base_reward": float(base_reward_out.reward),
                "fixed_reward": float(fixed_reward_out.reward),
                "bugfix_called": float(bugfix_called),
            },
            info={
                "base_code": base_code,
                "fixed_code": fix_code,
                "base_error": (base_reward_out.metadata or {}).get("error") if isinstance(base_reward_out.metadata, dict) else None,
                "fixed_error": (fixed_reward_out.metadata or {}).get("error") if isinstance(fixed_reward_out.metadata, dict) else None,
            },
        )
        return episode

    def assign_episode_correctness(self, episode: Episode) -> None:
        m = episode.metrics or {}
        episode.is_correct = bool(m.get("fixed_pass", 0.0))


async def main_async() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--hf_dataset", type=str, default="anonymous/bigcodebench", help="HF dataset name (default: anonymous/bigcodebench)")
    p.add_argument("--hf_split", type=str, default="v0.1.0_hf", help="HF split name (default: v0.1.0_hf)")
    p.add_argument("--max_examples", type=int, default=1200, help="Max number of examples to evaluate (default: 200)")
    p.add_argument("--start", type=int, default=0, help="Start index within the split (default: 0)")
    p.add_argument("--output_jsonl", type=str, default=None, help="Optional output JSONL path for per-example results")
    p.add_argument("--resume", action="store_true", help="If set, skip uids already present in --output_jsonl and append new rows")
    p.add_argument("--n_parallel", type=int, default=64, help="Number of parallel tasks to run in-process (bounded by model semaphore)")
    p.add_argument("--reward_workers", type=int, default=8, help="Max parallel unit-test executions (keep small; BigCodeBench uses subprocesses)")
    p.add_argument("--retry_limit", type=int, default=2, help="Retries per example on transient failures (API/timeouts)")
    p.add_argument("--num_shards", type=int, default=1, help="Total number of shards (for job arrays)")
    p.add_argument("--shard_id", type=int, default=0, help="Shard id in [0, num_shards)")

    # Base model (codegen)
    p.add_argument("--base_model", type=str, required=True, help="Base model name (OpenAI-compatible)")
    p.add_argument("--base_model_url", type=str, default="https://api.openai.com/v1", help="Base model base_url")
    p.add_argument("--base_api_key", type=str, default=None, help="Base model API key (defaults to OPENAI_API_KEY)")
    p.add_argument("--base_temperature", type=float, default=0.2)
    p.add_argument("--base_top_p", type=float, default=0.95)
    p.add_argument("--base_max_tokens", type=int, default=2048)
    p.add_argument("--base_system_prompt", type=str, default=None)

    # Bug-fixer model
    p.add_argument("--fix_model", type=str, required=True, help="Bug-fixer model name (OpenAI-compatible)")
    p.add_argument("--fix_model_url", type=str, default="http://localhost:30001/v1", help="Bug-fixer base_url")
    p.add_argument("--fix_api_key", type=str, default=None, help="Bug-fixer API key (defaults to OPENAI_API_KEY)")
    p.add_argument("--fix_temperature", type=float, default=0.6)
    p.add_argument("--fix_top_p", type=float, default=0.95)
    p.add_argument("--fix_max_tokens", type=int, default=2048)
    p.add_argument("--fix_system_prompt", type=str, default=None)

    args = p.parse_args()

    if int(args.num_shards) <= 0:
        raise ValueError("--num_shards must be >= 1")
    if int(args.shard_id) < 0 or int(args.shard_id) >= int(args.num_shards):
        raise ValueError("--shard_id must be in [0, num_shards)")

    hf = load_dataset(args.hf_dataset, split=args.hf_split)
    end = min(len(hf), args.start + args.max_examples)

    base_engine = OpenAIEngine(
        model=args.base_model,
        tokenizer=None,
        base_url=args.base_model_url,
        api_key=_resolve_api_key(args.base_model_url, args.base_api_key),
        sampling_params={"temperature": args.base_temperature, "top_p": args.base_top_p, "max_tokens": args.base_max_tokens},
        verbose=False,
    )
    fix_engine = OpenAIEngine(
        model=args.fix_model,
        tokenizer=None,
        base_url=args.fix_model_url,
        api_key=_resolve_api_key(args.fix_model_url, args.fix_api_key),
        sampling_params={"temperature": args.fix_temperature, "top_p": args.fix_top_p, "max_tokens": args.fix_max_tokens},
        verbose=False,
    )

    reward_fn = RewardCodeFn(RewardConfig())

    out_f = None
    if args.output_jsonl:
        os.makedirs(os.path.dirname(args.output_jsonl) or ".", exist_ok=True)
        mode = "a" if bool(args.resume) and os.path.exists(args.output_jsonl) else "w"
        out_f = open(args.output_jsonl, mode, encoding="utf-8")

    done_uids: set[str] = set()
    if bool(args.resume) and args.output_jsonl:
        done_uids = _read_done_uids_from_jsonl(args.output_jsonl)

    tasks: list[dict[str, Any]] = []
    for i in range(args.start, end):
        if (int(i) % int(args.num_shards)) != int(args.shard_id):
            continue
        task = _to_task(dict(hf[int(i)]), int(i))
        uid = str(task.get("uid", f"bigcodebench_{i}"))
        if done_uids and uid in done_uids:
            continue
        tasks.append(task)

    if not tasks:
        print("No examples selected (check --start/--max_examples, --num_shards/--shard_id, and/or --resume).")
        return

    total = 0
    base_ok = 0
    fixed_ok = 0

    try:
        reward_executor = ThreadPoolExecutor(max_workers=max(1, int(args.reward_workers)))

        task_ids: list[str] = [str(t.get("uid", "")) for t in tasks]
        engine = AgentWorkflowEngine(
            workflow_cls=BigCodeBenchBugFixerWorkflow,
            workflow_args={
                "fix_engine": fix_engine,
                "reward_executor": reward_executor,
                "reward_config": RewardConfig(),
                "base_system_prompt": args.base_system_prompt,
                "fix_system_prompt": args.fix_system_prompt,
            },
            rollout_engine=base_engine,
            config=None,
            n_parallel_tasks=int(args.n_parallel),
            retry_limit=int(args.retry_limit),
            raise_on_error=False,
        )

        episodes = await engine.execute_tasks(tasks, task_ids=task_ids)

        for ep in episodes:
            t = (ep.task or {}) if isinstance(ep.task, dict) else {}
            uid = str(t.get("uid", ""))
            idx = int(t.get("index", -1))

            base_pass = bool((ep.metrics or {}).get("base_pass", 0.0))
            fixed_pass = bool((ep.metrics or {}).get("fixed_pass", 0.0))
            base_reward = float((ep.metrics or {}).get("base_reward", 0.0))
            fixed_reward = float((ep.metrics or {}).get("fixed_reward", 0.0))

            base_error = (ep.info or {}).get("base_error", None)
            fixed_error = (ep.info or {}).get("fixed_error", None)
            if (base_error is None or fixed_error is None) and ep.termination_reason == TerminationReason.ERROR:
                err_msg = ((ep.info or {}).get("error", {}) or {}).get("error_message", None)
                if base_error is None:
                    base_error = err_msg
                if fixed_error is None:
                    fixed_error = err_msg

            row = EvalRow(
                uid=uid,
                index=idx,
                base_pass=base_pass,
                fixed_pass=fixed_pass,
                base_reward=base_reward,
                fixed_reward=fixed_reward,
                base_error=base_error,
                fixed_error=fixed_error,
                base_code=(ep.info or {}).get("base_code", None),
                fixed_code=(ep.info or {}).get("fixed_code", None),
            )

            total += 1
            base_ok += int(bool(base_pass))
            fixed_ok += int(bool(fixed_pass))

            if out_f is not None:
                out_f.write(json.dumps(asdict(row)) + "\n")

            if total % 25 == 0:
                print(f"[{total} eval'd] base_pass@1={base_ok/total:.3f} fixed_pass@1={fixed_ok/total:.3f}")

    finally:
        if out_f is not None:
            out_f.close()
        if "engine" in locals():
            engine.shutdown()
        if "reward_executor" in locals():
            reward_executor.shutdown(wait=True, cancel_futures=True)

    if total == 0:
        print("\nNo examples evaluated. Check --start/--max_examples.")
        return

    print("\n" + "=" * 80)
    print(
        f"BigCodeBench hf_dataset={args.hf_dataset} hf_split={args.hf_split} "
        f"slice=[{args.start}:{end}) shard={args.shard_id}/{args.num_shards} evaluated_examples={total}"
    )
    print(f"Base  pass@1: {base_ok/total:.6f} ({base_ok}/{total})")
    print(f"Fixed pass@1: {fixed_ok/total:.6f} ({fixed_ok}/{total})")
    print("=" * 80)


def main() -> None:
    asyncio.run(main_async())


if __name__ == "__main__":
    main()


