import gym
from gym.envs import register
import gym_classics
gym_classics.register('gym')
from gym_classics.envs.abstract.gridworld import Gridworld

from tasks.envs import atari
from tasks.envs.wrappers import OneHot, OneHotCoordinate


def make(env_id: str):
    if is_atari(env_id):  # Atari game
        return atari.make(env_id)

    env = gym.make(env_id)

    if isinstance(env.unwrapped, Gridworld):  # Gym Classics gridworld
        env = OneHotCoordinate(env)
    elif isinstance(env.observation_space, gym.spaces.Discrete):  # Discrete observations
        env = OneHot(env)

    if env_id.startswith('LunarLander'):
        env = gym.wrappers.TimeLimit(env, max_episode_steps=1_000)

    return env


def is_atari(env_id: str):
    return (env_id.startswith('ALE/')) and (env_id.endswith('-v5'))


### Environment registration:

register(
    id='SplitGridworld-v0',
    entry_point='tasks.envs.split_gridworld:SplitGridworld'
)
