"""Shared features extractor modules for RL algorithms."""
from abc import ABC, abstractmethod
from functools import partial

import gymnasium as gym
import jax
import jax.numpy as jnp
from stable_baselines3.common.preprocessing import get_flattened_obs_dim


class BaseFeaturesExtractor(ABC):
    """Base class for features extractor.

    A features extractor takes as input an observation and returns a
    vectorized representation of this observation. No assumptions are
    made about the kind of input the extractor can take as input, and
    the sort of output it can return, only that it must be a float vector.

    For example, a CNN features extractor would take as input an image
    observation and return a vector of features extracted by the CNN.

    Args:
        observation_space: The observation space of the environment.
        features_dim: The dimension of the features extracted.
    """

    # Private Instance Attributes:
    #   _observation_space: The observation space of the environment.
    #   _features_dim: The dimension of the features extracted.
    _observation_space: gym.spaces.Space
    _features_dim: int

    def __init__(self, observation_space: gym.spaces.Space,
                 features_dim: int = 0) -> None:
        """Initialize a new features extractor."""
        assert features_dim > 0, 'The features_dim must be positive'
        self._observation_space = observation_space
        self._features_dim = features_dim

    @abstractmethod
    def __call__(self, observation: jax.typing.ArrayLike) -> jax.Array:
        """Extract features from the observation."""
        raise NotImplementedError

    def __repr__(self) -> str:
        """Return a string representation of the features extractor."""
        return f'{self.__class__.__name__}(' \
            f'observation_space={self._observation_space}, ' \
            f'features_dim={self._features_dim})'

    @property
    def features_dim(self) -> int:
        """Return the dimension of the features extracted."""
        return self._features_dim


class FlattenExtractor(BaseFeaturesExtractor):
    """Flatten all dimensions of the observation space, except the first."""

    def __init__(self, observation_space: gym.spaces.Space) -> None:
        """Initialize a new flatten extractor."""
        features_dim = get_flattened_obs_dim(observation_space)
        super().__init__(observation_space, features_dim)

    @partial(jax.jit, static_argnums=0)
    def __call__(self, observation: jax.typing.ArrayLike) -> jax.Array:
        """Flatten the observation."""
        return jnp.reshape(observation, (observation.shape[0], -1))
