"""
Countdown task environment for SkyRL.
Based on TinyZero: https://github.com/Jiayi-Pan/TinyZero

Task: Given a target number and a list of numbers, create an equation using basic
arithmetic operations (+, -, *, /) to reach the target. Each number can only be used once.
"""

from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput
from skyrl_gym.envs.countdown import utils
from typing import Dict, Any
from omegaconf import DictConfig


class CountdownEnv(BaseTextEnv):
    """
    Environment for Countdown task.

    The countdown task requires the agent to create an equation that equals a target
    number using a given list of numbers and basic arithmetic operations.
    """

    def __init__(self, env_config: DictConfig = None, extras: Dict[str, Any] = None):
        super().__init__()

        if extras is None:
            extras = {}

        # Validate that we have the required reward specification
        # Support both 'reward_model' (standard) and 'reward_spec' (legacy) for compatibility
        if "reward_model" in extras:
            reward_info = extras["reward_model"]
        elif "reward_spec" in extras:
            reward_info = extras["reward_spec"]
        else:
            raise ValueError("Either 'reward_model' or 'reward_spec' field is required in extras")

        assert "ground_truth" in reward_info, "ground_truth is required in reward_model/reward_spec field"

        ground_truth = reward_info["ground_truth"]

        # Validate ground truth format
        assert "target" in ground_truth, "ground_truth must contain 'target'"
        assert "numbers" in ground_truth, "ground_truth must contain 'numbers'"

        self.ground_truth = ground_truth
        self.target = ground_truth["target"]
        self.available_numbers = ground_truth["numbers"]

        # Optional configuration
        self.correct_score = 1.0
        self.format_score = 0.1

        if env_config is not None:
            self.correct_score = env_config.get("correct_score", 1.0)
            self.format_score = env_config.get("format_score", 0.1)
            # Debug: print once to verify config is being read
            import random
            if random.random() < 0.01:  # 1% chance to avoid spam
                print(f"[CountdownEnv DEBUG] env_config: {env_config}, format_score set to: {self.format_score}")

    def _get_reward(self, action: str) -> float:
        """
        Calculate reward for the agent's action.

        Args:
            action: The agent's generated solution text

        Returns:
            Reward score (1.0 for correct, 0.1 for valid format, 0.0 for invalid)
        """
        return utils.compute_score(
            solution_text=action,
            ground_truth=self.ground_truth,
            correct_score=self.correct_score,
            format_score=self.format_score
        )

    def step(self, action: str) -> BaseTextEnvStepOutput:
        """
        Execute one step in the environment.

        For countdown task, this is always a single-step episode - the agent
        provides its full solution and receives a reward.

        Args:
            action: The agent's generated solution text

        Returns:
            BaseTextEnvStepOutput with reward and done=True
        """
        done = True  # Countdown is always single-step
        reward = self._get_reward(action)

        # No additional observations needed for countdown task
        return BaseTextEnvStepOutput(
            observations=[],
            reward=reward,
            done=done,
            metadata={
                "target": self.target,
                "available_numbers": self.available_numbers
            }
        )
