from typing import Optional

from src.wrappers import (
    VideoRecorder,
    SequentialMultiEnvWrapper,
    DMCEnv,
    EpisodeMonitor,
    RepeatAction,
    TakeKey,
    RGB2Gray,
    SinglePrecision,
    FrameStack,
    StickyActionEnv,
)

try:
    import gymnasium as gym

    old_api = False
except Exception:
    import gym

    old_api = True


def make_one_env(
    env_name: str,
    seed: int,
    save_folder: Optional[str] = None,
    add_episode_monitor: bool = True,
    action_repeat: int = 1,
    frame_stack: int = 1,
    from_pixels: bool = False,
    pixels_only: bool = True,
    image_size: int = 84,
    sticky: bool = False,
    gray_scale: bool = False,
    flatten: bool = True,
) -> gym.Env:
    # Check if the env is in gym.

    try:
        all_envs = gym.envs.registry.all()
        env_ids = [env_spec.id for env_spec in all_envs]
    except Exception:
        env_ids = gym.envs.registry.keys()

    if env_name in env_ids:
        env = gym.make(env_name)
    else:
        domain_name, task_name = env_name.split("-")
        env = DMCEnv(
            domain_name=domain_name,
            task_name=task_name,
            task_kwargs={"random": seed},
        )

    if flatten and isinstance(env.observation_space, gym.spaces.Dict):
        env = gym.wrappers.FlattenObservation(env)

    if add_episode_monitor:
        env = EpisodeMonitor(env)

    if action_repeat > 1:
        env = RepeatAction(env, action_repeat)

    env = gym.wrappers.RescaleAction(env, -1.0, 1.0)

    if save_folder is not None:
        env = VideoRecorder(env, save_folder=save_folder)

    if from_pixels:
        if env_name in env_ids:
            camera_id = 0
        else:
            camera_id = 2 if domain_name == "quadruped" else 0

        env = gym.wrappers.pixel_observation.PixelObservationWrapper(
            env,
            pixels_only=pixels_only,
            render_kwargs={
                "pixels": {
                    "height": image_size,
                    "width": image_size,
                    "camera_id": camera_id,
                }
            },
        )
        env = TakeKey(env, take_key="pixels")
        if gray_scale:
            env = RGB2Gray(env)
    else:
        env = SinglePrecision(env)

    if frame_stack > 1:
        env = FrameStack(env, num_stack=frame_stack)

    if sticky:
        env = StickyActionEnv(env)

    if old_api:
        env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    return env


def make_env(
    env_name: str,
    seed: int,
    eval_episodes: Optional[int] = None,
    save_folder: Optional[str] = None,
    add_episode_monitor: bool = True,
    action_repeat: int = 1,
    frame_stack: int = 1,
    from_pixels: bool = False,
    pixels_only: bool = True,
    image_size: int = 84,
    sticky: bool = False,
    gray_scale: bool = False,
    flatten: bool = True,
    eval_episode_length: int = 1000,
    num_envs: Optional[int] = None,
) -> gym.Env:

    if num_envs is None:
        return make_one_env(
            env_name,
            seed,
            save_folder,
            add_episode_monitor,
            action_repeat,
            frame_stack,
            from_pixels,
            pixels_only,
            image_size,
            sticky,
            gray_scale,
            flatten,
        )
    else:
        env_fn_list = [
            lambda: make_one_env(
                env_name,
                seed + i,  # noqa
                save_folder,
                add_episode_monitor,
                action_repeat,
                frame_stack,
                from_pixels,
                pixels_only,
                image_size,
                sticky,
                gray_scale,
                flatten,
            )
            for i in range(num_envs)
        ]
        return SequentialMultiEnvWrapper(
            env_fn_list, [seed + i for i in range(num_envs)]
        )
