from gym.envs.registration import register
import numpy as np

from .kitchen.kitchen_multitask import (
    KitchenMultiTaskEnv,
    KitchenMultiTaskMultistageEnv,
    KitchenMixed,
)
from .kitchen.kitchen_multistage import KitchenMultistageEnv
from .maze import MazeEnv
from environments.metaworld.metaworld_gym import MetaWorldEnv
from environments.metaworld.metaworld_cds import DoorOpenEnvV2
from .reacher import (
    ReacherSlowObstacleEnv,
    ReacherRandomizedEnv,
    GymReacherEnv,
    ReacherClusterEnv,
    ReacherRewardScaleEnv,
    ReacherRewardShiftEnv,
    ReacherRewardShapedEnv,
    ReacherGoalClusterEnv,
)
from .reacher_multistage import ReacherMultistageMTEnv
from .recsim_envs import RecSimEnv
from .walker_velocity import WalkerVelocityMTEnv
from environments.mujoco.walker2d_multitask import Walker2dMTEnv, Walker2dForwardEnv


# from environments.dnc_environments import (
#     create_deterministic,
#     create_env_partitions,
#     create_stochastic,
# )

### Walker

register(
    id="WalkerVelocityMT5-v0",
    max_episode_steps=1000,
    entry_point="environments:WalkerVelocityMTEnv",
    kwargs={
        "target_velocities": np.linspace(2.1, 3, 5),
        "velocity_reward_weight": 1.0,
        "walker_reward_weight": 0,
    },
)

register(
    id="WalkerVelocityMT5-v2",
    max_episode_steps=1000,
    entry_point="environments:WalkerVelocityMTEnv",
    kwargs={
        "target_velocities": np.linspace(2.1, 3, 5),
        "velocity_reward_weight": 10.0,
        "walker_reward_weight": 1.0,
    },
)

register(
    id="WalkerVelocityMT3-v0",
    max_episode_steps=1000,
    entry_point="environments:WalkerVelocityMTEnv",
    kwargs={
        "target_velocities": np.linspace(1.0, 2.5, 3),
        "velocity_reward_weight": 1.0,
        "walker_reward_weight": 1.0,
        "velocity_bonus": 1.0,
        "velocity_bonus_range": 0.2,
    },
)

register(
    id="Walker2dMT4-v0",
    max_episode_steps=1000,
    entry_point="environments:Walker2dMTEnv",
)

register(
    id="Walker2dMT4-v1",
    max_episode_steps=1000,
    entry_point="environments:Walker2dMTEnv",
    kwargs={
        "prob_apply_force": [0.1, 0.1, 0.1, 0.1]
    }
)

register(
    id="Walker2dMT4-v2",
    max_episode_steps=1000,
    entry_point="environments:Walker2dMTEnv",
    kwargs={
        "prob_apply_force": [0.2, 0.2, 0.2, 0.2]
    }
)

register(
    id="Walker2dForward-v0",
    max_episode_steps=1000,
    entry_point="environments:Walker2dForwardEnv",
    kwargs={
        "task_id": 0,
        "n_total_tasks": 4,
    }
)


register(
    id="RecSim-v0",
    entry_point="environments:RecSimEnv",
    max_episode_steps=50,
)

register(
    id="GymReacher-v0",
    entry_point="environments:GymReacherEnv",
    max_episode_steps=50,
)

register(
    id="ReacherRandomized-v0",
    entry_point="environments:ReacherRandomizedEnv",
    max_episode_steps=50,
)

register(
    id="ReacherObstacle-v0",
    entry_point="environments:ReacherSlowObstacleEnv",
    max_episode_steps=50,
)

register(
    id="ReacherCluster-v0",
    entry_point="environments:ReacherClusterEnv",
    max_episode_steps=50,
)

register(
    id="ReacherRewardScale-v0",
    entry_point="environments:ReacherRewardScaleEnv",
    max_episode_steps=50,
)

register(
    id="ReacherRewardShift-v0",
    entry_point="environments:ReacherRewardShiftEnv",
    max_episode_steps=50,
)

register(
    id="ReacherRewardShaped-v0",
    entry_point="environments:ReacherRewardShapedEnv",
    max_episode_steps=50,
)

### This one can choose a specific goal cluster ###

register(
    id="ReacherGoalCluster-v0",
    entry_point="environments:ReacherGoalClusterEnv",
    max_episode_steps=50,
)

register(
    id="ReacherMultistage-v0",
    entry_point="environments:ReacherMultistageMTEnv",
    max_episode_steps=90,
)

register(
    id="ReacherMultistageRandom-v0",
    entry_point="environments:ReacherMultistageMTEnv",
    max_episode_steps=90,
    kwargs={"random_reset": True},
)

register(
    id="ReacherMultistageFixed-v0",
    entry_point="environments:ReacherMultistageMTEnv",
    max_episode_steps=90,
    kwargs={"fixed_stage_length": True},
)


# Kitchen
# ----------------------------------------
register(
    id="KitchenMT10-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=70,
    kwargs={"benchmark_type": "MT10"},
)

register(
    id="KitchenSkillMT3-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=30,
    kwargs={
        "benchmark_type": "MT3",
        "max_episode_steps": 30,
        "use_skill_space": True,
        "accumulate_reward": True,  ### ASDF not sure
    },
)

register(
    id="KitchenSkillMT3-v1",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=30,
    kwargs={
        "benchmark_type": "MT3",
        "max_episode_steps": 30,
        "use_skill_space": True,
        "accumulate_reward": False,
    },
)

register(
    id="KitchenSkillMT3-v2",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=30,
    kwargs={
        "benchmark_type": "MT3_v2",
        "max_episode_steps": 30,
        "use_skill_space": True,
        "accumulate_reward": False,
    },
)

register(
    id="KitchenSkillMT3-v3",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=30,
    kwargs={
        "benchmark_type": "MT3_v3",
        "max_episode_steps": 30,
        "use_skill_space": True,
        "accumulate_reward": False,
    },
)

register(
    id="KitchenSkillHard-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=30,
    kwargs={
        "benchmark_type": "HARD",
        "max_episode_steps": 30,
        "use_skill_space": True,
        "accumulate_reward": False,
    },
)

register(
    id="KitchenSkillMT5-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=30,
    kwargs={
        "benchmark_type": "MT5",
        "max_episode_steps": 30,
        "use_skill_space": True,
        "accumulate_reward": False,
    },
)

register(
    id="KitchenSkillMT2-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=30,
    kwargs={
        "benchmark_type": "MT2",
        "max_episode_steps": 30,
        "use_skill_space": True,
        "accumulate_reward": True,
    },
)

register(
    id="KitchenSkillMT4-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=30,
    kwargs={
        "benchmark_type": "MT4",
        "max_episode_steps": 30,
        "use_skill_space": True,
        "accumulate_reward": True,
    },
)

register(
    id="KitchenSkillMT4-v1",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=25,
    kwargs={
        "benchmark_type": "MT4",
        "max_episode_steps": 25,
        "use_skill_space": True,
        "accumulate_reward": True,
    },
)

register(
    id="KitchenSkillMT4-v2",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=20,
    kwargs={
        "benchmark_type": "MT4",
        "max_episode_steps": 20,
        "use_skill_space": True,
        "accumulate_reward": True,
    },
)

register(
    id="KitchenSingle-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=70,
    kwargs={"benchmark_type": "MT1"},
)

register(
    id="KitchenSingleLong-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=140,
    kwargs={"benchmark_type": "MT1", "max_episode_steps": 140},
)

register(
    id="KitchenSkillSingle-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=15,
    kwargs={"benchmark_type": "MT1", "max_episode_steps": 150, "use_skill_space": True},
)

register(
    id="KitchenSkillSingle-v1",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=10,
    kwargs={
        "benchmark_type": "MT1",
        "max_episode_steps": 100,
        "use_skill_space": True,
        "accumulate_reward": False,
    },
)

register(
    id="kitchenmixed-v0",
    entry_point="environments:KitchenMixed",
    max_episode_steps=28,
    kwargs={
        "max_episode_steps": 28,
        "use_skill_space": True,
        "accumulate_reward": True,
    },
)

register(
    id="KitchenAll-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=70,
    kwargs={"benchmark_type": "All"},
)

register(
    id="KitchenCabinet-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=140,
    kwargs={"benchmark_type": "CABINET", "max_episode_steps": 140},
)

register(
    id="KitchenMTEasy-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=140,
    kwargs={"benchmark_type": "MT_EASY", "max_episode_steps": 140},
)

register(
    id="KitchenMTEasy5-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=140,
    kwargs={"benchmark_type": "MT_EASY5", "max_episode_steps": 140},
)

register(
    id="KitchenMicrowave-v0",
    entry_point="environments:KitchenMultiTaskEnv",
    max_episode_steps=140,
    kwargs={"benchmark_type": "MICROWAVE", "max_episode_steps": 140},
)

register(
    id="KitchenMS3-v0",
    entry_point="environments:KitchenMultiTaskMultistageEnv",
    max_episode_steps=400,
    kwargs={
        "task_names": [["slide cabinet-open", "kettle-push", "bottom burner-on"]],
        "max_episode_steps": 400,
    },
)

register(
    id="KitchenSkillMS3-v0",
    entry_point="environments:KitchenMultiTaskMultistageEnv",
    max_episode_steps=28,
    kwargs={
        "task_names": [["slide cabinet-open", "kettle-push", "bottom burner-on"]],
        "max_episode_steps": 280,
        "use_skill_space": True,
        "accumulate_reward": True,
    },
)

register(
    id="KitchenSkillMS3-v1",
    entry_point="environments:KitchenMultiTaskMultistageEnv",
    max_episode_steps=28,
    kwargs={
        "task_names": [["slide cabinet-open", "kettle-push", "bottom burner-on"]],
        "max_episode_steps": 28,
        "use_skill_space": True,
        "accumulate_reward": False,
    },
)


# Meta-World
# ----------------------------------------
register(
    id="MetaWorldCDS-v1",
    entry_point="environments:MetaWorldEnv",
    max_episode_steps=500,
    kwargs={"benchmark_type": "CDS_v1"},
)
register(
    id="MetaWorldCDS-v2",
    entry_point="environments:MetaWorldEnv",
    max_episode_steps=500,
    kwargs={"benchmark_type": "CDS"},
)

register(
    id="MetaWorldMT10-v2",
    entry_point="environments:MetaWorldEnv",
    max_episode_steps=500,
    kwargs={"benchmark_type": "MT10"},
)

register(
    id="MetaWorldMT50-v2",
    entry_point="environments:MetaWorldEnv",
    max_episode_steps=500,
    kwargs={"benchmark_type": "MT50"},
)

register(
    id="MetaWorldSingle-v2",
    entry_point="environments:MetaWorldEnv",
    max_episode_steps=500,
    kwargs={"benchmark_type": "MT1"},
)

register(
    id="MetaWorldMT3-v2",
    entry_point="environments:MetaWorldEnv",
    max_episode_steps=500,
    kwargs={"benchmark_type": "MT3"},
)


# Walker2d
# ----------------------------------------
register(
    id="Walker2dForward-v1",
    entry_point="environments.mujoco:Walker2dForwardEnv",
    max_episode_steps=1000,
)

register(
    id="Walker2dBackward-v1",
    entry_point="environments.mujoco:Walker2dBackwardEnv",
    max_episode_steps=1000,
)

register(
    id="Walker2dBalance-v1",
    entry_point="environments.mujoco:Walker2dBalanceEnv",
    max_episode_steps=1000,
)

register(
    id="Walker2dJump-v1",
    entry_point="environments.mujoco:Walker2dJumpEnv",
    max_episode_steps=1000,
)

register(
    id="Walker2dCrawl-v1",
    entry_point="environments.mujoco:Walker2dCrawlEnv",
    max_episode_steps=1000,
)


# Forward and backward
register(
    id="Walker2dPatrol-v1",
    entry_point="environments.mujoco:Walker2dPatrolEnv",
    max_episode_steps=10000,
)

# Forward and jump
register(
    id="Walker2dHurdle-v1",
    entry_point="environments.mujoco:Walker2dHurdleEnv",
    max_episode_steps=5000,
)

# Forward, jump and crawl
register(
    id="Walker2dObstacleCourse-v1",
    entry_point="environments.mujoco:Walker2dObstacleCourseEnv",
    max_episode_steps=5000,
)

# Jaco
# ----------------------------------------
register(
    id="JacoReach-v1",
    entry_point="environments.mujoco:JacoReachEnv",
    max_episode_steps=100,
)

register(
    id="JacoReachNoFinger-v1",
    entry_point="environments.mujoco:JacoReachEnv",
    max_episode_steps=200,
    kwargs={"no_finger": True},
)

register(
    id="JacoReachMT4-v1",
    entry_point="environments.mujoco:JacoReachMT4Env",
    max_episode_steps=200,
)

register(
    id="JacoReachMultistage-v1",
    entry_point="environments.mujoco:JacoReachMultistageMTEnv",
    max_episode_steps=200,
    kwargs={"version": 1},
)

register(
    id="JacoReachMultistageRew-v1",
    entry_point="environments.mujoco:JacoReachMultistageMTEnv",
    max_episode_steps=200,
    kwargs={"version": 1, "time_reward": False},
)

register(
    id="JacoReachMultistage-v2",
    entry_point="environments.mujoco:JacoReachMultistageMTEnv",
    max_episode_steps=200,
    kwargs={"version": 2},
)

register(
    id="JacoReachMultistage-v3",
    entry_point="environments.mujoco:JacoReachMultistageMTEnv",
    max_episode_steps=200,
    kwargs={"version": 3},
)

register(
    id="JacoReachMT5-v1",
    entry_point="environments.mujoco:JacoReachMT5Env",
    max_episode_steps=200,
)

register(
    id="JacoReachMT3-v1",
    entry_point="environments.mujoco:JacoReachMT3Env",
    max_episode_steps=200,
)

register(
    id="JacoStay-v1",
    entry_point="environments.mujoco:JacoReachMultistageEnv",
    max_episode_steps=200,
    kwargs={
        "task_id": 0,
        "num_tasks": 1,
        "goal_locations": [None, None, None],
        "sparse_reward": True,
        "time_reward": False,
        "include_task_id": False,
    },
)

register(
    id="JacoPick-v1",
    entry_point="environments.mujoco:JacoPickEnv",
    max_episode_steps=200,
)

register(
    id="JacoCatch-v1",
    entry_point="environments.mujoco:JacoCatchEnv",
    max_episode_steps=200,
)

register(
    id="JacoToss-v1",
    entry_point="environments.mujoco:JacoTossEnv",
    max_episode_steps=200,
    kwargs={"with_rot": 0},
)

register(
    id="JacoHit-v1",
    entry_point="environments.mujoco:JacoHitEnv",
    max_episode_steps=200,
    kwargs={"with_rot": 0},
)

# Keep picking up
register(
    id="JacoKeepPick-v1",
    entry_point="environments.mujoco:JacoKeepPickEnv",
    max_episode_steps=1000,
)

# Keep catching
register(
    id="JacoKeepCatch-v1",
    entry_point="environments.mujoco:JacoKeepCatchEnv",
    max_episode_steps=1000,
)

# Serve
register(
    id="JacoServe-v1",
    entry_point="environments.mujoco:JacoServeEnv",
    max_episode_steps=300,
    kwargs={"with_rot": 0},
)

# Maze
# ----------------------------------------
register(
    id="Maze20-10-v0",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=1000,
    kwargs={
        "maze_spec": 20,
        "num_tasks": 10,
        "max_episode_steps": 1000,
    },
)

register(
    id="Maze20-20-v0",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=1000,
    kwargs={
        "maze_spec": 20,
        "num_tasks": 20,
        "max_episode_steps": 1000,
    },
)

register(
    id="MazeMedium-v0",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=250,
    kwargs={
        "maze_spec": "MEDIUM_MAZE",
        "num_tasks": 3,
        "max_episode_steps": 250,
    },
)

register(
    id="MazeMedium-v1",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=200,
    kwargs={
        "maze_spec": "MEDIUM_MAZE",
        "num_tasks": 3,
        "max_episode_steps": 200,
    },
)

register(
    id="MazeMediumSingle-v0",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=250,
    kwargs={
        "maze_spec": "MEDIUM_MAZE",
        "num_tasks": 1,
        "max_episode_steps": 250,
    },
)

register(
    id="MazeLarge-v0",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=600,
    kwargs={
        "maze_spec": "LARGE_MAZE",
        "num_tasks": 5,
        "max_episode_steps": 600,
    },
)

register(
    id="MazeLarge-10-v0",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=600,
    kwargs={
        "maze_spec": "LARGE_MAZE",
        "num_tasks": 10,
        "max_episode_steps": 600,
    },
)

register(
    id="MazeLarge-20-v0",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=600,
    kwargs={
        "maze_spec": "LARGE_MAZE",
        "num_tasks": 20,
        "max_episode_steps": 600,
    },
)

register(
    id="MazeLarge-20-v1",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=400,
    kwargs={
        "maze_spec": "LARGE_MAZE",
        "num_tasks": 20,
        "max_episode_steps": 400,
    },
)

register(
    id="MazeMT10-v0",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=500,
    kwargs={
        "maze_spec": "MT10",
        "num_tasks": 10,
        "max_episode_steps": 500,
        "position_only": True,
    }
)

register(
    id="MazeMT10-v1",
    entry_point="environments.maze:MazeMultitaskEnv",
    max_episode_steps=300,
    kwargs={
        "maze_spec": "MT10",
        "num_tasks": 10,
        "max_episode_steps": 300,
    }
)