from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from concurrent.futures import ThreadPoolExecutor

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine import ModelOutput, RolloutEngine
from rllm.engine import OpenAIEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.rewards.reward_types import RewardOutput
from rllm.workflows.workflow import Workflow

from examples.bugs.prompts import _build_bug_generator_prompt, _build_bug_fixer_prompt


@dataclass
class BugGeneratorConfig:
    system_prompt: Optional[str] = None


@dataclass
class BugFixerConfig:
    system_prompt: Optional[str] = None


class BugGenerator:
    """Lightweight wrapper around the rollout engine for the bug-generator role."""
    
    def __init__(self, rollout_engine: RolloutEngine, config: Optional[BugGeneratorConfig] = None):
        self.rollout_engine = rollout_engine
        self.config = config or BugGeneratorConfig()

    async def generate_bug(self, task: Dict[str, Any], uid: str) -> Trajectory:
        """Generate a buggy version of the correct DeepCoder solution.

        Expected task fields:
            - "question": problem description (string)
            - "reference_solution": correct ground-truth code (string)
        """
        problem = task["question"]
        correct_code = task["reference_solution"]
        prompt = _build_bug_generator_prompt(problem, correct_code)

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        model_output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        buggy_code = model_output.content

        # Build chat_completions with user messages and assistant response
        chat_completions = messages + [{"role": "assistant", "content": buggy_code}]
        step = Step(
            chat_completions=chat_completions,
            action=buggy_code,
            model_output=model_output,
        )

        trajectory = Trajectory(
            name="bug_generator",
            steps=[step],
        )
        return trajectory


class BugFixer:
    """Lightweight wrapper around the rollout engine for the bug-fixer (solver) role."""
    
    def __init__(self, rollout_engine: RolloutEngine, config: Optional[BugFixerConfig] = None):
        self.rollout_engine = rollout_engine
        self.config = config or BugFixerConfig()

    async def fix_bug(self, task: Dict[str, Any], buggy_code: str, uid: str) -> Trajectory:
        """Try to fix the buggy code and make all tests pass.

        Expected task fields:
            - "question": problem description (string)
        """
        problem = task["question"]
        prompt = _build_bug_fixer_prompt(problem, buggy_code)

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        model_output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        fixed_code = model_output.content
        # Build chat_completions with user messages and assistant response
        chat_completions = messages + [{"role": "assistant", "content": fixed_code}]

        step = Step(
            chat_completions=chat_completions,
            action=fixed_code,
            model_output=model_output
        )

        trajectory = Trajectory(
            name="bug_fixer",
            steps=[step],
        )
        return trajectory


def check_bug_validity(
    bug_meta: Dict[str, Any],
    bug_reward_output: RewardOutput,
    compile_errors_invalid: bool = True,
) -> tuple[bool, bool]:
    """
    Check if a bug is valid based on metadata from reward evaluation.
    
    A valid bug must:
    1. Have no compilation errors (if compile_errors_invalid=True)
    2. Fail at least one unit test (passed_tests < total_tests or all_passed=False)
    
    Args:
        bug_meta: Metadata dictionary from RewardOutput containing test results
        bug_reward_output: The RewardOutput object for additional checks
        compile_errors_invalid: If True, compilation errors make the bug invalid.
                               If False, compilation errors are still considered valid bugs.
    
    Returns:
        tuple: (bug_valid, has_compile_error) where:
            - bug_valid: True if bug is valid, False otherwise
            - has_compile_error: True if compilation errors were detected
    
    Examples:
        Valid bug metadata:
            {
                'all_passed': False,
                'test_results': [{
                    'error': None,
                    'error_message': 'Wrong answer...',  # Test failure (valid!)
                    'passed': False
                }],
                'total_tests': 1,
                'passed_tests': 0
            }
        
        Invalid bug (compile error) metadata:
            {
                'all_passed': False,
                'test_results': [{
                    'error': None,
                    'error_message': "Error during testing: '(' was never closed",
                    'output': None,
                    'passed': False
                }],
                'total_tests': 1,
                'passed_tests': 0
            }
    """
    total_tests = bug_meta.get("total_tests")
    passed_tests = bug_meta.get("passed_tests")
    all_passed = bug_meta.get("all_passed", False)
    
    # Check for compilation errors in test_results
    # Key distinction:
    # - Valid bug: error_message contains "Wrong answer" (test failure, code runs but gives wrong output)
    # - Compile error: error_message contains "Error during testing:" (code doesn't compile/run)
    test_results = bug_meta.get("test_results", [])
    has_compile_error = False
    
    if isinstance(test_results, list):
        for test in test_results:
            test_error_message = str(test.get("error_message", ""))
            test_error_message_lower = test_error_message.lower()
            if test_error_message:
                # "Error during testing:" indicates a compilation error
                if "error during testing:" in test_error_message_lower:
                    has_compile_error = True
                    break
                
                # Check for other compilation error patterns
                # But exclude "Wrong answer" messages which are valid test failures
                if "wrong answer" not in test_error_message_lower:
                    compile_patterns = [
                        "syntax", "syntaxerror", "compilation", "compile error",
                        "cannot compile", "indentation", "invalid syntax",
                        "unexpected", "eof", "unterminated", "was never closed",
                        "nameerror", "typeerror", "attributeerror", "import error",
                        "module not found", "indentationerror"
                    ]
                    if any(pattern in test_error_message_lower for pattern in compile_patterns):
                        has_compile_error = True
                        break
    
    # Determine bug validity
    if compile_errors_invalid and has_compile_error:
        bug_valid = False
    elif total_tests is not None and passed_tests is not None and total_tests > 0:
        # Valid bug: compiles, runs, but fails at least one test
        bug_valid = passed_tests < total_tests
    elif all_passed is False:
        # If all_passed is explicitly False, it's a valid bug (tests failed)
        bug_valid = True
    else:
        # Fallback: if we don't have per-test info, check if it's incorrect
        # but not due to compilation errors (if compile_errors_invalid=True)
        bug_valid = not bug_reward_output.is_correct
        if compile_errors_invalid:
            bug_valid = bug_valid and not has_compile_error
    
    return bug_valid, has_compile_error


class BugGeneratorWorkflow(Workflow):
    """Workflow that trains the BugGenerator against a static solver on DeepCoder tasks.
    
    High-level logic for each episode:

    1. BugGenerator sees the *correct* DeepCoder solution and produces a buggy version.
    2. We run unit tests on the buggy code:
       - A "valid bug" is one that causes **at least one** unit test to fail.
       - The bug must compile and run successfully (no syntax/compilation errors).
    3. Static solver (e.g., GPT-4o-mini) tries to fix the buggy code.
    4. We run unit tests on the fixed code.

    Rewards:
      - Generator gets reward 1.0 iff:
          * The bug is syntactically valid (no syntax/compilation errors)
          * The bug compiles and runs successfully
          * The bug causes at least one unit test to fail
          * The static solver FAILS to fix it (fixed code doesn't pass all tests)
      - Otherwise, reward 0.0.

    A "valid bug" is one that:
      - Has no compile errors (SyntaxError, compilation errors)
      - Has no format errors (code was successfully extracted from model output)
      - Compiles and runs without compilation errors
      - Fails at least one unit test (but executes successfully)

    The underlying `reward_function` is a standard code reward that
    evaluates a (task, code) pair and returns a RewardOutput with:
      - is_correct = True iff all tests pass
      - metadata["total_tests"], metadata["passed_tests"] when available
      - metadata["error"], metadata["error_message"] for error information
    """

    def __init__(
        self,
        rollout_engine: RolloutEngine,
        executor: ThreadPoolExecutor,
        reward_function: RewardFunction,
        generator_system_prompt: Optional[str] = None,
        solver_rollout_engine: Optional[RolloutEngine] = None,
        solver_model: Optional[str] = None,
        solver_base_url: Optional[str] = None,
        solver_temperature: float = 0.0,
        solver_top_p: float = 1.0,
        solver_max_prompt_length: Optional[int] = None,
        solver_max_response_length: Optional[int] = None,
        solver_system_prompt: Optional[str] = None,
        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, executor=executor, **kwargs)
        self.reward_function = reward_function
        self.generator = BugGenerator(rollout_engine, BugGeneratorConfig(system_prompt=generator_system_prompt))

        # Static solver uses a separate rollout engine (e.g., GPT-4o-mini).
        # In training, we typically want to specify this via Hydra config as:
        #   rllm.workflow.workflow_args.solver_model=gpt-4o-mini
        # (Optionally with solver_base_url / temperature / top_p, etc.)
        if solver_rollout_engine is None and solver_model:
            solver_base_url = solver_base_url or "https://api.openai.com/v1"

            # We don't need a tokenizer here; OpenAIEngine will use the chat completions endpoint.
            #
            # Auth behavior:
            # - If you're pointing at the OpenAI API, you must provide OPENAI_API_KEY.
            # - If you're pointing at a local OpenAI-compatible server (e.g., vLLM),
            #   we allow missing OPENAI_API_KEY and send a dummy token.
            api_key = os.getenv("OPENAI_API_KEY", "") or ""
            api_key = api_key.strip()

            is_openai_api = "api.openai.com" in solver_base_url
            if is_openai_api and not api_key:
                raise ValueError(
                    "solver_model was provided with solver_base_url pointing to the OpenAI API, "
                    "but OPENAI_API_KEY is missing/empty. Please export OPENAI_API_KEY (or source it from your .env)."
                )
            if not api_key:
                api_key = "EMPTY"

            solver_rollout_engine = OpenAIEngine(
                model=solver_model,
                tokenizer=None,
                base_url=solver_base_url,
                api_key=api_key,
                max_prompt_length=(
                    solver_max_prompt_length
                    if solver_max_prompt_length is not None
                    else getattr(rollout_engine, "max_prompt_length", 4096)
                ),
                max_response_length=(
                    solver_max_response_length
                    if solver_max_response_length is not None
                    else getattr(rollout_engine, "max_response_length", 4096)
                ),
                sampling_params={
                    "temperature": float(solver_temperature),
                    "top_p": float(solver_top_p),
                },
            )

        if solver_rollout_engine is not None:
            self.solver = BugFixer(solver_rollout_engine, BugFixerConfig(system_prompt=solver_system_prompt))
        else:
            self.solver = None

    async def run(self, task: Dict[str, Any], uid: str, **kwargs) -> Episode:
        """Execute bug generation against static solver on a single DeepCoder task."""
        self.reset(task, uid)

        # Generate bug
        bug_traj = await self.generator.generate_bug(task, uid)
        bug_step = bug_traj.steps[0]

        # Get reward for the bug
        buggy_code = bug_step.action
        bug_reward_output: RewardOutput
        try:
            if "extra_info" in task.keys():
                task_info = task["extra_info"]
            else:
                task_info = task
            bug_reward_output = self.reward_function(task_info=task_info, action=buggy_code)
        except Exception as e:  # Treat reward errors as invalid bug
            bug_reward_output = RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={"error": f"bug reward error: {e}"},
            )
        bug_meta = bug_reward_output.metadata or {}
        
        # Check bug validity (compile errors are considered invalid by default)
        bug_valid, has_compile_error = check_bug_validity(
            bug_meta=bug_meta,
            bug_reward_output=bug_reward_output,
            compile_errors_invalid=True,  # Set to False if compile errors should be considered valid bugs
        )
        
        total_tests = bug_meta.get("total_tests")
        passed_tests = bug_meta.get("passed_tests")

        # Solver step: static solver attempts to fix buggy code
        solver_pass = False
        solver_traj = None
        # solver_evaluated = False
        # if self.solver is not None and bug_valid:
        if self.solver is not None:
            # solver_evaluated = True
            solver_traj = await self.solver.fix_bug(task, buggy_code, uid)
            solver_step = solver_traj.steps[0]
            fixed_code = solver_step.action

            solver_reward_output: RewardOutput
            try:
                solver_reward_output = self.reward_function(task_info=task_info, action=fixed_code)
            except Exception as e:
                solver_reward_output = RewardOutput(
                    reward=0.0,
                    is_correct=False,
                    metadata={"error": f"solver reward error: {e}"},
                )
            solver_pass = solver_reward_output.is_correct

        # Assign RL reward: 1.0 if bug is valid AND solver fails to fix it, 0.0 otherwise
        generator_reward = 1.0 if (bug_valid and not solver_pass) else 0.0

        bug_step.reward = generator_reward

        # Build episode
        trajectories = [bug_traj]
        if solver_traj is not None:
            trajectories.append(solver_traj)
        
        metrics: Dict[str, Any] = {
            "bug_valid": float(bug_valid),
            "generator_reward": float(generator_reward),
            "bug_has_compile_error": float(has_compile_error),
            "solver_pass": float(solver_pass),
            # "solver_evaluated": float(solver_evaluated),
            # "solver_pass": float(solver_pass) if solver_evaluated else 0.0,
        }
        if total_tests is not None:
            metrics["bug_total_tests"] = int(total_tests)
        if passed_tests is not None:
            metrics["bug_passed_tests"] = int(passed_tests)

        # Episode is considered "correct" when we get a valid bug that the solver fails to fix
        is_correct = bool(bug_valid and not solver_pass)

        episode = Episode(
            id=uid,
            task=task,
            trajectories=trajectories,
            is_correct=is_correct,
            metrics=metrics,
        )
        self.assign_episode_correctness(episode)
        return episode

    def assign_episode_correctness(self, episode: Episode) -> None:
        """Optionally adjust episode.is_correct based on metrics."""
        # Right now we keep `is_correct` as set in `run`, but you can
        # override with more complex logic if you want.
        pass
