from __future__ import annotations

import asyncio
import re
from typing import Any, Dict, List, Optional

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine import ModelOutput, RolloutEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.rewards.reward_types import RewardOutput
from rllm.workflows.workflow import Workflow

from examples.bugs.prompts import (
    _build_bug_generator_prompt,
    _build_bug_fixer_prompt,
    _extract_failed_test_output,
)
from examples.bugs_refactor.utils import (
    normalize_task_info,
    check_bug_validity,
    _get_problem,
    _get_reference_solution,
)


def _model_text(out: ModelOutput) -> str:
    """Extract text content from ModelOutput."""
    return out.content or ""


class BugGenerator:
    """Generates buggy code from a correct solution."""

    def __init__(self, rollout_engine: RolloutEngine, **kwargs):
        self.rollout_engine = rollout_engine

    async def generate_bug(self, task: Dict[str, Any]) -> Trajectory:
        problem = _get_problem(task)
        correct_code = _get_reference_solution(task)
        if not correct_code:
            raise KeyError("Task missing reference_solution/canonical_solution required for bug generation.")

        prompt = _build_bug_generator_prompt(problem, correct_code)
        messages = [{"role": "user", "content": prompt}]

        output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        buggy_code = _model_text(output)

        return Trajectory(
            name="bug_generator",
            steps=[
                Step(
                    chat_completions=messages + [{"role": "assistant", "content": output.content, "reasoning": output.reasoning}],
                    thought=output.reasoning,
                    action=buggy_code,
                    model_output=output,
                )
            ],
        )


class BugFixer:
    """Fixes buggy code by generating correct solutions."""

    def __init__(self, rollout_engine: RolloutEngine, include_failed_test_output: bool = True, **kwargs):
        self.rollout_engine = rollout_engine
        self.include_failed_test_output = include_failed_test_output

    async def generate_fix(
        self,
        task: Dict[str, Any],
        buggy_code: str,
        failed_test_output: Optional[str] = None,
    ) -> Trajectory:
        """Generate a single fix attempt."""
        problem = _get_problem(task)
        prompt = _build_bug_fixer_prompt(
            problem,
            buggy_code,
            include_failed_test_output=self.include_failed_test_output,
            failed_test_output=failed_test_output,
        )
        messages = [{"role": "user", "content": prompt}]

        output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        fixed_code = _model_text(output)

        return Trajectory(
            name="bug_fixer",
            steps=[
                Step(
                    chat_completions=messages + [{"role": "assistant", "content": output.content, "reasoning": output.reasoning}],
                    thought=output.reasoning,
                    action=fixed_code,
                    model_output=output,
                )
            ],
        )

    async def generate_fixes(
        self,
        task: Dict[str, Any],
        buggy_code: str,
        n_fixes: int = 2,
        failed_test_output: Optional[str] = None,
    ) -> List[Trajectory]:
        """Generate multiple fix attempts in parallel."""
        tasks = [
            asyncio.create_task(self.generate_fix(task, buggy_code, failed_test_output))
            for _ in range(n_fixes)
        ]
        return await asyncio.gather(*tasks)


class FixJudge:
    """Judges and selects the best fix from multiple attempts."""

    def __init__(self, rollout_engine: RolloutEngine, **kwargs):
        self.rollout_engine = rollout_engine

    async def judge_fixes(self, problem: str, buggy_code: str, fixes: List[str]) -> Trajectory:
        """Judge multiple fixes and select the best one."""
        messages = [{"role": "user", "content": self._create_judge_prompt(problem, buggy_code, fixes)}]
        output: ModelOutput = await self.rollout_engine.get_model_response(messages)

        return Trajectory(
            name="fix_judge",
            steps=[
                Step(
                    chat_completions=messages + [{"role": "assistant", "content": output.content, "reasoning": output.reasoning}],
                    thought=output.reasoning,
                    action=self._parse_judge_response(output.content, fixes),
                    model_output=output,
                )
            ],
        )

    def _parse_judge_response(self, response: str, fixes: List[str]) -> str:
        """Parse the judge's response to extract the selected fix."""
        answer_match = re.search(r"<answer>(.*?)</answer>", response, re.IGNORECASE | re.DOTALL)
        if answer_match:
            answer_text = answer_match.group(1).strip()
            try:
                fix_index = int(answer_text)
                if 1 <= fix_index <= len(fixes):
                    return fixes[fix_index - 1]
            except (ValueError, IndexError):
                pass
        # Fallback: return first fix if parsing fails
        return fixes[0] if fixes else ""

    def _create_judge_prompt(self, problem: str, buggy_code: str, fixes: List[str]) -> str:
        """Create a prompt for the judge to evaluate fix attempts."""
        prompt = f"""You are an expert code reviewer. Given a programming problem, buggy code, and multiple fix attempts, select the best fix.

Problem:
{problem}

Buggy Code:
```python
{buggy_code}
```

Fix Attempts:
"""
        for i, fix in enumerate(fixes, 1):
            prompt += f"\nFix {i}:\n```python\n{fix}\n```\n"

        prompt += """
A correct fix must:
1. Solve the original problem correctly
2. Fix all bugs in the buggy code
3. Handle all edge cases properly
4. Be syntactically correct Python code

Output the index of your selected fix within <answer>...</answer> tags, e.g., <answer>1</answer> for the first fix, <answer>2</answer> for the second fix, etc. If multiple fixes are correct, select the cleanest implementation."""
        return prompt


class GeneratorFixerWorkflow(Workflow):
    """
    Workflow where:
    1. Generator creates a buggy version of correct code
    2. Fixer generates multiple fix attempts in parallel
    3. Judge selects the best fix
    """

    def __init__(
        self,
        rollout_engine: RolloutEngine,
        reward_function: RewardFunction,
        n_fixes: int = 2,
        include_failed_test_output: bool = True,
        use_judge: bool = True,
        fixer_reward_pm1: bool = False,
        gen_invalid_bug_reward: float = -1.0,
        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, **kwargs)
        self.reward_function = reward_function
        self.n_fixes = n_fixes
        self.use_judge = use_judge
        self.fixer_reward_pm1 = fixer_reward_pm1
        self.gen_invalid_bug_reward = gen_invalid_bug_reward

        self.generator = BugGenerator(rollout_engine)
        self.fixer = BugFixer(rollout_engine, include_failed_test_output=include_failed_test_output)
        self.judge = FixJudge(rollout_engine) if use_judge else None

    async def run(self, task: Dict[str, Any], uid: str, **kwargs) -> Episode:
        self.reset(task, uid)
        task_info = normalize_task_info(task)
        problem = _get_problem(task)

        # Step 1: Generator creates buggy code
        generator_trajectory = await self.generator.generate_bug(task)
        buggy_code = generator_trajectory.steps[0].action

        # Validate the bug
        try:
            bug_reward_output = self.reward_function(task_info=task_info, action=buggy_code)
        except Exception as e:
            bug_reward_output = RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={"error": f"bug reward error: {e}"},
            )

        bug_meta = bug_reward_output.metadata or {}
        bug_valid, has_compile_error = check_bug_validity(
            bug_meta=bug_meta,
            bug_reward_output=bug_reward_output,
            compile_errors_invalid=True,
        )

        # Extract failed test output for fixer prompt
        failed_test_output: Optional[str] = None
        if self.fixer.include_failed_test_output:
            failed_test_output = _extract_failed_test_output(bug_meta)

        # Step 2: Fixer generates multiple fixes in parallel
        fixer_trajectories: List[Trajectory] = []
        fixes: List[str] = []
        fixer_passes: List[bool] = []

        if bug_valid:
            fixer_trajectories = await self.fixer.generate_fixes(
                task, buggy_code, n_fixes=self.n_fixes, failed_test_output=failed_test_output
            )

            # Evaluate each fix and assign rewards
            for traj in fixer_trajectories:
                fix = traj.steps[0].action
                fixes.append(fix)

                try:
                    reward_result = self.reward_function(task_info=task_info, action=fix)
                    is_pass = bool(reward_result.is_correct)
                except Exception:
                    is_pass = False

                fixer_passes.append(is_pass)

                # Assign reward to this fix attempt
                if self.fixer_reward_pm1:
                    traj.steps[0].reward = 1.0 if is_pass else -1.0
                else:
                    traj.steps[0].reward = 1.0 if is_pass else 0.0

        # Compute fixer metrics
        fixer_acc = sum(fixer_passes) / len(fixer_passes) if fixer_passes else 0.0
        fixer_pass_any = any(fixer_passes) if fixer_passes else False

        # Step 3: Judge selects the best fix (if enabled and we have fixes)
        judge_trajectory: Optional[Trajectory] = None
        judge_acc = 0.0
        selected_fix = ""

        if self.use_judge and self.judge and fixes:
            judge_trajectory = await self.judge.judge_fixes(problem, buggy_code, fixes)
            selected_fix = judge_trajectory.steps[0].action

            # Evaluate the judge's selection
            try:
                judge_reward_result = self.reward_function(task_info=task_info, action=selected_fix)
                judge_pass = bool(judge_reward_result.is_correct)
            except Exception:
                judge_pass = False

            judge_trajectory.steps[0].reward = 1.0 if judge_pass else 0.0
            judge_acc = float(judge_pass)
            is_correct = judge_pass
        else:
            # Without judge, success is if any fix passes
            is_correct = fixer_pass_any

        # Assign generator reward based on bug validity and fixer success
        if bug_valid:
            # Reward based on solve rate (how hard the bug is)
            solve_rate = fixer_acc
            # Band-based reward: reward bugs that are neither too easy nor too hard
            if solve_rate == 0.0 or solve_rate == 1.0:
                generator_reward = 0.2  # Penalty for trivial or impossible bugs
            elif 0.05 <= solve_rate <= 0.25:
                generator_reward = 1.0  # Ideal difficulty range
            else:
                generator_reward = 0.5  # Moderate reward
        else:
            generator_reward = self.gen_invalid_bug_reward

        generator_trajectory.steps[0].reward = generator_reward

        # Build trajectories list
        trajectories = [generator_trajectory] + fixer_trajectories
        if judge_trajectory:
            trajectories.append(judge_trajectory)

        # Compute metrics
        metrics = {
            "bug_valid": float(bug_valid),
            "fixer_acc": fixer_acc,
            "fixer_pass_any": float(fixer_pass_any),
            "generator_reward": generator_reward,
        }
        if self.use_judge:
            metrics["judge_acc"] = judge_acc

        return Episode(
            id=uid,
            task=task,
            trajectories=trajectories,
            is_correct=is_correct,
            metrics=metrics,
        )
