from functools import partial
from pathlib import Path

import gymnasium
from gymnasium.envs.registration import WrapperSpec, register
import tomlkit

from offline.envs.registration import register_dataset
from offline.envs.utils import DATA_ROOT, load_data_dict


DATASETS_CONFIG_ROOT = Path("config") / "custom" / "datasets"
ENVS_CONFIG_ROOT = Path("config") / "custom" / "envs"


def custom_make_env_and_load_data_dict_fn(dataset_name: str, env_id: str):
    env = gymnasium.make(env_id)
    file_name = dataset_name.replace("-", "_") + ".hdf5"
    data_dict = load_data_dict(DATA_ROOT / file_name)
    infos = {
        key.removeprefix("infos/"): value
        for key, value in data_dict.items()
        if key.startswith("infos/")
    }
    return env, data_dict, infos


def register_custom_dataset(dataset_name: str, env_id: str, **kwargs):
    register_dataset(
        name=dataset_name,
        make_env_and_load_data_dict_fn=partial(
            custom_make_env_and_load_data_dict_fn,
            dataset_name=dataset_name,
            env_id=env_id,
        ),
        **kwargs
    )


def register_custom_datasets_and_envs():
    if DATASETS_CONFIG_ROOT.is_dir():
        for path in DATASETS_CONFIG_ROOT.iterdir():
            with open(path, encoding="utf-8") as config:
                register_custom_dataset(**tomlkit.load(config))  # type: ignore

    if ENVS_CONFIG_ROOT.is_dir():
        for path in ENVS_CONFIG_ROOT.iterdir():
            with open(path, encoding="utf-8") as config:
                register_custom_envs(**tomlkit.load(config))  # type: ignore


def register_custom_envs(
    additional_wrappers=(), max_episode_steps: int | None = None, **kwargs
):
    register(
        additional_wrappers=tuple(
            WrapperSpec(**config) for config in additional_wrappers
        ),
        max_episode_steps=(
            None if max_episode_steps is None else int(max_episode_steps)
        ),
        **kwargs
    )
