from envs.custom_dmc_tasks import cheetah
from envs.custom_dmc_tasks import walker
from envs.custom_dmc_tasks import hopper
from envs.custom_dmc_tasks import quadruped
from envs.custom_dmc_tasks import jaco
from envs.custom_dmc_tasks import point_mass_maze
from envs.custom_dmc_tasks import shadowhand
from envs.custom_dmc_tasks import franka
from envs.custom_dmc_tasks import ur5e


def make(
    domain, task, task_kwargs=None, environment_kwargs=None, visualize_reward=False
):
    if domain == "cheetah":
        return cheetah.make(
            task,
            task_kwargs=task_kwargs,
            environment_kwargs=environment_kwargs,
            visualize_reward=visualize_reward,
        )
    elif domain == "walker":
        return walker.make(
            task,
            task_kwargs=task_kwargs,
            environment_kwargs=environment_kwargs,
            visualize_reward=visualize_reward,
        )
    elif domain == "point_mass_maze":
        return point_mass_maze.make(
            task,
            task_kwargs=task_kwargs,
            environment_kwargs=environment_kwargs,
            visualize_reward=visualize_reward,
        )
    elif domain == "hopper":
        return hopper.make(
            task,
            task_kwargs=task_kwargs,
            environment_kwargs=environment_kwargs,
            visualize_reward=visualize_reward,
        )
    elif domain == "quadruped":
        return quadruped.make(
            task,
            task_kwargs=task_kwargs,
            environment_kwargs=environment_kwargs,
            visualize_reward=visualize_reward,
        )
    elif domain == "shadowhand":
        return shadowhand.make(
            task,
            task_kwargs=task_kwargs,
            environment_kwargs=environment_kwargs,
            visualize_reward=visualize_reward,
        )
    elif domain == "ur5e":
        return ur5e.make(
            task,
            task_kwargs=task_kwargs,
            environment_kwargs=environment_kwargs,
            visualize_reward=visualize_reward,
        )
    else:
        raise ValueError(f"{task} not found")


def make_jaco(task, obs_type, seed):
    return jaco.make(task, obs_type, seed)
