import numpy as np

from offline.hdr.types import ActorBatch, QLearningBatch
from offline.lbp.utils import compute_means_stds_values
from offline.modules.actor.base import GaussianActor
from offline.modules.critic import VCritic
from offline.types import OfflineData
from offline.utils.data import Dataset


def prepare_actor_q_learning_dataset(
    actor: GaussianActor,
    critic: VCritic,
    data: OfflineData,
    batch_size: int = 256,
):
    means, stds, values = compute_means_stds_values(
        actor=actor,
        batch_size=batch_size,
        critic=critic,
        observations=data.observations,
    )
    dones = np.logical_or(data.timeouts, data.terminals)
    next_indices = np.arange(1, dones.size + 1)
    next_indices[-1] -= 1
    # remove end-of-trajectory samples if the trajectory was truncated
    indices = np.logical_not(data.timeouts)
    actor_batch = ActorBatch(
        means=means, observations=data.observations, stds=stds
    )
    next_indices = next_indices[indices]
    q_learning_batch = QLearningBatch(
        actions=data.actions[indices],
        dones=dones[indices],
        next_means=means[next_indices],
        next_observations=data.observations[next_indices],
        next_stds=stds[next_indices],
        next_values=values[next_indices],
        observations=data.observations[indices],
        rewards=data.rewards[indices],
    )
    return Dataset(actor_batch), Dataset(q_learning_batch)
