from gym.envs.registration import register
from d4rl.gym_mujoco import (
    HALFCHEETAH_EXPERT_SCORE, HALFCHEETAH_RANDOM_SCORE, ANT_EXPERT_SCORE, ANT_RANDOM_SCORE,
    WALKER_EXPERT_SCORE, WALKER_RANDOM_SCORE, HOPPER_EXPERT_SCORE, HOPPER_RANDOM_SCORE
)

from d4rl_additions import dangerous_ant_maze

register(
    id='halfcheetah-adv-random-v0',
    entry_point='d4rl_additions.gym_envs:get_cheetah_env',
    max_episode_steps=1000,
    kwargs={
        'deprecated': False,
        'ref_min_score': HALFCHEETAH_RANDOM_SCORE,
        'ref_max_score': HALFCHEETAH_EXPERT_SCORE,
        'dataset_filename': 'd4rl_additions/datasets/halfcheetah_adv_random.hdf5',
    }
)

register(
    id='halfcheetah-adv-medium-v0',
    entry_point='d4rl_additions.gym_envs:get_cheetah_env',
    max_episode_steps=1000,
    kwargs={
        'deprecated': False,
        'ref_min_score': HALFCHEETAH_RANDOM_SCORE,
        'ref_max_score': HALFCHEETAH_EXPERT_SCORE,
        'dataset_filename': 'd4rl_additions/datasets/halfcheetah_adv_medium.hdf5',
    }
)

register(
    id='halfcheetah-adv-expert-v0',
    entry_point='d4rl_additions.gym_envs:get_cheetah_env',
    max_episode_steps=1000,
    kwargs={
        'deprecated': False,
        'ref_min_score': HALFCHEETAH_RANDOM_SCORE,
        'ref_max_score': HALFCHEETAH_EXPERT_SCORE,
        'dataset_filename': 'd4rl_additions/datasets/halfcheetah_adv_expert.hdf5',
    }
)

register(
    id='walker2d-adv-random-v0',
    entry_point='d4rl_additions.gym_envs:get_walker_env',
    max_episode_steps=1000,
    kwargs={
        'deprecated': False,
        'ref_min_score': WALKER_RANDOM_SCORE,
        'ref_max_score': WALKER_EXPERT_SCORE,
        'dataset_filename': 'd4rl_additions/datasets/walker2d_adv_random.hdf5',
    }
)

register(
    id='walker2d-adv-medium-v0',
    entry_point='d4rl_additions.gym_envs:get_walker_env',
    max_episode_steps=1000,
    kwargs={
        'deprecated': False,
        'ref_min_score': WALKER_RANDOM_SCORE,
        'ref_max_score': WALKER_EXPERT_SCORE,
        'dataset_filename': 'd4rl_additions/datasets/walker2d_adv_medium.hdf5',
    }
)

register(
    id='walker2d-adv-expert-v0',
    entry_point='d4rl_additions.gym_envs:get_walker_env',
    max_episode_steps=1000,
    kwargs={
        'deprecated': False,
        'ref_min_score': WALKER_RANDOM_SCORE,
        'ref_max_score': WALKER_EXPERT_SCORE,
        'dataset_filename': 'd4rl_additions/datasets/walker2d_adv_expert.hdf5',
    }
)

register(
    id='hopper-adv-random-v0',
    entry_point='d4rl_additions.gym_envs:get_hopper_env',
    max_episode_steps=1000,
    kwargs={
        'deprecated': False,
        'ref_min_score': HOPPER_RANDOM_SCORE,
        'ref_max_score': HOPPER_EXPERT_SCORE,
        'dataset_filename': 'd4rl_additions/datasets/hopper_adv_random.hdf5',
    }
)

register(
    id='hopper-adv-medium-v0',
    entry_point='d4rl_additions.gym_envs:get_hopper_env',
    max_episode_steps=1000,
    kwargs={
        'deprecated': False,
        'ref_min_score': HOPPER_RANDOM_SCORE,
        'ref_max_score': HOPPER_EXPERT_SCORE,
        'dataset_filename': 'd4rl_additions/datasets/hopper_adv_medium.hdf5',
    }
)

register(
    id='hopper-adv-expert-v0',
    entry_point='d4rl_additions.gym_envs:get_hopper_env',
    max_episode_steps=1000,
    kwargs={
        'deprecated': False,
        'ref_min_score': HOPPER_RANDOM_SCORE,
        'ref_max_score': HOPPER_EXPERT_SCORE,
        'dataset_filename': 'd4rl_additions/datasets/hopper_adv_expert.hdf5',
    }
)

register(
    id='dangerous-antmaze-umaze-v0',
    entry_point='d4rl_additions.dangerous_ant_maze:make_dangerous_ant_maze_env',
    max_episode_steps=700,
    kwargs={
        'maze_map': dangerous_ant_maze.U_MAZE_TEST,
        'reward_type': 'sparse',
        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5',
        'non_zero_reset': False,
        'eval': True,
        'maze_size_scaling': 4.0,
        'ref_min_score': 0.0,
        'ref_max_score': 1.0,
        'v2_resets': True,
    }
)

register(
    id='dangerous-antmaze-umaze-dense-v0',
    entry_point='d4rl_additions.dangerous_ant_maze:make_dangerous_ant_maze_env',
    max_episode_steps=700,
    kwargs={
        'maze_map': dangerous_ant_maze.U_MAZE_TEST,
        'reward_type': 'dense',
        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5',
        'non_zero_reset': False,
        'eval': True,
        'maze_size_scaling': 4.0,
        'ref_min_score': 0.0,
        'ref_max_score': 1.0,
        'v2_resets': True,
    }
)


def modify_d4rl_dataset(dataset, env):
    if hasattr(env, 'wrapped') and isinstance(env.wrapped_env, dangerous_ant_maze.DangerousMazeEnv):
        dangerous_ant_maze.modify_dataset(dataset, env)
    return dataset
