from typing import Any

import h5py
import numpy as np
import tomlkit

from offline.envs.custom import DATASETS_CONFIG_ROOT
from offline.envs.utils import DATA_ROOT


def label_to_config_and_data_path(force: bool, label: str):
    DATA_ROOT.mkdir(exist_ok=True)
    DATASETS_CONFIG_ROOT.mkdir(exist_ok=True, parents=True)

    label_ = label.replace("-", "_")
    config_file_path = DATASETS_CONFIG_ROOT / f"{label_}.toml"
    data_file_path = DATA_ROOT / f"{label_}.hdf5"
    if not force:
        if config_file_path.is_file():
            raise ValueError(f"Config file already exists: {config_file_path}")
        if data_file_path.is_file():
            raise ValueError(f"Data file already exists: {data_file_path}")
    return config_file_path, data_file_path


def save_dataset(
    *,
    env_id: str,
    force: bool,
    label: str,
    results: dict[str, np.ndarray],
    max_score: float | None = None,
    min_score: float | None = None,
):
    config_file_path, data_file_path = label_to_config_and_data_path(
        force=force, label=label
    )

    with h5py.File(data_file_path, "w") as file:
        for key, value in results.items():
            file[key] = value

    dataset_config: dict[str, Any] = {"dataset_name": label, "env_id": env_id}
    if max_score is not None:
        dataset_config["ref_max_score"] = max_score
    if min_score is not None:
        dataset_config["ref_min_score"] = min_score
    with open(config_file_path, "w", encoding="utf-8") as file:
        tomlkit.dump(dataset_config, file)
