from typing import Any

from flax import nnx
from jax import Array, numpy as jnp

from offline.modules.actor.base import GaussianActor
from offline.modules.policy import Policy
from offline.types import ArrayLike


class BCPolicy(Policy[None]):
    def __init__(
        self,
        action_dim: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        unsquash: bool,
        **kwargs
    ):
        self.actor = GaussianActor(
            action_dim=action_dim,
            observation_dim=observation_dim,
            rngs=rngs,
            **kwargs
        )
        self.unsquash = unsquash

    def __call__(
        self, observations: ArrayLike, state: None
    ) -> tuple[Array, None, dict[str, Any]]:
        actions, _ = self.actor(observations)
        if not self.unsquash:
            actions = jnp.tanh(actions)
        return actions, state, {}
