# Original implementation: https://github.com/abaisero/asym-rlpo
#
####
#
# Extended to informed POMDPs by anonymous authors (2025)
#
####

from typing import Optional, Tuple

from .env import (
    Action,
    Environment,
    EnvironmentType,
    Latent,
    LatentType,
    Observation,

    # Privileged information
    Information
)
from .env_gv import make_gv_env
from .env_gym import make_gym_env


def make_env(
    id_or_path: str,
    *,
    latent_type: LatentType,
    max_episode_timesteps: Optional[int] = None,
) -> Environment:

    try:
        env = make_gym_env(id_or_path, latent_type=latent_type)

    except ValueError:
        print(
            f'Environment with id {id_or_path} not found.'
            ' Trying as a GV YAML environment.'
        )
        env = make_gv_env(id_or_path, latent_type=latent_type)

    if max_episode_timesteps is not None:
        env = InformedTimeLimitEnvironment(env, max_episode_timesteps)

    return env


class TimeLimitEnvironment:
    """terminates episodes after a given number of timesteps"""

    def __init__(self, env: Environment, max_timestep: int):
        self._env = env
        self.type = self._env.type
        self.latent_type = self._env.latent_type
        self.latent_space = self._env.latent_space
        self.action_space = self._env.action_space
        self.observation_space = self._env.observation_space

        self._timestep: int
        self._max_timestep = max_timestep

    def seed(self, seed: Optional[int] = None) -> None:
        self._env.seed(seed)

    def reset(self) -> Tuple[Observation, Latent]:
        self._timestep = 0
        return self._env.reset()

    def step(self, action: Action) -> Tuple[Observation, Latent, float, bool]:
        self._timestep += 1
        observation, latent, reward, done = self._env.step(action)
        done = done or self._timestep >= self._max_timestep
        return observation, latent, reward, done

    def render(self) -> None:
        self._env.render()


# Informed environment variant with finite horizon
class InformedTimeLimitEnvironment(TimeLimitEnvironment):
    """terminates episodes after a given number of timesteps"""

    def __init__(self, env: Environment, max_timestep: int):
        super().__init__(env, max_timestep)
        self.information_space = self._env.information_space

    def seed(self, seed: Optional[int] = None) -> None:
        self._env.seed(seed)

    def reset(self) -> Tuple[Observation, Latent, Information]:
        self._timestep = 0
        return self._env.reset()

    def step(self, action: Action) -> Tuple[Observation, Latent, float, bool, Information]:
        self._timestep += 1
        observation, latent, reward, done, information = self._env.step(action)
        done = done or self._timestep >= self._max_timestep
        return observation, latent, reward, done, information

    def render(self) -> None:
        self._env.render()

