# Original implementation: https://github.com/abaisero/gym-gridverse
#
####
#
# Extended to informed POMDPs by anonymous authors (2025)
#
####


from typing import Optional, Tuple

import numpy.random as rnd

from gym_gridverse.action import Action
from gym_gridverse.debugging import gv_debug
from gym_gridverse.envs import InnerEnv
from gym_gridverse.envs.observation_functions import ObservationFunction
from gym_gridverse.envs.reset_functions import ResetFunction
from gym_gridverse.envs.reward_functions import RewardFunction
from gym_gridverse.envs.terminating_functions import TerminatingFunction
from gym_gridverse.envs.transition_functions import (
    TransitionFunction,
    transition_with_copy,
)
from gym_gridverse.observation import Observation
from gym_gridverse.rng import make_rng
from gym_gridverse.spaces import ActionSpace, ObservationSpace, StateSpace
from gym_gridverse.state import State


class GridWorld(InnerEnv):
    """Implementation of the InnerEnv interface."""

    def __init__(
        self,
        state_space: StateSpace,
        action_space: ActionSpace,
        observation_space: ObservationSpace,
        reset_function: ResetFunction,
        transition_function: TransitionFunction,
        observation_function: ObservationFunction,
        reward_function: RewardFunction,
        termination_function: TerminatingFunction,
    ):
        """Initializes a GridWorld from the given components.

        Args:
            state_space (StateSpace):
            action_space (ActionSpace):
            observation_space (ObservationSpace):
            reset_function: (ResetFunction):
            transition_function: (TransitionFunction),:
            observation_function (ObservationFunction):
            reward_function (RewardFunction):
            termination_function (TerminatingFunction):
        """

        # TODO: maybe add a parameter to avoid calls to `contain` everywhere
        # (or maybe a global setting)

        self._reset_function = reset_function
        self._transition_function = transition_function
        self._observation_function = observation_function
        self._reward_function = reward_function
        self._termination_function = termination_function

        self._rng: Optional[rnd.Generator] = None

        super().__init__(state_space, action_space, observation_space)

    def set_seed(self, seed: Optional[int] = None):
        self._rng = make_rng(seed)

    def functional_reset(self) -> State:
        state = self._reset_function(rng=self._rng)
        if gv_debug() and not self.state_space.contains(state):
            raise ValueError('state does not satisfy state_space')

        return state

    def functional_step(
        self, state: State, action: Action
    ) -> Tuple[State, float, bool]:
        if gv_debug() and not self.state_space.contains(state):
            raise ValueError('state does not satisfy state_space')
        if not self.action_space.contains(action):
            raise ValueError('action {action} does not satisfy action-space')

        next_state = transition_with_copy(
            self._transition_function,
            state,
            action,
            rng=self._rng,
        )

        if gv_debug() and not self.state_space.contains(next_state):
            raise ValueError('next_state does not satisfy state_space')

        reward = self._reward_function(state, action, next_state)
        terminal = self._termination_function(state, action, next_state)

        return (next_state, reward, terminal)

    def functional_observation(self, state: State) -> Observation:
        observation = self._observation_function(state, rng=self._rng)
        if gv_debug() and not self.observation_space.contains(observation):
            raise ValueError('observation does not satisfy observation_space')

        return observation

# Informed Grid World
class InformedGridWorld(GridWorld):
    """Implementation of an informed GridWorld."""

    def __init__(
        self,
        state_space: StateSpace,
        action_space: ActionSpace,
        observation_space: ObservationSpace,
        information_space: ObservationSpace,
        reset_function: ResetFunction,
        transition_function: TransitionFunction,
        observation_function: ObservationFunction,
        information_function: ObservationFunction,
        reward_function: RewardFunction,
        termination_function: TerminatingFunction,
    ):
        """Initializes a GridWorld from the given components.

        Args:
            state_space (StateSpace):
            action_space (ActionSpace):
            observation_space (ObservationSpace):
            information_space (ObservationSpace):
            reset_function: (ResetFunction):
            transition_function: (TransitionFunction),:
            observation_function (ObservationFunction):  
            information_function (ObservationFunction):
            reward_function (RewardFunction):
            termination_function (TerminatingFunction):
        """

        super().__init__(state_space, action_space, observation_space, reset_function, transition_function, observation_function, reward_function, termination_function)
        self._information_function = information_function
        self.information_space = information_space

        # Privileged information
        self._information: Optional[Observation] = None
        

    def functional_information(self, state: State) -> Observation:
        information = self._information_function(state, rng=self._rng)
        if gv_debug() and not self.information_space.contains(information):
            raise ValueError('information does not satisfy information_space')

        return information


    @property
    def information(self) -> Observation:
        """Returns the current information

        Internally calls :py:meth:`functional_information` to generate the
        current informationbased on the current state.  The informationis
        generated lazily, such that at most one information is generated for
        each state.  As a consequence, this will return the same information
        until the state is reset/updated, even if the observation function is
        stochastic.

        Returns:
            Information: privileged information
        """
        # memoizing information because information function can be stochastic
        if self._information is None:
            self._information = self.functional_information(self.state)

        return self._information

    def reset(self):
        """Resets the state

        Internally calls :py:meth:`functional_reset` to reset the state;  also
        resets the observation, so that an updated observation will be
        generated upon request.
        """
        self._state = self.functional_reset()
        self._observation = None
        self._information = None

    def step(self, action: Action) -> Tuple[float, bool]:
        """Runs the dynamics for one timestep, and returns reward and done flag

        Internally calls :py:meth:`functional_step` to update the state;  also
        resets the observation, so that an updated observation will be
        generated upon request.

        Args:
            action (Action): the chosen action to apply

        Returns:
            Tuple[float, bool]: reward and terminal
        """

        self._state, reward, done = self.functional_step(self.state, action)
        self._observation = None
        self._information = None
        return reward, done


