from __future__ import annotations

import hashlib
import os
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

from concurrent.futures import ThreadPoolExecutor

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine import ModelOutput, RolloutEngine, OpenAIEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.rewards.reward_types import RewardOutput
from rllm.workflows.workflow import Workflow

from examples.bugs.prompts import _build_bug_generator_prompt, _build_bug_fixer_prompt
from examples.bugs.generator_flow import check_bug_validity


def _get_pregenerated_bug(task: Dict[str, Any]) -> Optional[str]:
    """Use held-out (pre-generated) buggy code when present (e.g., BugBench)."""
    for k in ("buggy_solution", "buggy", "buggy_code", "bug"):
        v = task.get(k, None)
        if isinstance(v, str) and v.strip():
            return v
    return None


@dataclass
class BugGeneratorConfig:
    system_prompt: Optional[str] = None


@dataclass
class BugFixerConfig:
    system_prompt: Optional[str] = None


class BugGenerator:
    """Static bug generator (frozen) that produces buggy code from a correct reference solution."""

    def __init__(self, rollout_engine: RolloutEngine, config: Optional[BugGeneratorConfig] = None):
        self.rollout_engine = rollout_engine
        self.config = config or BugGeneratorConfig()

    async def generate_bug(self, task: Dict[str, Any], uid: str) -> Tuple[str, str]:
        """
        Returns:
            (buggy_code, raw_response)
        """
        problem = task["question"]
        correct_code = task["reference_solution"]
        prompt = _build_bug_generator_prompt(problem, correct_code)

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        model_output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        raw_response = model_output.text or model_output.content or ""

        # We pass through `content` as the action for reward evaluation.
        # The code reward function typically extracts ```python``` blocks.
        buggy_code = model_output.content
        return buggy_code, raw_response


class BugFixer:
    """Trainable solver (the workflow's rollout_engine) that fixes buggy code."""

    def __init__(self, rollout_engine: RolloutEngine, config: Optional[BugFixerConfig] = None):
        self.rollout_engine = rollout_engine
        self.config = config or BugFixerConfig()

    @staticmethod
    def _split_think(raw_response: str) -> tuple[str, str]:
        """Split `<think>...</think>` from the final answer if present (common for OSS models)."""
        if raw_response.count("</think>") == 1:
            thought, sep, action = raw_response.partition("</think>")
            return (thought + sep).strip(), action.strip()
        return "", raw_response.strip()

    async def fix_bug(self, task: Dict[str, Any], buggy_code: str, uid: str) -> Trajectory:
        problem = task["question"]
        prompt = _build_bug_fixer_prompt(problem, buggy_code)

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        model_output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        raw_response = model_output.text or model_output.content or ""
        thought, action = self._split_think(raw_response)

        # IMPORTANT: Only include the *solver* trajectory in the episode to avoid mixing
        # tokenizations / policies for a static generator model.
        chat_completions = messages + [{"role": "assistant", "content": action}]
        step = Step(
            chat_completions=chat_completions,
            thought=thought,
            action=action,
            model_response=raw_response,
            model_output=model_output,
        )
        return Trajectory(name="solver", steps=[step])


class SolverWorkflow(Workflow):
    """Workflow that trains a solver to fix bugs generated by a separate (static) generator model.

    High-level episode:
      1) Static bug generator sees (problem, reference_solution) and produces buggy code.
      2) Reward function evaluates buggy code (for metrics / gating).
      3) Trainable solver sees (problem, buggy_code) and outputs a fix.
      4) Reward function evaluates fixed code. Solver reward = 1 iff:
           - buggy code was actually incorrect, AND
           - fixed code passes all tests.

    Notes on training integration:
      - We only return the solver trajectory in the Episode to ensure Verl tokenization
        uses the trainable solver policy/tokenizer.
      - Buggy code is stored in episode.info for debugging.
    """

    def __init__(
        self,
        rollout_engine: RolloutEngine,
        executor: ThreadPoolExecutor,
        reward_function: RewardFunction,
        # Static generator configuration:
        generator_rollout_engine: Optional[RolloutEngine] = None,
        generator_model: Optional[str] = None,
        generator_base_url: Optional[str] = None,
        generator_api_key: Optional[str] = None,
        generator_temperature: float = 0.6,
        generator_top_p: float = 0.95,
        generator_system_prompt: Optional[str] = None,
        # Trainable solver configuration:
        solver_system_prompt: Optional[str] = None,
        # Bug validity policy:
        compile_errors_invalid: bool = True,

        # NEW: Training-time control for using pregenerated bugs (when present in task_info).
        # This enables mixing in a separate "human bugs" dataset while still being able to
        # stochastically ignore pregenerated bugs if desired.
        use_pregenerated_bugs_in_training: bool = True,
        pregenerated_bug_train_probability: float = 1.0,
        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, executor=executor, **kwargs)
        self.reward_function = reward_function
        self.compile_errors_invalid = bool(compile_errors_invalid)
        self.use_pregenerated_bugs_in_training = bool(use_pregenerated_bugs_in_training)
        self.pregenerated_bug_train_probability = float(pregenerated_bug_train_probability)

        # Build / attach static generator engine (OpenAI-compatible).
        if generator_rollout_engine is None and generator_model:
            generator_base_url = generator_base_url or "https://api.openai.com/v1"

            api_key = (generator_api_key if generator_api_key is not None else os.getenv("OPENAI_API_KEY", "")) or ""
            api_key = api_key.strip()
            is_openai_api = "api.openai.com" in str(generator_base_url)
            if is_openai_api and not api_key:
                raise ValueError(
                    "generator_model was provided with generator_base_url pointing to the OpenAI API, "
                    "but OPENAI_API_KEY is missing/empty. Please export OPENAI_API_KEY (or pass generator_api_key)."
                )
            if not api_key:
                api_key = "EMPTY"

            generator_rollout_engine = OpenAIEngine(
                model=str(generator_model),
                tokenizer=None,  # static generator; use chat-completions endpoint
                base_url=str(generator_base_url),
                api_key=api_key,
                sampling_params={
                    "temperature": float(generator_temperature),
                    "top_p": float(generator_top_p),
                },
                # No tokenizer => cannot enforce these, but keep for symmetry / logs.
                max_prompt_length=getattr(rollout_engine, "max_prompt_length", 8192),
                max_response_length=getattr(rollout_engine, "max_response_length", 8192),
            )

        if generator_rollout_engine is None:
            raise ValueError(
                "SolverWorkflow requires a static generator. Provide either generator_rollout_engine or generator_model."
            )

        self.generator = BugGenerator(
            generator_rollout_engine,
            BugGeneratorConfig(system_prompt=generator_system_prompt),
        )
        self.solver = BugFixer(rollout_engine, BugFixerConfig(system_prompt=solver_system_prompt))

    async def run(self, task: Dict[str, Any], uid: str, **kwargs) -> Episode:
        self.reset(task, uid)

        is_validation = bool(getattr(self.rollout_engine, "validate", False))

        # Prefer pregenerated bugs (BugBench-style) when available.
        task_info = task.get("extra_info", task)
        pregenerated_bug: Optional[str] = None
        cand = _get_pregenerated_bug(task_info)
        if cand is not None:
            if is_validation:
                pregenerated_bug = cand
            else:
                if self.use_pregenerated_bugs_in_training:
                    p = max(0.0, min(1.0, float(self.pregenerated_bug_train_probability)))
                    if p >= 1.0:
                        pregenerated_bug = cand
                    elif p <= 0.0:
                        pregenerated_bug = None
                    else:
                        seed = int(hashlib.md5(uid.encode("utf-8")).hexdigest(), 16) % (2**32)
                        if random.Random(seed).random() < p:
                            pregenerated_bug = cand
        used_pregenerated_bug = pregenerated_bug is not None

        # 1) Get buggy code:
        #    - if present in task_info (e.g., BugBench's `buggy_solution`), use it
        #    - otherwise, generate buggy code via static generator.
        if pregenerated_bug is not None:
            buggy_code = pregenerated_bug
            bug_raw = "(pregenerated_bug)"
        else:
            buggy_code, bug_raw = await self.generator.generate_bug(task_info, uid)

        # 2) Evaluate buggy code (metrics / gating).
        try:
            bug_reward_output: RewardOutput = self.reward_function(task_info=task_info, action=buggy_code)
        except Exception as e:
            bug_reward_output = RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={"error": f"bug reward error: {e}"},
            )
        bug_meta = bug_reward_output.metadata or {}

        bug_valid, has_compile_error = check_bug_validity(
            bug_meta=bug_meta,
            bug_reward_output=bug_reward_output,
            compile_errors_invalid=self.compile_errors_invalid,
        )
        bug_incorrect = not bool(bug_reward_output.is_correct)

        # 3) Trainable solver attempts to fix buggy code.
        solver_traj = await self.solver.fix_bug(task, buggy_code, uid)
        solver_step = solver_traj.steps[0]
        fixed_code = solver_step.action

        # 4) Evaluate fixed code.
        try:
            solver_reward_output: RewardOutput = self.reward_function(task_info=task_info, action=fixed_code)
        except Exception as e:
            solver_reward_output = RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={"error": f"solver reward error: {e}"},
            )
        solver_pass = bool(solver_reward_output.is_correct)

        # Solver reward: only count as success if the bug was actually incorrect.
        solver_reward = 1.0 if (bug_incorrect and solver_pass) else 0.0
        solver_step.reward = float(solver_reward)
        solver_traj.reward = float(solver_reward)

        metrics: Dict[str, Any] = {
            "bug_incorrect": float(bug_incorrect),
            "bug_valid": float(bug_valid),
            "bug_has_compile_error": float(has_compile_error),
            "solver_pass": float(solver_pass),
            "solver_reward": float(solver_reward),
        }
        if used_pregenerated_bug:
            metrics["used_pregenerated_bug"] = 1.0
        if "total_tests" in bug_meta:
            metrics["bug_total_tests"] = int(bug_meta["total_tests"])
        if "passed_tests" in bug_meta:
            metrics["bug_passed_tests"] = int(bug_meta["passed_tests"])

        fixed_meta = solver_reward_output.metadata or {}
        if "total_tests" in fixed_meta:
            metrics["fixed_total_tests"] = int(fixed_meta["total_tests"])
        if "passed_tests" in fixed_meta:
            metrics["fixed_passed_tests"] = int(fixed_meta["passed_tests"])

        episode = Episode(
            id=uid,
            task=task,
            trajectories=[solver_traj],
            is_correct=bool(solver_reward),
            metrics=metrics,
            info={
                "buggy_code": buggy_code,
                "bug_generator_raw": bug_raw,
                "bug_source": ("pregenerated" if used_pregenerated_bug else "generated"),
            },
        )
        self.assign_episode_correctness(episode)
        return episode

    def assign_episode_correctness(self, episode: Episode) -> None:
        pass


