"""Math domain environment for SkyRL training.

Provides MathEnv class for math problem solving with advisor feedback.
Evaluates generated solutions for correctness.
"""

from typing import Any, Dict, List, Tuple

from omegaconf import DictConfig

from ..env_base import BaseAdvisorEnv
from .config import (
    STUDENT_SYSTEM_PROMPT,
    STUDENT_INSTRUCTION,
    BASELINE_SYSTEM_PROMPT,
    BASELINE_INSTRUCTION,
    extract_answer,
    compute_correctness_score,
)


class MathEnv(BaseAdvisorEnv):
    """Environment for math problem solving with advisor feedback using 2-step flow."""

    def __init__(self, env_config: DictConfig, extras: Dict[str, Any] = {}):
        super().__init__(env_config, extras)

    def _build_baseline_prompt(
        self, prompt: List[Dict[str, str]]
    ) -> Tuple[List[Dict[str, str]], str]:
        """Build prompt for baseline model to solve the math problem."""
        formatted_prompt = BASELINE_INSTRUCTION.format(problem=self.original_question)
        return [
            {"role": "system", "content": BASELINE_SYSTEM_PROMPT},
            {"role": "user", "content": formatted_prompt},
        ], formatted_prompt

    def _build_advisor_prompt(
        self, prompt: List[Dict[str, str]]
    ) -> Tuple[List[Dict[str, str]], str]:
        return super()._build_advisor_prompt(prompt)

    def _build_student_prompt(
        self, advisor_feedback: str
    ) -> Tuple[List[Dict[str, str]], str]:
        """Build prompt for student model to solve the math problem."""
        formatted_prompt = STUDENT_INSTRUCTION.format(
            problem=self.original_question,
            advisor_feedback=advisor_feedback,
        )

        return [
            {"role": "system", "content": STUDENT_SYSTEM_PROMPT},
            {"role": "user", "content": formatted_prompt},
        ], formatted_prompt

    def _compute_step(self) -> Tuple[float, bool, Dict[str, Any]]:
        """Compute reward based on correctness of the final answer."""
        try:
            # Extract answer from the final response
            extracted_answer = extract_answer(self.final_response)

            # Compute correctness score
            reward = compute_correctness_score(extracted_answer, self.ground_truth)

            return reward, True, {}

        except Exception as e:
            print(f"Error computing correctness reward: {e}")
            return 0.0, True, {}

    def _get_metadata(self) -> Dict[str, Any]:
        metadata = super()._get_metadata()
        metadata["other_info"] = f"Ground Truth Answer: {self.ground_truth}"
        return metadata
