from functools import partial
from pathlib import Path
from typing import Any

from gymnasium.core import Env
import numpy as np
import ogbench
import tomlkit

from offline.envs.registration import register_dataset
from offline.envs.utils import DATA_ROOT


def load_ogbench_datasets() -> tuple[str, ...]:
    path = Path(__file__).parent / "ogbench.toml"
    with open(path, "r", encoding="utf-8") as file:
        datasets = tuple(tomlkit.load(file)["datasets"])  # type: ignore
    return datasets


OGBENCH_DATA_ROOT = DATA_ROOT / "OGBench"
OGBENCH_DATASETS = load_ogbench_datasets()


def default_ogbench_make_env_and_load_data_dict(
    name: str, **kwargs
) -> tuple[Env, dict[str, np.ndarray], dict[str, Any]]:
    env, data_dict, _ = ogbench.make_env_and_datasets(  # type: ignore
        name,
        add_info=True,
        compact_dataset=True,
        cur_env=None,
        dataset_dir=str(OGBENCH_DATA_ROOT),
        dataset_path=None,
        env_only=False,
        **kwargs,
    )
    data_dict["dones"] = 1 - data_dict["valids"]
    data_dict["terminals"] = data_dict["masks"]
    infos = {
        key: value
        for key, value in data_dict.items()
        if key in ("qpos", "qvel", "button_states")
    }
    return env, data_dict, infos  # type: ignore


def register_ogbench_datasets():
    OGBENCH_DATA_ROOT.mkdir(parents=True, exist_ok=True)
    for name in OGBENCH_DATASETS:
        register_dataset(
            name=name,
            make_env_and_load_data_dict_fn=partial(
                default_ogbench_make_env_and_load_data_dict, name=name
            ),
        )
