from typing import Any

from jax import Array, lax, numpy as jnp
from jax.random import fold_in, normal

from offline.modules.actor.base import DeterministicActor, GaussianActor
from offline.modules.actor.ensemble import (
    DeterministicActorEnsemble,
    GaussianActorEnsemble,
)
from offline.types import ArrayLike

Actor = (
    DeterministicActor
    | DeterministicActorEnsemble
    | GaussianActor
    | GaussianActorEnsemble
)


def compute_actions(
    actor: Actor,
    key: Array,
    key_data: Any,
    observations: ArrayLike,
    squash: bool,
):
    means, stds = actor(observations)
    means = lax.cond(squash, jnp.tanh, lambda x: x, means)
    if stds is None:
        return means
    key = fold_in(key, key_data)
    actions = means + stds * normal(key, stds.shape)
    return actions


def gaussian_log_likelihood(
    means: ArrayLike | float,
    samples: ArrayLike,
    stds: ArrayLike | float,
    reduce: bool = True,
) -> Array:
    log_stds = jnp.log(stds)
    log_likelihood = -0.5 * jnp.square((means - samples) / stds) - log_stds
    if reduce:
        return jnp.sum(log_likelihood, axis=-1)
    return log_likelihood
