from collections.abc import Callable
from typing import Any
import warnings

from gymnasium.core import Env
from gymnasium.spaces import Box
import numpy as np

from offline.types import FloatArray, OfflineData, OfflineDataWithInfos


GetNormalizedScore = Callable[[FloatArray | float], FloatArray]
MakeEnvAndLoadDataDict = Callable[
    ..., tuple[Env, dict[str, np.ndarray], dict[str, Any]]
]


make_env_and_load_data_dict_registry: dict[str, MakeEnvAndLoadDataDict] = {}
ref_scores_registry: dict[str, tuple[float, float]] = {}


def get_ref_scores(name: str) -> tuple[float, float]:
    return ref_scores_registry[name]


def get_registered_datasets():
    return make_env_and_load_data_dict_registry.keys()


def make_env_and_load_data(
    name: str, *args, **kwargs
) -> tuple[Env, OfflineDataWithInfos, GetNormalizedScore | None]:
    env, data_dict, infos = make_env_and_load_data_dict(name, *args, **kwargs)
    assert isinstance(env.action_space, Box)
    for key in ["actions", "dones", "observations", "rewards", "terminals"]:
        assert key in data_dict, f"Dataset is missing key {key}"
    num_samples = data_dict["actions"].shape[0]
    sanity_check_nd("actions", data_dict, num_samples, env.action_space.shape)
    sanity_check_1d("dones", data_dict, num_samples)
    sanity_check_nd(
        "observations", data_dict, num_samples, env.observation_space.shape
    )
    sanity_check_1d("rewards", data_dict, num_samples)
    sanity_check_1d("terminals", data_dict, num_samples)
    actions = np.astype(
        normalize_actions(data_dict["actions"], env.action_space), np.float32
    )
    dones = np.astype(data_dict["dones"], np.bool)
    dones[-1] = True
    observations = np.astype(data_dict["observations"], np.float32)
    rewards = np.astype(data_dict["rewards"], np.float32)
    terminals = np.astype(data_dict["terminals"], np.bool)
    try:
        max_score, min_score = get_ref_scores(name)
        delta_score = max_score - min_score
    except KeyError:
        delta_score = min_score = 0
    return (
        env,
        OfflineDataWithInfos(
            data=OfflineData(
                actions=actions.reshape(num_samples, -1),
                dones=dones,
                observations=observations.reshape(num_samples, -1),
                rewards=rewards,
                terminals=terminals,
            ),
            infos=infos,
        ),
        (
            None
            if delta_score == 0
            else lambda x: np.asarray(100 * (x - min_score) / delta_score)
        ),
    )


def make_env_and_load_data_dict(
    name: str, *args, **kwargs
) -> tuple[Env, dict[str, np.ndarray], dict[str, Any]]:
    return make_env_and_load_data_dict_registry[name](*args, **kwargs)


def normalize_actions(actions: FloatArray, action_space: Box) -> FloatArray:
    low, high = action_space.low, action_space.high
    low, high = low.reshape(1, -1), high.reshape(1, -1)
    denominator = (high - low) / 2
    numerator = actions - (high + low) / 2
    return numerator / denominator


def register_dataset(
    name: str,
    make_env_and_load_data_dict_fn: MakeEnvAndLoadDataDict,
    ref_max_score: float | None = None,
    ref_min_score: float | None = None,
):
    if name in make_env_and_load_data_dict_registry:
        warnings.warn(f"Overriding dataset {name} already in registry.")
    make_env_and_load_data_dict_registry[name] = make_env_and_load_data_dict_fn
    if ref_max_score is not None and ref_min_score is not None:
        ref_scores_registry[name] = (ref_max_score, ref_min_score)


def sanity_check_1d(
    key: str, data_dict: dict[str, np.ndarray], num_samples: int
) -> dict[str, np.ndarray]:
    if data_dict[key].shape == (num_samples, 1):
        data_dict[key] = data_dict[key].ravel()
    assert data_dict[key].shape == (
        num_samples,
    ), f"{key.title()} has wrong shape: {data_dict[key].shape}"
    return data_dict


def sanity_check_nd(
    key: str,
    data_dict: dict[str, np.ndarray],
    num_samples: int,
    true_shape: tuple[int, ...] | None,
) -> dict[str, np.ndarray]:
    shape = data_dict[key].shape
    if true_shape is None:
        check = shape[0] == num_samples
    else:
        true_shape = (num_samples,) + true_shape
        check = shape == true_shape
    key = key.title()
    assert check, f"{key} shape does not match env: {shape} vs {true_shape}"
    return data_dict
