from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from concurrent.futures import ThreadPoolExecutor

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine import ModelOutput, RolloutEngine, OpenAIEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.rewards.reward_types import RewardOutput
from rllm.workflows.workflow import Workflow, TerminationReason

from examples.bugs.prompts import (
    _build_bug_fixer_prompt,
    _build_code_generation_prompt,
    _extract_failed_test_output,
)


def _resolve_api_key(base_url: str, explicit_api_key: Optional[str]) -> str:
    """Resolve OpenAI-compatible API key, allowing dummy keys for non-OpenAI endpoints."""
    api_key = (explicit_api_key or os.getenv("OPENAI_API_KEY", "")) or ""
    api_key = str(api_key).strip()
    if "api.openai.com" in str(base_url) and not api_key:
        raise ValueError(
            "solver_base_url points to the OpenAI API, but OPENAI_API_KEY is missing/empty. "
            "Please export OPENAI_API_KEY or pass solver_api_key explicitly."
        )
    return api_key or "EMPTY"


def _normalize_task_info(task: Dict[str, Any]) -> Dict[str, Any]:
    """Unwrap extra_info if present and ensure reward_fn-compatible keys exist."""
    task_info = task.get("extra_info", task)
    if "ground_truth" not in task_info and "test" in task_info:
        task_info = dict(task_info)
        task_info["ground_truth"] = task_info["test"]
    return task_info


def _pass_ratio(meta: Dict[str, Any]) -> Optional[float]:
    """Return passed_tests / total_tests when available."""
    try:
        total = int(meta.get("total_tests", 0))
        passed = int(meta.get("passed_tests", 0))
        return passed / total if total > 0 else None
    except Exception:
        return None


@dataclass
class FrozenSolverConfig:
    system_prompt: Optional[str] = None
    append_python_only_instruction: bool = True


@dataclass
class FixerConfig:
    system_prompt: Optional[str] = None
    include_failed_test_output: bool = False
    max_failed_test_output_chars: int = 4000


class FrozenSolver:
    """Frozen solver that attempts code synthesis from scratch (static / no training)."""

    def __init__(self, rollout_engine: RolloutEngine, config: Optional[FrozenSolverConfig] = None):
        self.rollout_engine = rollout_engine
        self.config = config or FrozenSolverConfig()

    async def synthesize(self, task: Dict[str, Any], uid: str) -> tuple[str, ModelOutput]:
        prompt = _build_code_generation_prompt(task)
        if self.config.append_python_only_instruction:
            prompt += "\n\nReturn only the full Python solution inside a single ```python``` block."

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        out: ModelOutput = await self.rollout_engine.get_model_response(messages)
        code = (out.content or "").strip()
        return code, out


class BugFixer:
    """Trainable fixer that takes (problem, buggy_code) and outputs a corrected program."""

    def __init__(self, rollout_engine: RolloutEngine, config: Optional[FixerConfig] = None):
        self.rollout_engine = rollout_engine
        self.config = config or FixerConfig()

    async def fix(
        self,
        task: Dict[str, Any],
        buggy_code: str,
        uid: str,
        failed_test_output: Optional[str] = None,
    ) -> Trajectory:
        problem = str(task.get("instruct_prompt") or task.get("complete_prompt") or task.get("problem") or "")
        prompt = _build_bug_fixer_prompt(
            problem,
            buggy_code,
            include_failed_test_output=bool(self.config.include_failed_test_output),
            failed_test_output=failed_test_output,
        )

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        out: ModelOutput = await self.rollout_engine.get_model_response(messages)
        raw = (out.text or out.content or "").strip()

        chat_completions = messages + [{"role": "assistant", "content": raw}]
        step = Step(
            chat_completions=chat_completions,
            thought="",
            action=raw,
            model_response=raw,
            model_output=out,
        )
        return Trajectory(name="fixer", steps=[step])


class FixerWorkflow(Workflow):
    """Train a fixer on a frozen solver's code synthesis failures.

    Two modes:

    (A) Failure-only (only_train_on_failures=True) [default]
      1) Frozen solver synthesizes code for the task.
      2) Reward function evaluates solver output.  (unit tests run)
      3) If solver fails, trainable fixer repairs the solver output.
      4) Reward function evaluates fixed output.   (unit tests run again)
      5) Fixer reward = 1 iff (solver failed) AND (fix passes all tests).

    (B) One-shot (only_train_on_failures=False)
      1) Frozen solver synthesizes code for the task.
      2) Fixer always proposes a refined/fixed program.
      3) Reward function evaluates ONLY the final submitted code ONCE.
         (No solver evaluation, no intermediate tests.)

    Training notes:
      - Only the fixer trajectory is returned (frozen solver is not trained).
      - If only_train_on_failures=True, solver successes yield empty episodes (dropped) in training mode.
    """

    def __init__(
        self,
        rollout_engine: RolloutEngine,
        executor: ThreadPoolExecutor,
        reward_function: RewardFunction,
        # Frozen solver configuration.
        solver_rollout_engine: Optional[RolloutEngine] = None,
        solver_model: Optional[str] = None,
        solver_base_url: Optional[str] = None,
        solver_api_key: Optional[str] = None,
        solver_temperature: float = 0.6,
        solver_top_p: float = 0.95,
        solver_system_prompt: Optional[str] = None,
        # Fixer configuration.
        fixer_system_prompt: Optional[str] = None,
        # Training behavior.
        only_train_on_failures: bool = True,
        reward_pm1: bool = False,
        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, executor=executor, **kwargs)
        self.reward_function = reward_function
        self.only_train_on_failures = bool(only_train_on_failures)
        self.reward_pm1 = bool(reward_pm1)

        # Build frozen solver engine if not provided.
        if solver_rollout_engine is None and solver_model:
            solver_base_url = solver_base_url or "https://api.openai.com/v1"
            solver_rollout_engine = OpenAIEngine(
                model=str(solver_model),
                tokenizer=None,
                base_url=str(solver_base_url),
                api_key=_resolve_api_key(str(solver_base_url), solver_api_key),
                sampling_params={"temperature": float(solver_temperature), "top_p": float(solver_top_p)},
                max_prompt_length=getattr(rollout_engine, "max_prompt_length", 8192),
                max_response_length=getattr(rollout_engine, "max_response_length", 8192),
                verbose=False,
            )

        if solver_rollout_engine is None:
            raise ValueError("FixerWorkflow requires a frozen solver. Provide solver_rollout_engine or solver_model.")

        self.frozen_solver = FrozenSolver(
            solver_rollout_engine,
            FrozenSolverConfig(system_prompt=solver_system_prompt),
        )
        self.fixer = BugFixer(
            rollout_engine,
            FixerConfig(
                system_prompt=fixer_system_prompt,
                include_failed_test_output=bool(kwargs.get("include_failed_test_output", False)),
                max_failed_test_output_chars=int(kwargs.get("max_failed_test_output_chars", 4000)),
            ),
        )

    async def run(self, task: Dict[str, Any], uid: str, **kwargs) -> Episode:
        self.reset(task, uid)
        task_info = _normalize_task_info(task)

        # If only_train_on_failures=False => one-shot mode.
        one_shot_mode = (not self.only_train_on_failures)

        validate_mode = bool(getattr(self.rollout_engine, "validate", False)) or bool(kwargs.get("validate", False))
        bug_source = "generated"
        solver_out: Optional[ModelOutput] = None

        # Get solver_code (either pre-provided buggy code in validate mode, or solver synthesis).
        if validate_mode:
            buggy = task_info.get("buggy_solution") or task_info.get("buggy_sampled_solution") or task_info.get("buggy") or ""
            if buggy:
                bug_source = "pregenerated"
                solver_code = str(buggy)
            else:
                solver_code, solver_out = await self.frozen_solver.synthesize(task_info, uid)
        else:
            solver_code, solver_out = await self.frozen_solver.synthesize(task_info, uid)

        # ---------------------------
        # (A) Failure-only evaluation
        # ---------------------------
        solver_pass: Optional[bool]
        solver_meta: Dict[str, Any]
        solver_ratio: Optional[float]

        if one_shot_mode:
            # One-shot protocol: do NOT run reward function (tests) on solver output.
            solver_pass = None
            solver_meta = {}
            solver_ratio = None
        else:
            # Evaluate solver/buggy code (first test run).
            try:
                solver_reward_out: RewardOutput = self.reward_function(task_info=task_info, action=solver_code)
            except Exception as e:
                solver_reward_out = RewardOutput(reward=0.0, is_correct=False, metadata={"error": str(e)})

            solver_pass = bool(solver_reward_out.is_correct)
            solver_meta = solver_reward_out.metadata or {}
            solver_ratio = _pass_ratio(solver_meta)

            # Early exit in training if solver passed and we only train on failures.
            if (not validate_mode) and solver_pass and self.only_train_on_failures:
                metrics: Dict[str, Any] = {
                    "solver_pass": 1.0,
                    "fixer_pass": 0.0,
                    "solver_pass_ratio": 1.0,
                    "fixer_pass_ratio": 0.0,
                }
                return Episode(
                    id=uid,
                    task=task,
                    trajectories=[],
                    termination_reason=TerminationReason.ENV_DONE,
                    is_correct=True,
                    metrics=metrics,
                    info={
                        "buggy_code": solver_code,
                        "bug_source": bug_source,
                        "solver_reward_metadata": solver_meta,
                    },
                )

        # ---------------------------
        # Run fixer (always in one-shot; conditional in failure-only via early exit above).
        # ---------------------------
        failed_test_output = None
        if (not one_shot_mode) and (not bool(solver_pass)) and self.fixer.config.include_failed_test_output:
            failed_test_output = _extract_failed_test_output(
                solver_meta,
                max_chars=self.fixer.config.max_failed_test_output_chars,
            )

        fixer_traj = await self.fixer.fix(
            task_info,
            buggy_code=solver_code,
            uid=uid,
            failed_test_output=failed_test_output,
        )
        fixer_step = fixer_traj.steps[0]
        fixed_code = fixer_step.action

        # Evaluate final submitted code.
        # - One-shot: this is the ONLY test run.
        # - Failure-only: this is the second test run.
        try:
            fix_reward_out: RewardOutput = self.reward_function(task_info=task_info, action=fixed_code)
        except Exception as e:
            fix_reward_out = RewardOutput(reward=0.0, is_correct=False, metadata={"error": str(e)})

        fix_pass = bool(fix_reward_out.is_correct)
        fix_meta = fix_reward_out.metadata or {}
        fix_ratio = _pass_ratio(fix_meta)

        # Compute fixer reward.
        if one_shot_mode:
            # One-shot: reward depends only on whether final submission passes.
            fixer_reward = (1.0 if fix_pass else -1.0) if self.reward_pm1 else (1.0 if fix_pass else 0.0)
        else:
            # Failure-only: optionally penalize unnecessary "fix" when solver already passed.
            if bool(solver_pass):
                fixer_reward = -1.0 if self.reward_pm1 else 0.0
            else:
                fixer_reward = (1.0 if fix_pass else -1.0) if self.reward_pm1 else (1.0 if fix_pass else 0.0)

        fixer_step.reward = fixer_traj.reward = float(fixer_reward)

        # Metrics
        metrics: Dict[str, Any] = {"fixer_pass": float(fix_pass)}
        if one_shot_mode:
            metrics["one_shot"] = 1.0
            metrics["solver_pass"] = None
            metrics["solver_pass_ratio"] = None
        else:
            metrics["solver_pass"] = float(bool(solver_pass))
            if solver_ratio is not None:
                metrics["solver_pass_ratio"] = solver_ratio

        if fix_ratio is not None:
            metrics["fixer_pass_ratio"] = fix_ratio

        episode = Episode(
            id=uid,
            task=task,
            trajectories=[fixer_traj],
            is_correct=(fixer_reward > 0.0) if self.reward_pm1 else bool(fixer_reward),
            metrics=metrics,
            info={
                "buggy_code": solver_code,
                "bug_source": bug_source,
                "solver_reward_metadata": (solver_meta if (not one_shot_mode) else None),
                "fixed_reward_metadata": fix_meta,
                "solver_model_output": solver_out.model_dump() if solver_out and hasattr(solver_out, "model_dump") else None,
            },
        )
        self.assign_episode_correctness(episode)
        return episode

    def assign_episode_correctness(self, episode: Episode) -> None:
        m = episode.metrics or {}
        if self.only_train_on_failures:
            # "Correct" means: solver failed AND fixer succeeded.
            episode.is_correct = bool(m.get("fixer_pass", 0.0)) and not bool(m.get("solver_pass", 0.0))
        else:
            # One-shot: "Correct" means final submitted code passed.
            episode.is_correct = bool(m.get("fixer_pass", 0.0))

