import gym
import jax
import numpy as np


def make(env_id: str):
    env = gym.make(
        env_id,
        obs_type='grayscale',
        frameskip=4,
        repeat_action_probability=0.0,
        full_action_space=False
    )
    return AtariImagePreprocessor(env)


class AtariImagePreprocessor(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=[84, 84, 1], dtype=np.uint8)

    def observation(self, obs):
        return preprocess_atari_image(obs)


@jax.jit
def preprocess_atari_image(image):
    # TODO: Should not use JAX here
    image = jax.image.resize(image, [84, 84], method=jax.image.ResizeMethod.NEAREST)
    return image[..., None]
