from flax import nnx
import numpy as np

from offline.svr.core import compute_batch_log_likelihoods
from offline.svr.types import SVRBatch
from offline.modules.actor.base import DeterministicActor
from offline.types import FloatArray, OfflineData
from offline.utils.data import Dataset, DataLoader
from offline.utils.dataset import prepare_sa_dataset, prepare_q_learning_dataset


def compute_log_likelihoods(
    actor: DeterministicActor,
    data: OfflineData,
    sample_std: float,
    batch_size: int = 256,
) -> FloatArray:
    graphdef, graphstate = nnx.split(actor)
    loader = DataLoader(
        batch_size=batch_size, dataset=prepare_sa_dataset(data), drop_last=False
    )
    outputs_list = []
    for batch in loader:
        outputs = compute_batch_log_likelihoods(
            actions=batch.actions,
            graphdef=graphdef,
            graphstate=graphstate,
            observations=batch.observations,
            sample_std=sample_std,
        )
        outputs_list.append(outputs)
    return np.concatenate(outputs_list)


def prepare_svr_dataset(
    actor: DeterministicActor,
    data: OfflineData,
    sample_std: float,
    batch_size: int = 256,
):
    indices = np.logical_or(np.logical_not(data.dones), data.terminals)
    q_learning_dataset = prepare_q_learning_dataset(data=data)
    batch = q_learning_dataset.data
    log_likelihoods = compute_log_likelihoods(
        actor=actor, batch_size=batch_size, data=data, sample_std=sample_std
    )
    svr_batch = SVRBatch(
        actions=batch.actions,
        dones=batch.dones,
        log_likelihoods=log_likelihoods[indices],
        next_observations=batch.next_observations,
        observations=batch.observations,
        rewards=batch.rewards
    )
    return Dataset(svr_batch)
