"""
Async-safe evaluation version of GeneratorFixerWorkflow.

This file wraps blocking reward_function calls in run_in_executor to avoid
blocking the asyncio event loop during parallel evaluation.

Use this for evaluation scripts (run_generator_fixer_flow.py).
Training code should continue using generator_fixer_flow.py.
"""
from __future__ import annotations

import asyncio
import functools
import hashlib
import math
import random
from typing import Any, Dict, List, Optional, Tuple

from concurrent.futures import ThreadPoolExecutor

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine import RolloutEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.rewards.reward_types import RewardOutput
from rllm.workflows.workflow import Workflow

from examples.bugs.prompts import (
    _extract_failed_test_output,
)
from examples.bugs.code_embedding import (
    CodeEmbeddingConfig,
    CodeEmbedder,
    KNNBugSimilarity,
    ReferencePool,
)
from examples.bugs_refactor.components import (
    BugGenerator,
    BugGeneratorConfig,
    BugFixer,
    BugFixerConfig,
)
from examples.bugs_refactor.utils import (
    normalize_task_info,
    check_bug_validity,
    _get_problem,
    _get_reference_solution,
    _get_pregenerated_bug,
    _shaped_generator_reward_from_solve_rate,
)


class GeneratorFixerWorkflowEval(Workflow):
    """Async-safe evaluation version of GeneratorFixerWorkflow.
    
    This version wraps blocking reward_function calls in run_in_executor
    to prevent blocking the asyncio event loop during parallel evaluation.
    
    Train-time: generator creates bug -> fixer fixes it (K attempts for solve-rate).
    Validation-time (BugBench style): if task contains a pre-generated bug
    (buggy_solution/buggy/buggy_code), skip generation and evaluate fixer on that held-out bug.
    """

    def __init__(
        self,
        rollout_engine: RolloutEngine,
        executor: ThreadPoolExecutor,
        reward_function: RewardFunction,
        generator_system_prompt: Optional[str] = None,
        fixer_system_prompt: Optional[str] = None,
        evaluate_codegen: bool = True,

        # Validation behavior
        use_pregenerated_bugs_in_validation: bool = True,

        # Training behavior: optionally use human/pregenerated bugs if present on the task.
        # This is useful when you *mix in* a separate bug dataset into the training set.
        use_pregenerated_bugs_in_training: bool = False,
        pregenerated_bug_train_probability: float = 1.0,

        # Defines what Episode.is_correct means (and thus val/pass@k):
        # - "bugfix": success iff (bug_valid and fixer_pass_any)
        # - "codegen": success iff fixer_codegen_pass (validation only; falls back to bugfix on train)
        episode_success_mode: str = "bugfix",

        # NEW: solve-rate evaluation knobs
        fixer_attempts_train: int = 8,
        fixer_attempts_val: int = 1,

        # NEW: generator reward shaping knobs (SSR-like)
        generator_reward_mode: str = "band",  # "band" | "smooth" | "binary"
        solve_rate_band_low: float = 0.05,
        solve_rate_band_high: float = 0.25,
        gen_alpha_extreme: float = 0.2,        # penalty when solve_rate in {0,1}
        gen_invalid_bug_reward: float = -1.0,  # penalty for invalid bug

        # NEW: fixer reward style
        fixer_reward_pm1: bool = False,       # False => {0,1}, True => {-1,+1}

        # Include failed test output in fixer prompts
        include_failed_test_output: bool = True,

        # Code-embedding similarity (optional auxiliary reward)
        use_code_embedding_similarity: bool = False,
        code_embedding_reward_weight: float = 0.3,
        code_embedding_model_name: str = "voyage-code-3",
        code_embedding_embed_mode: str = "diff",  # "diff" | "buggy"
        code_embedding_include_problem: bool = False,
        code_embedding_top_k: int = 5,

        # Pools: you can pass tasks directly OR load pools from disk OR pass pre-built pools
        # If code_embedding_reference_bugs is None, we will fall back to `reference_bugs`
        code_embedding_reference_bugs: Optional[List[Dict[str, Any]]] = None,
        code_embedding_negative_bugs: Optional[List[Dict[str, Any]]] = None,
        code_embedding_target_pool_path: Optional[str] = None,
        code_embedding_negative_pool_path: Optional[str] = None,
        # Pre-built pool objects (avoids rebuilding for each parallel workflow instance)
        code_embedding_target_pool: Optional[Any] = None,
        code_embedding_negative_pool: Optional[Any] = None,

        # If negative pool exists, use margin = target - negative and sigmoid it
        code_embedding_use_margin: bool = True,
        code_embedding_margin_temperature: float = 10.0,

        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, executor=executor, **kwargs)
        self.reward_function = reward_function
        self.generator = BugGenerator(rollout_engine, BugGeneratorConfig(system_prompt=generator_system_prompt))
        self.fixer = BugFixer(
            rollout_engine,
            BugFixerConfig(
                system_prompt=fixer_system_prompt,
                include_failed_test_output=bool(include_failed_test_output),
            ),
        )

        self.evaluate_codegen = bool(evaluate_codegen)
        self.use_pregenerated_bugs_in_validation = bool(use_pregenerated_bugs_in_validation)
        self.use_pregenerated_bugs_in_training = bool(use_pregenerated_bugs_in_training)
        self.pregenerated_bug_train_probability = float(pregenerated_bug_train_probability)
        self.episode_success_mode = str(episode_success_mode).lower().strip()

        self.fixer_attempts_train = max(1, int(fixer_attempts_train))
        self.fixer_attempts_val = max(1, int(fixer_attempts_val))

        self.generator_reward_mode = str(generator_reward_mode)
        self.solve_rate_band_low = float(solve_rate_band_low)
        self.solve_rate_band_high = float(solve_rate_band_high)
        self.gen_alpha_extreme = float(gen_alpha_extreme)
        self.gen_invalid_bug_reward = float(gen_invalid_bug_reward)

        self.fixer_reward_pm1 = bool(fixer_reward_pm1)

        # ---------------------------
        # Code-embedding similarity (aux reward)
        # ---------------------------
        self.use_code_embedding_similarity = bool(use_code_embedding_similarity)
        self.code_embedding_reward_weight = float(code_embedding_reward_weight)
        self.code_embedding_use_margin = bool(code_embedding_use_margin)
        self.code_embedding_margin_temperature = float(code_embedding_margin_temperature)

        self._code_embedder: Optional[Any] = None
        self._code_knn_target: Optional[Any] = None
        self._code_knn_negative: Optional[Any] = None

        if self.use_code_embedding_similarity:
            if CodeEmbeddingConfig is None or CodeEmbedder is None or KNNBugSimilarity is None or ReferencePool is None:
                print("[CodeEmbedding] WARNING: code_embedding deps not available. Disabling.")
                self.use_code_embedding_similarity = False
            else:
                # Build embedder
                emb_cfg = CodeEmbeddingConfig(
                    enabled=True,
                    reward_weight=self.code_embedding_reward_weight,
                    model_name=str(code_embedding_model_name),
                    embed_mode=str(code_embedding_embed_mode),
                    include_problem=bool(code_embedding_include_problem),
                    top_k=int(code_embedding_top_k),
                )
                self._code_embedder = CodeEmbedder(emb_cfg)

                # Target pool
                self._code_knn_target = KNNBugSimilarity(self._code_embedder, top_k=int(code_embedding_top_k))

                # Priority: 1) pre-built pool object, 2) load from disk, 3) build from tasks
                if code_embedding_target_pool is not None:
                    self._code_knn_target.target_pool = code_embedding_target_pool
                    print(f"[CodeEmbedding] Using pre-built TARGET pool (n={self._safe_len(code_embedding_target_pool)})")
                elif code_embedding_target_pool_path:
                    try:
                        pool = ReferencePool.load(str(code_embedding_target_pool_path))
                        self._code_knn_target.target_pool = pool
                        print(f"[CodeEmbedding] Loaded TARGET pool from {code_embedding_target_pool_path} (n={self._safe_len(pool)})")
                    except Exception as e:
                        print(f"[CodeEmbedding] WARNING: Failed loading TARGET pool: {e}")
                else:
                    if code_embedding_reference_bugs:
                        try:
                            self._code_knn_target.build_target_pool(list(code_embedding_reference_bugs))
                            tp = getattr(self._code_knn_target, "target_pool", None)
                            print(f"[CodeEmbedding] Built TARGET pool (n={self._safe_len(tp)})")
                        except Exception as e:
                            print(f"[CodeEmbedding] WARNING: Failed building TARGET pool: {e}")

                # Negative pool (optional)
                if code_embedding_negative_pool is not None or code_embedding_negative_pool_path or code_embedding_negative_bugs:
                    self._code_knn_negative = KNNBugSimilarity(self._code_embedder, top_k=int(code_embedding_top_k))

                    # Priority: 1) pre-built pool object, 2) load from disk, 3) build from tasks
                    if code_embedding_negative_pool is not None:
                        self._code_knn_negative.target_pool = code_embedding_negative_pool
                        print(f"[CodeEmbedding] Using pre-built NEGATIVE pool (n={self._safe_len(code_embedding_negative_pool)})")
                    elif code_embedding_negative_pool_path:
                        try:
                            pool = ReferencePool.load(str(code_embedding_negative_pool_path))
                            self._code_knn_negative.target_pool = pool
                            print(
                                f"[CodeEmbedding] Loaded NEGATIVE pool from {code_embedding_negative_pool_path} (n={self._safe_len(pool)})"
                            )
                        except Exception as e:
                            print(f"[CodeEmbedding] WARNING: Failed loading NEGATIVE pool: {e}")
                    else:
                        try:
                            self._code_knn_negative.build_negative_pool(list(code_embedding_negative_bugs or []))
                            npool = getattr(self._code_knn_negative, "target_pool", None)
                            print(f"[CodeEmbedding] Built NEGATIVE pool (n={self._safe_len(npool)})")
                        except Exception as e:
                            print(f"[CodeEmbedding] WARNING: Failed building NEGATIVE pool: {e}")

                # Final sanity
                tp = getattr(self._code_knn_target, "target_pool", None) if self._code_knn_target is not None else None
                if self._code_knn_target is None or self._safe_len(tp) == 0:
                    print("[CodeEmbedding] WARNING: Enabled but TARGET pool is empty. Disabling.")
                    self.use_code_embedding_similarity = False

    def _normalize_task_info(self, task: Dict[str, Any]) -> Dict[str, Any]:
        return normalize_task_info(task)

    @staticmethod
    def _sigmoid(x: float) -> float:
        # numerically stable-ish for typical temps/margins
        if x >= 0:
            z = math.exp(-x)
            return 1.0 / (1.0 + z)
        else:
            z = math.exp(x)
            return z / (1.0 + z)

    @staticmethod
    def _safe_len(x: Any) -> int:
        if x is None:
            return 0
        try:
            return len(x)
        except Exception:
            return 0

    def _score_code_embedding_aux_sync(self, problem: str, buggy_code: str) -> Tuple[float, Dict[str, Any]]:
        """
        Returns:
          aux_reward in [0,1] and metadata
        """
        if not self.use_code_embedding_similarity or self._code_knn_target is None:
            return 0.0, {"disabled": True}

        # score against target pool
        target_score, target_meta = self._code_knn_target.score_similarity(problem, buggy_code)

        # optionally score against negative pool and turn into margin-based probability
        neg_pool = getattr(self._code_knn_negative, "target_pool", None) if self._code_knn_negative is not None else None
        if self.code_embedding_use_margin and self._code_knn_negative is not None and self._safe_len(neg_pool) > 0:
            neg_score, neg_meta = self._code_knn_negative.score_similarity(problem, buggy_code)
            margin = float(target_score - neg_score)
            z = float(self.code_embedding_margin_temperature) * margin
            aux = float(self._sigmoid(z))
            meta = {
                "target_score": float(target_score),
                "negative_score": float(neg_score),
                "margin": float(margin),
                "margin_temperature": float(self.code_embedding_margin_temperature),
                "target_meta": target_meta,
                "negative_meta": neg_meta,
            }
            return aux, meta

        # fallback: just use target score
        return float(target_score), {
            "target_score": float(target_score),
            "target_meta": target_meta,
            "used_margin": False,
        }

    def _run_reward_sync(self, task_info: Dict[str, Any], action: str) -> RewardOutput:
        """Synchronous wrapper for reward function (to be run in executor)."""
        return self.reward_function(task_info=task_info, action=action)

    async def run(self, task: Dict[str, Any], uid: str, **kwargs) -> Episode:
        self.reset(task, uid)

        is_validation = bool(getattr(self.rollout_engine, "validate", False))
        task_info = self._normalize_task_info(task)

        # Get event loop for run_in_executor calls
        loop = asyncio.get_running_loop()

        pregenerated_bug: Optional[str] = None
        allow_pregenerated = (is_validation and self.use_pregenerated_bugs_in_validation) or (
            (not is_validation) and self.use_pregenerated_bugs_in_training
        )
        if allow_pregenerated:
            cand = _get_pregenerated_bug(task)
            if cand is not None:
                if is_validation:
                    pregenerated_bug = cand
                else:
                    p = max(0.0, min(1.0, float(self.pregenerated_bug_train_probability)))
                    if p >= 1.0:
                        pregenerated_bug = cand
                    elif p <= 0.0:
                        pregenerated_bug = None
                    else:
                        # Stable deterministic sampling per-episode (avoid global RNG drift).
                        seed = int(hashlib.md5(uid.encode("utf-8")).hexdigest(), 16) % (2**32)
                        if random.Random(seed).random() < p:
                            pregenerated_bug = cand

        # ---------------------------
        # BUG SOURCE
        # ---------------------------
        bug_traj: Optional[Trajectory] = None
        bug_step: Optional[Step] = None

        if pregenerated_bug is None:
            bug_traj = await self.generator.generate_bug(task)
            bug_step = bug_traj.steps[0]
            buggy_code = bug_step.action
        else:
            buggy_code = pregenerated_bug
        
        # DEBUGGING: After buggy_code is created
        bug_id = hashlib.md5(buggy_code.encode("utf-8")).hexdigest()[:12]

        # ---------------------------
        # BUG VALIDITY (via reward_fn)
        # ---------------------------
        # Run reward function in thread pool to avoid blocking the event loop
        try:
            bug_reward_output = await loop.run_in_executor(
                self.executor,
                functools.partial(self._run_reward_sync, task_info, buggy_code)
            )
        except Exception as e:
            bug_reward_output = RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={"error": f"bug reward error: {e}"},
            )

        bug_meta = bug_reward_output.metadata or {}
        bug_valid, has_compile_error = check_bug_validity(
            bug_meta=bug_meta,
            bug_reward_output=bug_reward_output,
            compile_errors_invalid=True,
        )
        total_tests = bug_meta.get("total_tests")
        passed_tests = bug_meta.get("passed_tests")

        # Extract failed test output for fixer prompt (if enabled)
        failed_test_output: Optional[str] = None
        if self.fixer.config.include_failed_test_output:
            failed_test_output = _extract_failed_test_output(bug_meta)

        # ---------------------------
        # SOLVER FIX (K attempts -> solve_rate)
        # ---------------------------
        fixer_attempts = self.fixer_attempts_val if is_validation else self.fixer_attempts_train

        fixer_trajs: List[Trajectory] = []
        fixer_passes: List[bool] = []

        # Only run fixer on valid bugs (bugs that fail tests and have no compile errors)
        if bug_valid:
            # 1) launch all attempts concurrently
            tasks = [
                self.fixer.fix_bug(
                    task,
                    buggy_code=buggy_code,
                    failed_test_output=failed_test_output,
                )
                for k in range(fixer_attempts)
            ]
            fixer_trajs = await asyncio.gather(*tasks)

            # 2) evaluate each fix in parallel using thread pool
            async def eval_fix(fixed_code: str) -> bool:
                try:
                    out = await loop.run_in_executor(
                        self.executor,
                        functools.partial(self._run_reward_sync, task_info, fixed_code)
                    )
                    return bool(out.is_correct)
                except Exception:
                    return False

            eval_tasks = [eval_fix(traj.steps[0].action) for traj in fixer_trajs]
            fixer_passes = list(await asyncio.gather(*eval_tasks))

        num_pass = int(sum(1 for p in fixer_passes if p))
        solve_rate = float(num_pass) / float(fixer_attempts) if fixer_attempts > 0 else 0.0
        fixer_pass_any = bool(num_pass > 0)
        fixer_pass_all = bool(num_pass == fixer_attempts and fixer_attempts > 0)

        # Assign per-attempt fixer rewards (only for valid bugs)
        # Each attempt gets its own reward based on whether that specific attempt passed
        for k, traj in enumerate(fixer_trajs):
            sp = fixer_passes[k]
            if self.fixer_reward_pm1:
                r = 1.0 if sp else -1.0
            else:
                r = 1.0 if sp else 0.0
            traj.steps[0].reward = float(r)

        # ---------------------------
        # OPTIONAL CODEGEN EVAL (VAL ONLY)
        # ---------------------------
        codegen_traj: Optional[Trajectory] = None
        codegen_pass = False
        if self.evaluate_codegen and is_validation:
            codegen_traj = await self.fixer.generate_code(task)
            codegen_step = codegen_traj.steps[0]
            generated_code = codegen_step.action
            try:
                codegen_reward_output = await loop.run_in_executor(
                    self.executor,
                    functools.partial(self._run_reward_sync, task_info, generated_code)
                )
                codegen_pass = bool(codegen_reward_output.is_correct)
            except Exception:
                codegen_pass = False
            # No training reward for codegen eval
            codegen_step.reward = 0.0

        # ---------------------------
        # CODE-EMBEDDING BUG SIMILARITY (optional auxiliary reward)
        # ---------------------------
        code_embed_score = 0.0
        code_embed_meta: Dict[str, Any] = {}

        if (
            self.use_code_embedding_similarity
            and pregenerated_bug is None  # only reward generator when it generated the bug
        ):
            # If bug is invalid, we usually don't want style reward to override "invalid" penalty.
            if bug_valid:
                problem = _get_problem(task)

                # run in threadpool to avoid blocking event loop (Voyage/local HF can be slow)
                code_embed_score, code_embed_meta = await loop.run_in_executor(
                    self.executor,
                    lambda: self._score_code_embedding_aux_sync(problem, buggy_code),
                )
            else:
                code_embed_score, code_embed_meta = 0.0, {"skipped": True, "reason": "bug_invalid"}

        # ---------------------------
        # ASSIGN GENERATOR RL REWARD (TRAINING SIGNAL)
        # ---------------------------
        if pregenerated_bug is not None:
            # No generator step happened, so there's no generator RL signal to assign.
            generator_reward = 0.0
            generator_base_reward = 0.0
        else:
            generator_base_reward = _shaped_generator_reward_from_solve_rate(
                bug_valid=bug_valid,
                solve_rate=solve_rate,
                mode=self.generator_reward_mode,
                band_low=self.solve_rate_band_low,
                band_high=self.solve_rate_band_high,
                alpha_extreme=self.gen_alpha_extreme,
                invalid_bug_reward=self.gen_invalid_bug_reward,
            )
            
            # Optional auxiliary code embedding similarity reward
            if self.use_code_embedding_similarity and code_embed_score > 0:
                w = max(0.0, min(1.0, float(self.code_embedding_reward_weight)))
                generator_reward = (1.0 - w) * generator_base_reward + w * float(code_embed_score)
            else:
                generator_reward = generator_base_reward


        if bug_step is not None:
            bug_step.reward = float(generator_reward)

        # ---------------------------
        # BUILD EPISODE
        # ---------------------------
        trajectories: List[Trajectory] = []
        if bug_traj is not None:
            trajectories.append(bug_traj)
        trajectories.extend(fixer_trajs)

        metrics: Dict[str, Any] = {"fixer_pass": float(solve_rate)}
        
        # Bug validity and test metrics
        metrics["bug_valid"] = float(bug_valid)
        if total_tests is not None:
            metrics["bug_total_tests"] = int(total_tests)
        if passed_tests is not None:
            metrics["bug_passed_tests"] = int(passed_tests)
        
        # Generator reward (for tracking reward rates)
        if bug_traj is not None and bug_step is not None:
            metrics["generator_reward"] = float(bug_step.reward)
        else:
            metrics["generator_reward"] = 0.0  # Pregenerated bug, no generator reward
        
        # Fixer code generation pass rate
        if self.evaluate_codegen and is_validation:
            metrics["fixer_codegen_pass"] = float(codegen_pass)
        
        if self.use_code_embedding_similarity and code_embed_meta:
            metrics["code_embed_score"] = float(code_embed_score)
            # helpful extras if present
            if "target_score" in code_embed_meta:
                metrics["code_embed_target_score"] = float(code_embed_meta["target_score"])
            if "negative_score" in code_embed_meta:
                metrics["code_embed_negative_score"] = float(code_embed_meta["negative_score"])
            if "margin" in code_embed_meta:
                metrics["code_embed_margin"] = float(code_embed_meta["margin"])

        # Store correctness computation parameters in episode.info
        episode_info: Dict[str, Any] = {
            "fixer_pass_any": bool(fixer_pass_any),
            "fixer_pass_all": bool(fixer_pass_all),
            "is_validation": bool(is_validation),
            "bug_valid": bool(bug_valid),
            "codegen_pass": bool(codegen_pass),
            "workflow": "generator_fixer",
            "bug_id": bug_id,
        }

        episode = Episode(
            id=uid,
            task=task,
            trajectories=trajectories,
            is_correct=False,  # set below
            metrics=metrics,
            info=episode_info,
        )

        self.assign_episode_correctness(episode)
        return episode

    def assign_episode_correctness(self, episode: Episode) -> None:
        """Assign episode correctness based on parameters stored in episode.info."""
        info = episode.info or {}
        fixer_pass_any = bool(info.get("fixer_pass_any", False))
        fixer_pass_all = bool(info.get("fixer_pass_all", False))
        is_validation = bool(info.get("is_validation", False))
        bug_valid = bool(info.get("bug_valid", False))
        codegen_pass = bool(info.get("codegen_pass", False))
        
        mode = self.episode_success_mode
        if is_validation:
            if mode == "codegen":
                episode.is_correct = bool(codegen_pass)
            elif mode == "bugfix":
                episode.is_correct = bool(bug_valid and fixer_pass_any)
        else:
            # Training: do not reward trivial bugs or too hard bugs
            episode.is_correct = bool(bug_valid and fixer_pass_any)
