from flax import nnx
import numpy as np

from offline.lbp.core import compute_batch_means_stds_values
from offline.lbp.modules import BehaviorState
from offline.lbp.types import ActorBatch, RegularizerBatch, QLearningBatch
from offline.modules.actor.base import GaussianActor
from offline.modules.critic import VCritic
from offline.types import BoolArray, FloatArray, OfflineData, RegressionBatch
from offline.utils.data import ArrayDataLoader, Dataset


def compute_means_stds_values(
    actor: GaussianActor,
    critic: VCritic,
    observations: FloatArray,
    batch_size: int = 256,
) -> tuple[FloatArray, FloatArray, FloatArray]:
    train_state = BehaviorState(actor=actor, critic=critic)
    graphdef, graphstate = nnx.split(train_state)
    loader = ArrayDataLoader(
        batch_size=batch_size, data=observations, drop_last=False
    )
    means_list, stds_list, values_list = [], [], []
    for batch in loader:
        means, stds, values = compute_batch_means_stds_values(
            graphdef=graphdef, graphstate=graphstate, observations=batch
        )
        means_list.append(means)
        stds_list.append(stds)
        values_list.append(values)
    return (
        np.concatenate(means_list),
        np.concatenate(stds_list),
        np.concatenate(values_list),
    )


def compute_discounted_returns(
    dones: BoolArray, gamma: float, rewards: FloatArray
):
    last_return = rewards[-1]
    returns = [last_return]
    for reward, done in zip(rewards[-2::-1], dones[-2::-1]):
        last_return = reward + gamma * (1 - done) * last_return
        returns.append(last_return)
    return np.asarray(returns[::-1])


def get_max_min_reward(
    dataset_name: str,
) -> tuple[float, float] | tuple[None, None]:
    if dataset_name.endswith("-v2"):
        if dataset_name.startswith("halfcheetah-"):
            return 13.854624, -3.6640236
        if dataset_name.startswith("hopper-"):
            return 6.628322, -1.1588151
        if dataset_name.startswith("walker2d-"):
            return 8.469034, -2.5572553
    if dataset_name.endswith("-v0"):
        if dataset_name.startswith("pen-"):
            return 60.998226, -7.179708
    return None, None


def normalize_rewards(data: OfflineData, gamma: float, eps: float = 1e-5):
    dones = np.logical_or(data.dones, data.terminals)
    discounted_returns = compute_discounted_returns(
        dones=dones, gamma=gamma, rewards=data.rewards
    )
    mean_discounted_returns = float(np.abs(np.mean(discounted_returns)))
    return 1 / (mean_discounted_returns + eps)


def prepare_actor_regularizer_qlearning_dataset(
    actor: GaussianActor,
    critic: VCritic,
    data: OfflineData,
    offset: float,
    batch_size: int = 256,
):
    dones = np.logical_or(data.dones, data.terminals)
    next_indices = np.arange(1, dones.size + 1)
    next_indices[-1] -= 1
    next_observations = data.observations[next_indices]
    # remove end-of-trajectory samples if the trajectory was truncated
    indices = np.logical_or(np.logical_not(dones), data.terminals)
    means, stds, values = compute_means_stds_values(
        actor=actor,
        batch_size=batch_size,
        critic=critic,
        observations=data.observations,
    )
    actor_batch = ActorBatch(means=means, observations=data.observations)
    regularizer_batch = RegularizerBatch(
        baseline=values - offset,
        means=means,
        observations=data.observations,
        stds=stds,
    )
    next_means = means[next_indices]
    q_batch = QLearningBatch(
        actions=data.actions[indices],
        dones=dones[indices],
        next_means=next_means[indices],
        next_observations=next_observations[indices],
        observations=data.observations[indices],
        rewards=data.rewards[indices],
    )
    return Dataset(actor_batch), Dataset(regularizer_batch), Dataset(q_batch)


def prepare_v_learning_dataset(data: OfflineData, gamma: float):
    dones = np.logical_or(data.dones, data.terminals)
    discounted_returns = compute_discounted_returns(
        dones=dones, gamma=gamma, rewards=data.rewards
    )
    batch = RegressionBatch(
        features=data.observations, targets=discounted_returns
    )
    return Dataset(batch)
