from functools import partial
from pathlib import Path
from warnings import catch_warnings, filterwarnings

import gymnasium
from gymnasium.envs.registration import register
import numpy as np

from offline.envs.d4rl.infos import DATASET_URLS, REF_MAX_SCORE, REF_MIN_SCORE
from offline.envs.d4rl.wrapper import GymEnvWrapper
from offline.envs.registration import register_dataset
from offline.envs.utils import DATA_ROOT, download, load_data_dict


D4RL_DATA_ROOT = DATA_ROOT / "D4RL"


def d4rl_make_env_and_load_data_dict(dataset_name: str, env_id: str, **kwargs):
    with catch_warnings():
        filterwarnings(
            "ignore",
            category=DeprecationWarning,
            module="gymnasium.envs.registration",
        )
        filterwarnings("ignore", category=UserWarning, module="gym.spaces.box")
        env = gymnasium.make(env_id, **kwargs)
    data_dict = load_data_dict(get_d4rl_data_path(dataset_name))
    data_dict["dones"] = np.logical_or(
        data_dict["terminals"], data_dict["timeouts"]
    )
    infos = {
        key.removeprefix("infos/"): value
        for key, value in data_dict.items()
        if key.startswith("infos/")
    }
    return env, data_dict, infos


def get_d4rl_data_path(dataset_name: str) -> Path:
    D4RL_DATA_ROOT.mkdir(exist_ok=True, parents=True)
    url = DATASET_URLS[dataset_name]
    data_name = url.split("/")[-1]
    data_path = D4RL_DATA_ROOT / data_name
    if not data_path.is_file():
        try:
            with open(data_path, "wb") as file:
                for chunk in download(url, desc=data_name, leave=False):
                    file.write(chunk)
        except BaseException:
            data_path.unlink(missing_ok=True)
            raise
    return data_path


def register_d4rl_dataset(
    dataset_name: str,
    env_id: str,
    ref_max_score: float | None = None,
    ref_min_score: float | None = None,
):
    if ref_max_score is None:
        ref_max_score = REF_MAX_SCORE[dataset_name]
    if ref_min_score is None:
        ref_min_score = REF_MIN_SCORE[dataset_name]
    register_dataset(
        name=dataset_name,
        make_env_and_load_data_dict_fn=partial(
            d4rl_make_env_and_load_data_dict,
            dataset_name=dataset_name,
            env_id=env_id,
        ),
        ref_max_score=ref_max_score,
        ref_min_score=ref_min_score,
    )


def register_d4rl_env_and_dataset(
    dataset_name: str,
    max_episode_steps: int,
    ref_max_score: float | None = None,
    ref_min_score: float | None = None,
):
    register(
        id=dataset_name,
        entry_point=GymEnvWrapper,
        max_episode_steps=max_episode_steps,
        kwargs={"gym_env_id": dataset_name},
    )
    register_d4rl_dataset(
        dataset_name=dataset_name,
        env_id=dataset_name,
        ref_max_score=ref_max_score,
        ref_min_score=ref_min_score,
    )
