from functools import partial
from warnings import catch_warnings, filterwarnings

from gymnasium.envs.registration import register

from offline.envs.d4rl.utils import d4rl_make_env_and_load_data_dict
from offline.envs.d4rl.wrapper import GymEnvWrapper
from offline.envs.registration import register_dataset


ANT_MAZE_DATASET_NAMES = (
    "antmaze-umaze-v0",
    "antmaze-umaze-diverse-v0",
    "antmaze-medium-play-v0",
    "antmaze-medium-diverse-v0",
    "antmaze-large-diverse-v0",
    "antmaze-large-play-v0",
    "antmaze-eval-umaze-v0",
    "antmaze-eval-umaze-diverse-v0",
    "antmaze-eval-medium-play-v0",
    "antmaze-eval-medium-diverse-v0",
    "antmaze-eval-large-diverse-v0",
    "antmaze-eval-large-play-v0",
    "antmaze-umaze-v1",
    "antmaze-umaze-diverse-v1",
    "antmaze-medium-play-v1",
    "antmaze-medium-diverse-v1",
    "antmaze-large-diverse-v1",
    "antmaze-large-play-v1",
    "antmaze-umaze-v2",
    "antmaze-umaze-diverse-v2",
    "antmaze-medium-play-v2",
    "antmaze-medium-diverse-v2",
    "antmaze-large-diverse-v2",
    "antmaze-large-play-v2",
)

SMALL_ENVIRONMENTS = (
    "antmaze-umaze-v0",
    "antmaze-umaze-diverse-v0",
    "antmaze-eval-umaze-v0",
    "antmaze-eval-umaze-diverse-v0",
    "antmaze-umaze-v1",
    "antmaze-umaze-diverse-v1",
    "antmaze-umaze-v2",
    "antmaze-umaze-diverse-v2",
)


def ant_maze_make_env_and_load_data_dict(
    dataset_name: str, env_id: str, **kwargs
):
    with catch_warnings():
        filterwarnings("ignore", category=UserWarning, module="gym.spaces.box")
        env, data_dict, infos = d4rl_make_env_and_load_data_dict(
            dataset_name, env_id, **kwargs
        )
    data_dict["dones"] = data_dict["timeouts"]
    return env, data_dict, infos


def register_ant_maze_datasets():
    for dataset_name in ANT_MAZE_DATASET_NAMES:
        register(
            id=dataset_name,
            entry_point=GymEnvWrapper,
            max_episode_steps=(
                700 if dataset_name in SMALL_ENVIRONMENTS else 1000
            ),
            kwargs={"gym_env_id": dataset_name},
        )
        register_dataset(
            name=dataset_name,
            make_env_and_load_data_dict_fn=partial(
                ant_maze_make_env_and_load_data_dict,
                dataset_name=dataset_name,
                env_id=dataset_name,
            ),
            ref_max_score=1,
            ref_min_score=0,
        )
