from offline.bppo.types import BPPOBatch
from offline.lbp.utils import compute_means_stds_values
from offline.modules.actor.base import GaussianActor
from offline.modules.critic import VCritic
from offline.types import FloatArray
from offline.utils.data import Dataset


def prepare_bppo_dataset(
    actor: GaussianActor,
    observations: FloatArray,
    vcritic: VCritic,
    batch_size: int = 256,
):
    means, stds, values = compute_means_stds_values(
        actor=actor,
        batch_size=batch_size,
        critic=vcritic,
        observations=observations,
    )
    return Dataset(
        BPPOBatch(
            means=means, observations=observations, stds=stds, values=values
        )
    )
