import numpy as np
import gymnasium as gym
from gymnasium.core import ObsType
from .simulation_env import AeroSimulationEnv


class AeroFutureSimulationEnv(AeroSimulationEnv):
    """
    Extension of AeroSimulationEnv that includes future target information in the observation space.

    :param num_future_targets: Number of future target points to include in the observation.
    :param future_target_interval: Time (seconds) between each future target point.
    :param kwargs: Arguments passed to the parent AeroSimulationEnv.
    """

    def __init__(
            self,
            num_future_targets: int = 5,
            future_target_interval: float = 0.2,
            **kwargs
    ):
        super().__init__(**kwargs)

        self.num_future_targets = num_future_targets
        self.future_target_interval = future_target_interval

        new_spaces = self.observation_space.spaces.copy()

        if self.norm_observation:
            low, high = -1.0, 1.0
        else:
            limit = self.MAX_OBSERVATION_VALUES["target"]
            low, high = -limit, limit

        new_spaces["future_targets"] = gym.spaces.Box(
            low=low,
            high=high,
            shape=(self.num_future_targets,),
            dtype=np.float32
        )

        self.observation_space = gym.spaces.Dict(new_spaces)

    @property
    def state(self) -> ObsType:
        """
        Returns the state dictionary including pitch, velocity, current target, and the vector of future targets.
        """
        state = super().state

        future_times = np.array([
            self.current_time + (i * self.future_target_interval)
            for i in range(1, self.num_future_targets + 1)
        ])

        future_vals = np.array(
            [self.target_tilt_fn(t) for t in future_times],
            dtype=np.float32
        )

        if self.norm_observation:
            future_vals = future_vals / self.MAX_OBSERVATION_VALUES["target"]

        state["future_targets"] = future_vals

        return state


class AeroVelocitySimulationEnv(AeroSimulationEnv):
    """
    Extension of AeroSimulationEnv that includes velocity of the target in the observation space.

    :param kwargs: Arguments passed to the parent AeroSimulationEnv.
    """

    MAX_OBSERVATION_VALUES = AeroSimulationEnv.MAX_OBSERVATION_VALUES | {
        "target_velocity": 1.0,  # TODO determine appropriate max value
    }

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        new_spaces = self.observation_space.spaces.copy()

        if self.norm_observation:
            low, high = -1.0, 1.0
        else:
            limit = self.MAX_OBSERVATION_VALUES["target_velocity"]
            low, high = -limit, limit

        new_spaces["target_velocity"] = gym.spaces.Box(
            low=low,
            high=high,
            shape=(1,),
            dtype=np.float32
        )

        self.observation_space = gym.spaces.Dict(new_spaces)

    @property
    def state(self) -> ObsType:
        """
        Returns the state dictionary including pitch, velocity, and target.
        """
        state = super().state

        velocity = self.target_tilt_fn.get_derivative(self.current_time)

        if self.norm_observation:
            velocity = velocity / self.MAX_OBSERVATION_VALUES["target_velocity"]

        state["target_velocity"] = np.array([velocity], dtype=np.float32)

        return state
