from __future__ import annotations

from dataclasses import dataclass
import hashlib
import random
from typing import Any, Dict, List, Optional, Tuple

from concurrent.futures import ThreadPoolExecutor

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine import ModelOutput, RolloutEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.rewards.reward_types import RewardOutput
from rllm.workflows.workflow import Workflow

from examples.bugs.prompts import (
    _build_bug_generator_prompt,
    _build_bug_fixer_prompt,
    _build_code_generation_prompt,
    _extract_failed_test_output,
)
from examples.bugs.llm_judge import BugSimilarityJudge, BugSimilarityJudgeConfig
from examples.bugs.code_embedding import (
    CodeEmbeddingConfig,
    CodeEmbedder,
    KNNBugSimilarity,
    ReferencePool,
)


# ---------------------------
# Task helpers
# ---------------------------

def _get_problem(task: Dict[str, Any]) -> str:
    """Extract problem description, handling different dataset schemas.
    
    Supports:
    - Preprocessed format: question (LCB-formatted)
    - BugBench format: instruct_prompt, complete_prompt
    - HumanEval format: prompt
    - MBPP format: text
    - Generic: problem, description
    """
    # Check in extra_info first if present (some datasets nest fields there)
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("question", "instruct_prompt", "complete_prompt", "prompt", "text", "problem", "description"):
            val = extra_info.get(key)
            if isinstance(val, str) and val.strip():
                return val
    
    # Check top-level task keys
    for key in ("question", "instruct_prompt", "complete_prompt", "prompt", "text", "problem", "description", "code_prompt"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    
    return ""


def _get_reference_solution(task: Dict[str, Any]) -> str:
    """Extract reference solution, handling different dataset schemas.
    
    Supports:
    - Preprocessed format: reference_solution
    - BugBench format: canonical_solution
    - HumanEval format: canonical_solution
    - Generic: solution, code, correct_code
    """
    # Check in extra_info first
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("reference_solution", "canonical_solution", "solution", "code", "correct_code"):
            val = extra_info.get(key)
            if isinstance(val, str) and val.strip():
                return val
    
    # Check top-level keys
    for key in ("reference_solution", "canonical_solution", "solution", "code", "correct_code"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    
    return ""


def _get_pregenerated_bug(task: Dict[str, Any]) -> Optional[str]:
    """Extract pregenerated buggy code if present in the task.
    
    Supports various dataset schemas:
    - BugBench: buggy_solution, buggy
    - Generic: buggy_code, bug
    """
    # Check in extra_info first
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("buggy_solution", "buggy_sampled_solution", "buggy", "buggy_code", "bug"):
            val = extra_info.get(key)
            if isinstance(val, str) and val.strip():
                return val
    
    # Check top-level keys
    for key in ("buggy_solution", "buggy_sampled_solution", "buggy", "buggy_code", "bug"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    
    return None


def _get_data_source(task: Dict[str, Any]) -> str:
    """Infer or extract data_source for reward function routing.
    
    The reward function uses data_source to determine which test runner to use.
    BugBench format uses the "bugbench" runner (same as bigcodebench).
    """
    # Check in extra_info first
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        ds = extra_info.get("data_source")
        if isinstance(ds, str) and ds.strip():
            return ds.strip()
    
    # Check top-level
    ds = task.get("data_source")
    if isinstance(ds, str) and ds.strip():
        return ds.strip()
    
    # Infer from task structure
    # BugBench format: has entry_point, test (string), and characteristic prompt fields
    has_entry_point = bool(
        task.get("entry_point") or 
        (extra_info and extra_info.get("entry_point"))
    )
    has_bugbench_prompts = any(
        task.get(k) or (extra_info and extra_info.get(k))
        for k in ("instruct_prompt", "complete_prompt", "code_prompt")
    )
    has_string_test = isinstance(task.get("test"), str) or (
        extra_info and isinstance(extra_info.get("test"), str)
    )
    
    # BugBench/BigCodeBench: entry_point is the strongest signal
    if has_entry_point:
        return "bugbench"
    
    # BugBench format without explicit entry_point but with characteristic fields
    if has_bugbench_prompts and has_string_test:
        return "bugbench"
    
    # Metadata with func_name indicates BigCodeBench-style
    metadata = task.get("metadata") or (extra_info.get("metadata") if extra_info else None) or {}
    if isinstance(metadata, dict) and metadata.get("func_name"):
        return "bugbench"
    
    # Default to livecodebench (works for DeepCoder-style tasks)
    return "livecodebench"


def _get_entry_point(task: Dict[str, Any]) -> Optional[str]:
    """Extract entry_point for BigCodeBench/BugBench test execution."""
    # Check in extra_info first
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        ep = extra_info.get("entry_point")
        if isinstance(ep, str) and ep.strip():
            return ep.strip()
        # Also check metadata
        meta = extra_info.get("metadata", {})
        if isinstance(meta, dict):
            ep = meta.get("func_name") or meta.get("entry_point")
            if isinstance(ep, str) and ep.strip():
                return ep.strip()
    
    # Check top-level
    ep = task.get("entry_point")
    if isinstance(ep, str) and ep.strip():
        return ep.strip()
    
    # Check metadata
    meta = task.get("metadata", {})
    if isinstance(meta, dict):
        ep = meta.get("func_name") or meta.get("entry_point")
        if isinstance(ep, str) and ep.strip():
            return ep.strip()
    
    return None


def _get_ground_truth(task: Dict[str, Any]) -> Any:
    """Extract ground_truth (tests) for reward function evaluation.
    
    Supports:
    - Preprocessed: ground_truth
    - BugBench: test (string)
    - HumanEval: test (string)
    - MBPP: test_list (list of strings)
    - LiveCodeBench: ground_truth (dict with inputs/outputs)
    """
    # Check in extra_info first
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("ground_truth", "test", "test_list", "tests"):
            val = extra_info.get(key)
            if val is not None and val != "":
                return val
    
    # Check top-level
    for key in ("ground_truth", "test", "test_list", "tests"):
        val = task.get(key)
        if val is not None and val != "":
            return val
    
    return None


# ---------------------------
# Role wrappers
# ---------------------------

@dataclass
class BugGeneratorConfig:
    system_prompt: Optional[str] = None


@dataclass
class BugFixerConfig:
    system_prompt: Optional[str] = None
    include_failed_test_output: bool = False


class BugGenerator:
    """Lightweight wrapper around the rollout engine for the bug-generator role."""

    def __init__(self, rollout_engine: RolloutEngine, config: Optional[BugGeneratorConfig] = None):
        self.rollout_engine = rollout_engine
        self.config = config or BugGeneratorConfig()

    async def generate_bug(self, task: Dict[str, Any], uid: str) -> Trajectory:
        """Generate a buggy version of the correct solution.

        Expected task fields (after preprocessing):
          - question: problem description (string)
          - reference_solution / canonical_solution: correct code (string)
        """
        problem = _get_problem(task)
        correct_code = _get_reference_solution(task)
        if not correct_code:
            raise KeyError("Task missing reference_solution/canonical_solution required for bug generation.")

        prompt = _build_bug_generator_prompt(problem, correct_code)

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        model_output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        buggy_code = model_output.content

        chat_completions = messages + [{"role": "assistant", "content": buggy_code}]
        step = Step(
            chat_completions=chat_completions,
            action=buggy_code,
            model_output=model_output,
        )
        return Trajectory(name="bug_generator", steps=[step])


class BugFixer:
    """Lightweight wrapper around the rollout engine for the bug-fixer (solver) role."""

    def __init__(self, rollout_engine: RolloutEngine, config: Optional[BugFixerConfig] = None):
        self.rollout_engine = rollout_engine
        self.config = config or BugFixerConfig()

    async def fix_bug(
        self,
        task: Dict[str, Any],
        buggy_code: str,
        uid: str,
        failed_test_output: Optional[str] = None,
    ) -> Trajectory:
        """Try to fix the buggy code and make all tests pass."""
        problem = _get_problem(task)
        prompt = _build_bug_fixer_prompt(
            problem,
            buggy_code,
            include_failed_test_output=self.config.include_failed_test_output,
            failed_test_output=failed_test_output,
        )

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        model_output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        fixed_code = model_output.content

        chat_completions = messages + [{"role": "assistant", "content": fixed_code}]
        step = Step(
            chat_completions=chat_completions,
            action=fixed_code,
            model_output=model_output,
        )
        return Trajectory(name="bug_fixer", steps=[step])

    async def generate_code(self, task: Dict[str, Any], uid: str) -> Trajectory:
        """Generate code from scratch for regular code generation task."""
        prompt = _build_code_generation_prompt(task)

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        model_output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        generated_code = model_output.content

        chat_completions = messages + [{"role": "assistant", "content": generated_code}]
        step = Step(
            chat_completions=chat_completions,
            action=generated_code,
            model_output=model_output,
        )
        return Trajectory(name="code_generator", steps=[step])


# ---------------------------
# Bug validity
# ---------------------------

def check_bug_validity(
    bug_meta: Dict[str, Any],
    bug_reward_output: RewardOutput,
    compile_errors_invalid: bool = True,
) -> Tuple[bool, bool]:
    """A valid bug must:
      1) Have no compilation errors (if compile_errors_invalid=True)
      2) Fail at least one unit test (passed_tests < total_tests OR all_passed=False)
    """
    total_tests = bug_meta.get("total_tests")
    passed_tests = bug_meta.get("passed_tests")
    all_passed = bug_meta.get("all_passed", False)

    test_results = bug_meta.get("test_results", [])
    has_compile_error = False

    if isinstance(test_results, list):
        for test in test_results:
            msg = str(test.get("error_message", "") or "")
            msg_l = msg.lower()
            if not msg:
                continue

            if "error during testing:" in msg_l:
                has_compile_error = True
                break

            # Heuristic compile-ish patterns (exclude "wrong answer" which is a legit failing test).
            if "wrong answer" not in msg_l:
                compile_patterns = [
                    "syntax", "syntaxerror", "compilation", "compile error",
                    "cannot compile", "indentation", "invalid syntax",
                    "unexpected", "eof", "unterminated", "was never closed",
                    "nameerror", "typeerror", "attributeerror", "import error",
                    "module not found", "indentationerror",
                ]
                if any(p in msg_l for p in compile_patterns):
                    has_compile_error = True
                    break

    if compile_errors_invalid and has_compile_error:
        bug_valid = False
    elif total_tests is not None and passed_tests is not None and total_tests > 0:
        bug_valid = passed_tests < total_tests
    elif all_passed is False:
        bug_valid = True
    else:
        bug_valid = not bug_reward_output.is_correct
        if compile_errors_invalid:
            bug_valid = bug_valid and not has_compile_error

    return bug_valid, has_compile_error


# ---------------------------
# SSR-style generator reward shaping
# ---------------------------

def _shaped_generator_reward_from_solve_rate(
    *,
    bug_valid: bool,
    solve_rate: float,
    mode: str,
    band_low: float,
    band_high: float,
    alpha_extreme: float,
    invalid_bug_reward: float,
) -> float:
    """SSR-like shaping: reward 'frontier' examples and penalize degeneracy.

    Modes:
      - "binary": old behavior: reward iff (bug_valid and solve_rate < 1.0)
      - "band":   +1 if solve_rate in [band_low, band_high],
                  -alpha if solve_rate in {0,1}, else 0
      - "smooth": peak in the middle of the band, 0 outside, -alpha at {0,1}

    Notes:
      - solve_rate in [0,1], computed from K solver attempts.
      - Penalizing solve_rate==0 prevents "unsolvable" bugs dominating.
      - Penalizing solve_rate==1 prevents "too easy" / no-op bugs dominating.
    """
    mode = (mode or "band").lower().strip()
    solve_rate = max(0.0, min(1.0, float(solve_rate)))

    if not bug_valid:
        return float(invalid_bug_reward)

    # Extremes (degenerate) get explicit penalty.
    if solve_rate <= 0.0 or solve_rate >= 1.0:
        return float(-alpha_extreme)

    if mode == "binary":
        # old-ish: reward if not trivially solvable by everyone
        return 1.0 if solve_rate < 1.0 else 0.0

    band_low = float(band_low)
    band_high = float(band_high)
    if band_low > band_high:
        band_low, band_high = band_high, band_low

    if mode == "band":
        return 1.0 if (band_low <= solve_rate <= band_high) else 0.0

    if mode == "smooth":
        # Piecewise linear bump: 0 outside band, 1 at band midpoint.
        if solve_rate < band_low or solve_rate > band_high:
            return 0.0
        mid = 0.5 * (band_low + band_high)
        half = max(1e-8, 0.5 * (band_high - band_low))
        # 1 at mid, 0 at edges
        return float(max(0.0, 1.0 - abs(solve_rate - mid) / half))

    # Fallback
    return 1.0 if (band_low <= solve_rate <= band_high) else 0.0


# ---------------------------
# Workflow
# ---------------------------

class GeneratorSolverWorkflow(Workflow):
    """Train-time: generator creates bug -> solver fixes it (K attempts for solve-rate).
    Validation-time (BugBench style): if task contains a pre-generated bug
    (buggy_solution/buggy/buggy_code), skip generation and evaluate solver on that held-out bug.
    """

    def __init__(
        self,
        rollout_engine: RolloutEngine,
        executor: ThreadPoolExecutor,
        reward_function: RewardFunction,
        generator_system_prompt: Optional[str] = None,
        solver_system_prompt: Optional[str] = None,
        evaluate_codegen: bool = True,

        # Validation behavior
        use_pregenerated_bugs_in_validation: bool = True,

        # Training behavior: optionally use human/pregenerated bugs if present on the task.
        # This is useful when you *mix in* a separate bug dataset into the training set.
        use_pregenerated_bugs_in_training: bool = False,
        pregenerated_bug_train_probability: float = 1.0,

        # Defines what Episode.is_correct means (and thus val/pass@k):
        # - "bugfix": success iff (bug_valid and solver_pass_any)
        # - "codegen": success iff solver_codegen_pass (validation only; falls back to bugfix on train)
        episode_success_mode: str = "bugfix",

        # NEW: solve-rate evaluation knobs
        solver_attempts_train: int = 8,
        solver_attempts_val: int = 1,

        # NEW: generator reward shaping knobs (SSR-like)
        generator_reward_mode: str = "band",  # "band" | "smooth" | "binary"
        solve_rate_band_low: float = 0.05,
        solve_rate_band_high: float = 0.25,
        gen_alpha_extreme: float = 0.2,        # penalty when solve_rate in {0,1}
        gen_invalid_bug_reward: float = -1.0,  # penalty for invalid bug

        # NEW: solver reward style
        solver_reward_pm1: bool = False,       # False => {0,1}, True => {-1,+1}

        # Include failed test output in solver prompts
        include_failed_test_output: bool = True,

        # LLM-as-judge for bug similarity scoring
        use_bug_similarity_judge: bool = False,
        bug_similarity_reward_weight: float = 0.5,  # Weight for auxiliary reward
        judge_system_prompt: Optional[str] = None,
        # Optional: run the judge against a separate OpenAI-compatible endpoint/model
        # (e.g., a vLLM OpenAI server), instead of using the training rollout engine.
        judge_base_url: Optional[str] = None,
        judge_model_name: Optional[str] = None,
        # Reference bugs for similarity comparison (list of tasks with buggy_solution)
        # If None, the judge will be disabled even if use_bug_similarity_judge=True
        reference_bugs: Optional[List[Dict[str, Any]]] = None,
        # Number of target bugs to compare against (averaged for final score)
        bug_similarity_n_targets: int = 3,

        # Code-embedding similarity (optional auxiliary reward)
        use_code_embedding_similarity: bool = False,
        code_embedding_reward_weight: float = 0.3,
        code_embedding_model_name: str = "voyage-code-3",
        code_embedding_include_problem: bool = True,
        code_embedding_top_k: int = 5,

        # Pools: you can pass tasks directly OR load pools from disk
        # If code_embedding_reference_bugs is None, we will fall back to `reference_bugs`
        code_embedding_reference_bugs: Optional[List[Dict[str, Any]]] = None,
        code_embedding_negative_bugs: Optional[List[Dict[str, Any]]] = None,
        code_embedding_target_pool_path: Optional[str] = None,
        code_embedding_negative_pool_path: Optional[str] = None,

        # If negative pool exists, use margin = target - negative and sigmoid it
        code_embedding_use_margin: bool = True,
        code_embedding_margin_temperature: float = 10.0,

        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, executor=executor, **kwargs)
        self.reward_function = reward_function
        self.generator = BugGenerator(rollout_engine, BugGeneratorConfig(system_prompt=generator_system_prompt))
        self.solver = BugFixer(
            rollout_engine,
            BugFixerConfig(
                system_prompt=solver_system_prompt,
                include_failed_test_output=bool(include_failed_test_output),
            ),
        )

        self.evaluate_codegen = bool(evaluate_codegen)
        self.use_pregenerated_bugs_in_validation = bool(use_pregenerated_bugs_in_validation)
        self.use_pregenerated_bugs_in_training = bool(use_pregenerated_bugs_in_training)
        self.pregenerated_bug_train_probability = float(pregenerated_bug_train_probability)
        self.episode_success_mode = str(episode_success_mode).lower().strip()

        self.solver_attempts_train = max(1, int(solver_attempts_train))
        self.solver_attempts_val = max(1, int(solver_attempts_val))

        self.generator_reward_mode = str(generator_reward_mode)
        self.solve_rate_band_low = float(solve_rate_band_low)
        self.solve_rate_band_high = float(solve_rate_band_high)
        self.gen_alpha_extreme = float(gen_alpha_extreme)
        self.gen_invalid_bug_reward = float(gen_invalid_bug_reward)

        self.solver_reward_pm1 = bool(solver_reward_pm1)

        # LLM-as-judge for bug similarity
        self.use_bug_similarity_judge = bool(use_bug_similarity_judge)
        self.bug_similarity_reward_weight = float(bug_similarity_reward_weight)
        self.bug_similarity_n_targets = max(1, int(bug_similarity_n_targets))
        self.bug_similarity_judge: Optional[BugSimilarityJudge] = None
        self.judge_rollout_engine: RolloutEngine = rollout_engine

                # ---------------------------
        # Code-embedding similarity (aux reward)
        # ---------------------------
        self.use_code_embedding_similarity = bool(use_code_embedding_similarity)
        self.code_embedding_reward_weight = float(code_embedding_reward_weight)
        self.code_embedding_use_margin = bool(code_embedding_use_margin)
        self.code_embedding_margin_temperature = float(code_embedding_margin_temperature)

        self._code_embedder: Optional[Any] = None
        self._code_knn_target: Optional[Any] = None
        self._code_knn_negative: Optional[Any] = None

        if self.use_code_embedding_similarity:
            if CodeEmbedder is None or KNNBugSimilarity is None:
                print("[CodeEmbedding] WARNING: examples.bugs.code_embedding could not be imported. Disabling.")
                self.use_code_embedding_similarity = False
            else:
                # Build embedder
                emb_cfg = CodeEmbeddingConfig(
                    enabled=True,
                    reward_weight=self.code_embedding_reward_weight,
                    model_name=str(code_embedding_model_name),
                    include_problem=bool(code_embedding_include_problem),
                    top_k=int(code_embedding_top_k),
                )
                self._code_embedder = CodeEmbedder(emb_cfg)

                # Target pool
                self._code_knn_target = KNNBugSimilarity(self._code_embedder, top_k=int(code_embedding_top_k))

                # Load pool from disk if provided, else build from tasks
                if code_embedding_target_pool_path:
                    try:
                        pool = ReferencePool.load(str(code_embedding_target_pool_path))
                        self._code_knn_target.reference_pool = pool
                        print(f"[CodeEmbedding] Loaded TARGET pool from {code_embedding_target_pool_path} (n={len(pool)})")
                    except Exception as e:
                        print(f"[CodeEmbedding] WARNING: Failed loading TARGET pool: {e}")
                else:
                    target_tasks = code_embedding_reference_bugs if code_embedding_reference_bugs is not None else reference_bugs
                    if target_tasks:
                        try:
                            self._code_knn_target.build_reference_pool(list(target_tasks))
                            print(f"[CodeEmbedding] Built TARGET pool (n={len(self._code_knn_target.reference_pool)})")
                        except Exception as e:
                            print(f"[CodeEmbedding] WARNING: Failed building TARGET pool: {e}")

                # Negative pool (optional)
                if code_embedding_negative_pool_path or code_embedding_negative_bugs:
                    self._code_knn_negative = KNNBugSimilarity(self._code_embedder, top_k=int(code_embedding_top_k))

                    if code_embedding_negative_pool_path:
                        try:
                            pool = ReferencePool.load(str(code_embedding_negative_pool_path))
                            self._code_knn_negative.reference_pool = pool
                            print(f"[CodeEmbedding] Loaded NEGATIVE pool from {code_embedding_negative_pool_path} (n={len(pool)})")
                        except Exception as e:
                            print(f"[CodeEmbedding] WARNING: Failed loading NEGATIVE pool: {e}")
                    else:
                        try:
                            self._code_knn_negative.build_reference_pool(list(code_embedding_negative_bugs or []))
                            print(f"[CodeEmbedding] Built NEGATIVE pool (n={len(self._code_knn_negative.reference_pool)})")
                        except Exception as e:
                            print(f"[CodeEmbedding] WARNING: Failed building NEGATIVE pool: {e}")

                # Final sanity
                if self._code_knn_target is None or len(self._code_knn_target.reference_pool) == 0:
                    print("[CodeEmbedding] WARNING: Enabled but TARGET pool is empty. Disabling.")
                    self.use_code_embedding_similarity = False
        
        # Build reference bug pool from provided reference_bugs
        self.reference_bug_pool: List[Dict[str, Any]] = []
        if reference_bugs:
            for ref_task in reference_bugs:
                bug = _get_pregenerated_bug(ref_task)
                if bug:
                    self.reference_bug_pool.append({
                        "bug": bug,
                        "problem": _get_problem(ref_task),
                        "ground_truth": _get_reference_solution(ref_task),
                        "uid": ref_task.get("uid", ref_task.get("task_id", "")),
                    })
        
        # Optionally create a dedicated judge engine (OpenAI-compatible HTTP API).
        # NOTE: OpenAI-compatible APIs require a model name; base_url alone isn't sufficient.
        if self.use_bug_similarity_judge and judge_base_url and not judge_model_name:
            print(
                "[BugSimilarityJudge] WARNING: judge_base_url was provided but judge_model_name is empty; "
                "falling back to training rollout engine."
            )

        if self.use_bug_similarity_judge and judge_model_name:
            try:
                import os

                from rllm.engine.rollout.openai_engine import OpenAIEngine

                self.judge_rollout_engine = OpenAIEngine(
                    model=str(judge_model_name or ""),
                    tokenizer=None,  # judge is API-only; use chat completions endpoint
                    max_prompt_length=8192,
                    max_response_length=2048,
                    base_url=str(judge_base_url or "https://api.openai.com/v1"),
                    api_key=os.getenv("OPENAI_API_KEY") or "EMPTY",
                    sampling_params={"temperature": 0.3, "top_p": 0.95},
                    verbose=False,
                )
                print(
                    f"[BugSimilarityJudge] Using dedicated judge engine: "
                    f"model={judge_model_name!r} base_url={judge_base_url!r}"
                )
            except Exception as e:
                print(
                    f"[BugSimilarityJudge] WARNING: Failed to initialize dedicated judge engine "
                    f"(model={judge_model_name!r} base_url={judge_base_url!r}): {e}. "
                    "Falling back to training rollout engine."
                )
                self.judge_rollout_engine = rollout_engine

        if self.use_bug_similarity_judge and self.reference_bug_pool:
            self.bug_similarity_judge = BugSimilarityJudge(
                self.judge_rollout_engine,
                BugSimilarityJudgeConfig(
                    enabled=True,
                    reward_weight=self.bug_similarity_reward_weight,
                    system_prompt=judge_system_prompt,
                ),
            )
            print(f"[BugSimilarityJudge] Initialized with {len(self.reference_bug_pool)} reference bugs, comparing against {self.bug_similarity_n_targets} targets per bug")
        elif self.use_bug_similarity_judge:
            print("[BugSimilarityJudge] WARNING: Enabled but no reference bugs provided. Judge will be disabled.")

    def _normalize_task_info(self, task: Dict[str, Any]) -> Dict[str, Any]:
        """Normalize task into a format the reward function expects.
        
        The reward function (RewardCodeFn) expects:
        - data_source: str (e.g., "livecodebench", "bigcodebench", "bugbench")
        - ground_truth: test cases (format depends on data_source)
        - entry_point: str (optional, required for bigcodebench/bugbench)
        
        This method handles tasks from different datasets with varying schemas.
        """
        # Start with extra_info if present, otherwise use task itself
        base_info = task.get("extra_info", task)
        
        # Build normalized task_info
        task_info = dict(base_info) if isinstance(base_info, dict) else {}
        
        # Ensure ground_truth is set
        if "ground_truth" not in task_info or task_info["ground_truth"] is None:
            gt = _get_ground_truth(task)
            if gt is not None:
                task_info["ground_truth"] = gt
        
        # Ensure data_source is set (reward function uses this for routing)
        if "data_source" not in task_info or not task_info["data_source"]:
            task_info["data_source"] = _get_data_source(task)
        
        # Ensure entry_point is set if available (needed for bigcodebench/bugbench)
        if "entry_point" not in task_info or not task_info["entry_point"]:
            ep = _get_entry_point(task)
            if ep:
                task_info["entry_point"] = ep
        
        return task_info

    def _set_episode_is_correct(
        self,
        episode: Episode,
        *,
        is_validation: bool,
        bug_valid: bool,
        solver_pass_any: bool,
        solver_pass_all: bool,
        codegen_pass: bool,
    ) -> None:
        mode = self.episode_success_mode
        if is_validation:
            if mode == "codegen":
                episode.is_correct = bool(codegen_pass)
            elif mode == "bugfix":
                # Note that solver_attempts_val should be 1
                episode.is_correct = solver_pass_all
        else:
            # Training: do not reward trivial bugs or too hard bugs
            episode.is_correct = bool(bug_valid and solver_pass_any)
    
    @staticmethod
    def _sigmoid(x: float) -> float:
        # numerically stable-ish for typical temps/margins
        if x >= 0:
            z = math.exp(-x)
            return 1.0 / (1.0 + z)
        else:
            z = math.exp(x)
            return z / (1.0 + z)

    def _score_code_embedding_aux_sync(self, problem: str, buggy_code: str) -> Tuple[float, Dict[str, Any]]:
        """
        Returns:
          aux_reward in [0,1] and metadata
        """
        if not self.use_code_embedding_similarity or self._code_knn_target is None:
            return 0.0, {"disabled": True}

        # score against target pool
        target_score, target_meta = self._code_knn_target.score_similarity(problem, buggy_code)

        # optionally score against negative pool and turn into margin-based probability
        if (
            self.code_embedding_use_margin
            and self._code_knn_negative is not None
            and len(self._code_knn_negative.reference_pool) > 0
        ):
            neg_score, neg_meta = self._code_knn_negative.score_similarity(problem, buggy_code)
            margin = float(target_score - neg_score)
            z = float(self.code_embedding_margin_temperature) * margin
            aux = float(self._sigmoid(z))
            meta = {
                "target_score": float(target_score),
                "negative_score": float(neg_score),
                "margin": float(margin),
                "margin_temperature": float(self.code_embedding_margin_temperature),
                "target_meta": target_meta,
                "negative_meta": neg_meta,
            }
            return aux, meta

        # fallback: just use target score
        return float(target_score), {
            "target_score": float(target_score),
            "target_meta": target_meta,
            "used_margin": False,
        }


    async def run(self, task: Dict[str, Any], uid: str, **kwargs) -> Episode:
        self.reset(task, uid)

        is_validation = bool(getattr(self.rollout_engine, "validate", False))
        task_info = self._normalize_task_info(task)

        pregenerated_bug: Optional[str] = None
        allow_pregenerated = (is_validation and self.use_pregenerated_bugs_in_validation) or (
            (not is_validation) and self.use_pregenerated_bugs_in_training
        )
        if allow_pregenerated:
            cand = _get_pregenerated_bug(task)
            if cand is not None:
                if is_validation:
                    pregenerated_bug = cand
                else:
                    p = max(0.0, min(1.0, float(self.pregenerated_bug_train_probability)))
                    if p >= 1.0:
                        pregenerated_bug = cand
                    elif p <= 0.0:
                        pregenerated_bug = None
                    else:
                        # Stable deterministic sampling per-episode (avoid global RNG drift).
                        seed = int(hashlib.md5(uid.encode("utf-8")).hexdigest(), 16) % (2**32)
                        if random.Random(seed).random() < p:
                            pregenerated_bug = cand

        # ---------------------------
        # BUG SOURCE
        # ---------------------------
        bug_traj: Optional[Trajectory] = None
        bug_step: Optional[Step] = None

        if pregenerated_bug is None:
            bug_traj = await self.generator.generate_bug(task, uid)
            bug_step = bug_traj.steps[0]
            buggy_code = bug_step.action
        else:
            buggy_code = pregenerated_bug

        # ---------------------------
        # BUG VALIDITY (via reward_fn)
        # ---------------------------
        try:
            bug_reward_output = self.reward_function(task_info=task_info, action=buggy_code)
        except Exception as e:
            bug_reward_output = RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={"error": f"bug reward error: {e}"},
            )

        bug_meta = bug_reward_output.metadata or {}
        bug_valid, has_compile_error = check_bug_validity(
            bug_meta=bug_meta,
            bug_reward_output=bug_reward_output,
            compile_errors_invalid=True,
        )
        total_tests = bug_meta.get("total_tests")
        passed_tests = bug_meta.get("passed_tests")

        # Extract failed test output for solver prompt (if enabled)
        failed_test_output: Optional[str] = None
        if self.solver.config.include_failed_test_output:
            failed_test_output = _extract_failed_test_output(bug_meta)

        # ---------------------------
        # SOLVER FIX (K attempts -> solve_rate)
        # ---------------------------
        solver_attempts = self.solver_attempts_val if is_validation else self.solver_attempts_train

        solver_trajs: List[Trajectory] = []
        solver_passes: List[bool] = []

        for k in range(solver_attempts):
            t = await self.solver.fix_bug(task, buggy_code, f"{uid}_solve{k}", failed_test_output=failed_test_output)
            # Ensure unique names (useful for debugging/logging)
            t.name = f"bug_fixer_{k}"
            solver_trajs.append(t)

            fixed_code = t.steps[0].action
            try:
                out = self.reward_function(task_info=task_info, action=fixed_code)
                sp = bool(out.is_correct)
            except Exception:
                sp = False
            solver_passes.append(sp)

        num_pass = int(sum(1 for p in solver_passes if p))
        solve_rate = float(num_pass) / float(solver_attempts)
        solver_pass_any = bool(num_pass > 0)
        solver_pass_all = bool(num_pass == solver_attempts)

        # Assign per-attempt solver rewards (conditioned on bug_valid!)
        for k, traj in enumerate(solver_trajs):
            sp = solver_passes[k]
            if not bug_valid:
                r = 0.0 if not self.solver_reward_pm1 else -1.0
            else:
                if self.solver_reward_pm1:
                    r = 1.0 if sp else -1.0
                else:
                    r = 1.0 if sp else 0.0
            traj.steps[0].reward = float(r)

        # ---------------------------
        # OPTIONAL CODEGEN EVAL (VAL ONLY)
        # ---------------------------
        codegen_traj: Optional[Trajectory] = None
        codegen_pass = False
        if self.evaluate_codegen and is_validation:
            codegen_traj = await self.solver.generate_code(task, uid)
            codegen_step = codegen_traj.steps[0]
            generated_code = codegen_step.action
            try:
                codegen_reward_output = self.reward_function(task_info=task_info, action=generated_code)
                codegen_pass = bool(codegen_reward_output.is_correct)
            except Exception:
                codegen_pass = False
            # No training reward for codegen eval
            codegen_step.reward = 0.0

        # ---------------------------
        # LLM-AS-JUDGE BUG SIMILARITY (optional auxiliary reward)
        # ---------------------------
        bug_similarity_score = 0.0
        bug_similarity_meta: Dict[str, Any] = {}
        
        # Only compute similarity if:
        # 1. Judge is enabled and has reference bugs
        # 2. We generated a bug (not using pregenerated)
        if (
            self.bug_similarity_judge is not None
            and self.reference_bug_pool
            and pregenerated_bug is None  # We generated this bug
        ):
            import asyncio
            
            # Sample N reference bugs from the pool (deterministic based on uid)
            seed = int(hashlib.md5(uid.encode("utf-8")).hexdigest(), 16) % (2**32)
            rng = random.Random(seed)
            
            # Cap n_targets at pool size to avoid duplicates
            n_targets = min(self.bug_similarity_n_targets, len(self.reference_bug_pool))
            sampled_refs = rng.sample(self.reference_bug_pool, n_targets)
            
            generated_problem = _get_problem(task)
            generated_ground_truth = _get_reference_solution(task)
            
            # Score similarity against each target bug in parallel
            async def score_one(ref_entry: Dict[str, Any], idx: int) -> Tuple[float, Dict[str, Any]]:
                score, meta = await self.bug_similarity_judge.score_similarity(
                    generated_problem=generated_problem,
                    generated_bug=buggy_code,
                    target_problem=ref_entry.get("problem", ""),
                    target_bug=ref_entry["bug"],
                    uid=f"{uid}_judge_{idx}",
                    generated_ground_truth=generated_ground_truth,
                    target_ground_truth=ref_entry.get("ground_truth", ""),
                )
                meta["reference_uid"] = ref_entry.get("uid", "")
                return score, meta
            
            # Run all comparisons in parallel
            results = await asyncio.gather(*[
                score_one(ref, i) for i, ref in enumerate(sampled_refs)
            ])
            
            # Average the scores
            scores = [r[0] for r in results]
            bug_similarity_score = sum(scores) / len(scores) if scores else 0.0
            
            # Aggregate metadata
            bug_similarity_meta = {
                "n_targets": n_targets,
                "individual_scores": [r[0] for r in results],
                "avg_score": bug_similarity_score,
                "reference_uids": [r[1].get("reference_uid", "") for r in results],
                "individual_reasoning": [r[1].get("reasoning", "") for r in results],
            }
        
        # ---------------------------
        # CODE-EMBEDDING BUG SIMILARITY (optional auxiliary reward)
        # ---------------------------
        code_embed_score = 0.0
        code_embed_meta: Dict[str, Any] = {}

        if (
            self.use_code_embedding_similarity
            and pregenerated_bug is None  # only reward generator when it generated the bug
        ):
            # If bug is invalid, we usually don't want style reward to override "invalid" penalty.
            if bug_valid:
                loop = asyncio.get_running_loop()
                problem = _get_problem(task)

                # run in threadpool to avoid blocking event loop (Voyage/local HF can be slow)
                code_embed_score, code_embed_meta = await loop.run_in_executor(
                    self.executor,
                    lambda: self._score_code_embedding_aux_sync(problem, buggy_code),
                )
            else:
                code_embed_score, code_embed_meta = 0.0, {"skipped": True, "reason": "bug_invalid"}

        # ---------------------------
        # ASSIGN GENERATOR RL REWARD (TRAINING SIGNAL)
        # ---------------------------
        if pregenerated_bug is not None:
            # No generator step happened, so there's no generator RL signal to assign.
            generator_reward = 0.0
            generator_base_reward = 0.0
        else:
            generator_base_reward = _shaped_generator_reward_from_solve_rate(
                bug_valid=bug_valid,
                solve_rate=solve_rate,
                mode=self.generator_reward_mode,
                band_low=self.solve_rate_band_low,
                band_high=self.solve_rate_band_high,
                alpha_extreme=self.gen_alpha_extreme,
                invalid_bug_reward=self.gen_invalid_bug_reward,
            )
            
            # Combine base reward with optional auxiliary similarity rewards (judge + embeddings)
            aux_scores: List[float] = []
            aux_weights: List[float] = []

            if self.bug_similarity_judge is not None and self.reference_bug_pool and pregenerated_bug is None:
                aux_scores.append(float(bug_similarity_score))
                aux_weights.append(float(self.bug_similarity_reward_weight))

            if self.use_code_embedding_similarity and pregenerated_bug is None:
                aux_scores.append(float(code_embed_score))
                aux_weights.append(float(self.code_embedding_reward_weight))

            if aux_scores and sum(aux_weights) > 0:
                w_total = min(1.0, float(sum(aux_weights)))
                aux_avg = float(sum(w * s for w, s in zip(aux_weights, aux_scores)) / sum(aux_weights))
                generator_reward = (1.0 - w_total) * generator_base_reward + w_total * aux_avg
            else:
                generator_reward = generator_base_reward


        if bug_step is not None:
            bug_step.reward = float(generator_reward)

        # ---------------------------
        # BUILD EPISODE
        # ---------------------------
        trajectories: List[Trajectory] = []
        if bug_traj is not None:
            trajectories.append(bug_traj)
        trajectories.extend(solver_trajs)
        if codegen_traj is not None:
            trajectories.append(codegen_traj)

        metrics: Dict[str, Any] = {"solver_solve_rate": float(solve_rate)}
        if self.evaluate_codegen:
            metrics["solver_codegen_pass"] = float(codegen_pass)
            if not is_validation:
                if self.use_code_embedding_similarity and code_embed_meta:
                    metrics["code_embed_score"] = float(code_embed_score)
                    # helpful extras if present
                    if "target_score" in code_embed_meta:
                        metrics["code_embed_target_score"] = float(code_embed_meta["target_score"])
                    if "negative_score" in code_embed_meta:
                        metrics["code_embed_negative_score"] = float(code_embed_meta["negative_score"])
                    if "margin" in code_embed_meta:
                        metrics["code_embed_margin"] = float(code_embed_meta["margin"])

        episode = Episode(
            id=uid,
            task=task,
            trajectories=trajectories,
            is_correct=False,  # set below
            metrics=metrics,
        )

        self._set_episode_is_correct(
            episode,
            is_validation=is_validation,
            bug_valid=bug_valid,
            solver_pass_any=solver_pass_any,
            solver_pass_all=solver_pass_all,
            codegen_pass=codegen_pass,
        )

        self.assign_episode_correctness(episode)
        return episode

    def assign_episode_correctness(self, episode: Episode) -> None:
        # We set is_correct explicitly in _set_episode_is_correct.
        return
