import gymnasium as gym
import jax

from environments.base import BaseEnv


class LunarLander(BaseEnv):
    def __init__(self, env_key: jax.random.PRNGKey) -> None:
        super().__init__(env_key, gym.make("LunarLander-v2", render_mode="rgb_array"))
