import os.path

import gymnasium as gym
from minigrid.wrappers import ViewSizeWrapper

from dcrl.envs.minigrid_wrapper import MinigridObservationWrapper
from dcrl.envs.miniworld_wrapper import MiniworldObservationWrapper


def make_minigrid_env(
    env_key,
    seed=None,
    idx=None,
    capture_video=False,
    log_dir=None,
    render_mode=None,
    view_size=None,
):
    def thunk():
        render_mode_ = "rgb_array" if capture_video else render_mode
        env = gym.make(env_key, render_mode=render_mode_)
        if view_size is not None:
            env = ViewSizeWrapper(env, agent_view_size=view_size)
        env = MinigridObservationWrapper(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, os.path.join(log_dir, "videos"))
        env.reset(seed=seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk


def make_miniworld_env(
    env_key,
    seed=None,
    idx=None,
    capture_video=False,
    log_dir=None,
    render_mode=None,
):
    import miniworld  # noqa

    def thunk():
        render_mode_ = "rgb_array" if capture_video else render_mode
        env = gym.make(env_key, render_mode=render_mode_)
        env = MiniworldObservationWrapper(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, os.path.join(log_dir, "videos"))
        env.reset(seed=seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk
