from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from concurrent.futures import ThreadPoolExecutor

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine import ModelOutput, RolloutEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.rewards.reward_types import RewardOutput
from rllm.workflows.workflow import Workflow

from examples.bugs.prompts import _build_code_generation_prompt


@dataclass
class CodeGenConfig:
    system_prompt: Optional[str] = None


class CodeGenerator:
    """Lightweight wrapper around the rollout engine for code generation."""

    def __init__(self, rollout_engine: RolloutEngine, config: Optional[CodeGenConfig] = None):
        self.rollout_engine = rollout_engine
        self.config = config or CodeGenConfig()

    @staticmethod
    def _split_think(raw_response: str) -> tuple[str, str]:
        """Split `<think>...</think>` from the final answer if present."""
        if raw_response.count("</think>") == 1:
            thought, sep, action = raw_response.partition("</think>")
            return (thought + sep).strip(), action.strip()
        return "", raw_response.strip()

    async def generate(self, task: Dict[str, Any], uid: str) -> Trajectory:
        prompt = _build_code_generation_prompt(task)

        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})

        model_output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        raw_response = model_output.content
        thought, action = self._split_think(raw_response)

        chat_completions = messages + [{"role": "assistant", "content": action}]
        step = Step(
            chat_completions=chat_completions,
            thought=thought,
            action=action,
            model_response=raw_response,
            model_output=model_output,
        )

        return Trajectory(name="code_generator", steps=[step])


class CodeGenWorkflow(Workflow):
    """Workflow that trains a single agent for *pure* code generation.

    For each episode:
      1) The model generates code for a task.
      2) The reward function runs unit tests and returns correctness.

    This is used for:
      - Train: DeepCoder (registered as `deepcoder_bugs/train`)
      - Val/Test: BigCodeBench (registered as `bigcodebench/test`)
    """

    def __init__(
        self,
        rollout_engine: RolloutEngine,
        executor: ThreadPoolExecutor,
        reward_function: RewardFunction,
        system_prompt: Optional[str] = None,
        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, executor=executor, **kwargs)
        self.reward_function = reward_function
        self.codegen = CodeGenerator(rollout_engine, CodeGenConfig(system_prompt=system_prompt))

    async def run(self, task: Dict[str, Any], uid: str, **kwargs) -> Episode:
        self.reset(task, uid)

        traj = await self.codegen.generate(task, uid)
        step = traj.steps[0]

        # Evaluate correctness via unit tests.
        task_info = task.get("extra_info", task)
        try:
            reward_output: RewardOutput = self.reward_function(task_info=task_info, action=step.action)
        except Exception as e:
            reward_output = RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={"error": f"codegen reward error: {e}"},
            )

        step.reward = float(reward_output.reward)
        traj.reward = float(step.reward)

        meta = reward_output.metadata or {}
        metrics: Dict[str, Any] = {
            "codegen_reward": float(reward_output.reward),
            "codegen_pass": float(reward_output.is_correct),
        }
        if "total_tests" in meta:
            metrics["total_tests"] = int(meta["total_tests"])
        if "passed_tests" in meta:
            metrics["passed_tests"] = int(meta["passed_tests"])
        if "all_passed" in meta:
            metrics["all_passed"] = float(bool(meta["all_passed"]))

        episode = Episode(
            id=uid,
            task=task,
            trajectories=[traj],
            is_correct=bool(reward_output.is_correct),
            metrics=metrics,
        )
        self.assign_episode_correctness(episode)
        return episode

    def assign_episode_correctness(self, episode: Episode) -> None:
        pass
