import gym
from gym.envs.registration import register as gym_register_env
from ray.tune.registry import register_env as ray_register_env

from offline_rl.envs.bouncing_balls_env import BouncingBallsEnv
from offline_rl.envs.custom_reacher_env import CustomReacherEnv
from offline_rl.envs.line_env import LineEnv
from offline_rl.envs.maze_env import MazeEnv
from offline_rl.envs.point_maze_env import JsonWritablePointMazeEnv


def register_imitation_envs_with_rllib():
    # The use of the variable in the lambda works as intended.
    # pylint: disable=cell-var-from-loop
    for dname, dval in {"Left": 0, "Right": 1}.items():
        for vname, vval in {"": False, "Vel": True}.items():
            ray_register_env(
                f"imitation/PointMaze{dname}{vname}-v0",
                lambda env_config: JsonWritablePointMazeEnv(
                    direction=dval,
                    include_vel=vval,
                    **env_config,
                ),
            )


def register_ray_envs():
    ray_register_env("MazeEnv-v0", lambda env_config: MazeEnv())
    ray_register_env("LineEnv-v0", lambda env_config: LineEnv(**env_config))
    ray_register_env("BouncingBallsEnv-v0", lambda env_config: BouncingBallsEnv(**env_config))
    ray_register_env("CustomReacherEnv-v0", lambda env_config: CustomReacherEnv(**env_config))
    register_imitation_envs_with_rllib()


def register_gym_envs():
    # Check if the first env we register is already in the registry. If it is then we know the rest are
    # as well and we can just return (assuming this is the only place envs are registered).
    if "MazeEnv-v0" in [spec.id for spec in gym.envs.registry.all()]:
        return

    gym_register_env(id="MazeEnv-v0", entry_point="offline_rl.envs.maze_env:MazeEnv")
    gym_register_env(id="LineEnv-v0", entry_point="offline_rl.envs.line_env:LineEnv")
    gym_register_env(id="BouncingBallsEnv-v0", entry_point="offline_rl.envs.bouncing_balls_env:BouncingBallsEnv")
    gym_register_env(id="CustomReacherEnv-v0", entry_point="offline_rl.envs.custom_reacher_env:CustomReacherEnv")


def load_custom_envs() -> None:
    """Registers custom envs with both rllib and gym."""
    register_ray_envs()
    register_gym_envs()
