from collections.abc import Callable
from typing import NamedTuple, TypeVar

import numpy as np
from numpy.testing import suppress_warnings
from numpy.typing import DTypeLike

from offline.types import (
    BoolArray,
    FloatArray,
    OfflineData,
    QLearningBatch,
    SaBatch,
    SarsaBatch,
    VLearningBatch,
)
from offline.utils.data import Dataset


T = TypeVar("T", bound=NamedTuple)


def compute_next(data: OfflineData) -> dict[str, np.ndarray]:
    observations = data.observations
    actions = data.actions
    rewards = data.rewards
    terminals = data.terminals
    next_observations = np.roll(observations, -1, axis=0)
    assert np.all(next_observations[:-1] == observations[1:])
    next_actions = np.roll(actions, -1, axis=0)
    assert np.all(next_actions[:-1] == actions[1:])
    # remove end-of-trajectory samples if the trajectory was truncated
    indices = np.logical_or(np.logical_not(data.dones), terminals)
    observations = observations[indices]
    actions = actions[indices]
    rewards = rewards[indices]
    dones = terminals[indices]
    next_observations = next_observations[indices]
    next_actions = next_actions[indices]
    return {
        "actions": actions,
        "dones": dones,
        "next_actions": next_actions,
        "next_observations": next_observations,
        "observations": observations,
        "rewards": rewards,
    }


def compute_returns(rewards: FloatArray, dones: BoolArray):
    indices = np.insert(np.flatnonzero(dones) + 1, 0, 0)
    returns = [rewards[i:j].sum() for i, j in zip(indices[:-1], indices[1:])]
    return np.asarray(returns)


def normalize_rewards(data: OfflineData, eps: float = 1e-5):
    dones = np.logical_or(data.dones, data.terminals)
    returns = compute_returns(data.rewards, dones)
    max_return, min_return = returns.max(), returns.min()
    return 1 / (float(max_return - min_return) + eps)


def prepare_dataset(
    batch_type: Callable[..., T],
    data_dict: dict[str, np.ndarray],
    *keys: str,
    dtype: DTypeLike = np.float32,
) -> Dataset[T]:
    data = batch_type(**{key: data_dict[key].astype(dtype) for key in keys})
    return Dataset(data)


def prepare_q_learning_dataset(
    data: OfflineData, dtype: DTypeLike = np.float32
) -> Dataset[QLearningBatch]:
    data_dict = compute_next(data)
    return prepare_dataset(
        QLearningBatch,
        data_dict,
        "actions",
        "dones",
        "next_observations",
        "observations",
        "rewards",
        dtype=dtype,
    )


def prepare_sa_dataset(data: OfflineData, dtype: DTypeLike = np.float32):
    data_dict = {"actions": data.actions, "observations": data.observations}
    return prepare_dataset(
        SaBatch, data_dict, "actions", "observations", dtype=dtype
    )


def prepare_sarsa_dataset(
    data: OfflineData, dtype: DTypeLike = np.float32
) -> Dataset[SarsaBatch]:
    data_dict = compute_next(data)
    return prepare_dataset(
        SarsaBatch,
        data_dict,
        "actions",
        "dones",
        "next_actions",
        "next_observations",
        "observations",
        "rewards",
        dtype=dtype,
    )


def prepare_v_learning_dataset(
    data: OfflineData, dtype: DTypeLike = np.float32
) -> Dataset[VLearningBatch]:
    data_dict = compute_next(data)
    return prepare_dataset(
        VLearningBatch,
        data_dict,
        "dones",
        "next_observations",
        "observations",
        "rewards",
        dtype=dtype,
    )


def unsquash_actions(actions: FloatArray, max_action: float = 10):
    # the default max_action is set to 10 because tanh(10)=1 when computed
    # using float32
    with suppress_warnings() as sup:
        sup.filter(RuntimeWarning, "divide by zero encountered in arctanh")
        return np.clip(np.arctanh(actions), -max_action, max_action)
