import gym

from src.rllib.models.modelv2 import ModelV2
from src.rllib.utils.annotations import PublicAPI
from src.rllib.utils.typing import ModelConfigDict


@PublicAPI
class JAXModelV2(ModelV2):
    """JAX version of ModelV2.

    Note that this class by itself is not a valid model unless you
    implement forward() in a subclass."""

    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str):
        """Initializes a JAXModelV2 instance."""

        ModelV2.__init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            framework="jax")
