from flax import nnx
from flax.nnx.nn import initializers
from flax.typing import Dtype
from jax import numpy as jnp
from jax.nn import softplus

from offline.modules.mlp import MLP
from offline.types import ArrayLike


EPS = 1e-5


class DeterministicActor(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        squash: bool,
        param_dtype: Dtype = jnp.float32,
        **kwargs
    ):
        self.model = MLP(
            in_features=observation_dim,
            out_features=action_dim,
            param_dtype=param_dtype,
            rngs=rngs,
            **kwargs
        )
        self.squash = squash

    def __call__(self, observations: ArrayLike):
        actions = self.model(observations)
        if self.squash:
            actions = jnp.tanh(actions)
        return actions, None


class GaussianActor(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        eps: float = EPS,
        param_dtype: Dtype = jnp.float32,
        state_dependent_stds: bool = True,
        **kwargs
    ):
        self.eps = eps
        self.model = MLP(
            in_features=observation_dim,
            out_features=2 * action_dim if state_dependent_stds else action_dim,
            param_dtype=param_dtype,
            rngs=rngs,
            **kwargs
        )
        self.stds: nnx.Param | None
        if state_dependent_stds:
            self.stds = None
        else:
            self.stds = nnx.Param(
                initializers.zeros(rngs.params(), (action_dim,), param_dtype)
            )

    def __call__(self, observations: ArrayLike):
        if self.stds is None:
            outputs = self.model(observations)
            means, stds = jnp.split(outputs, 2, axis=-1)
            stds = softplus(stds) + self.eps
        else:
            means = self.model(observations)
            stds = softplus(self.stds.value) + self.eps
            stds = jnp.broadcast_to(stds, means.shape)
        return means, stds
