from typing import Callable, Sequence, Optional
import jax
import jax.numpy as jnp
from flax import linen as nn
from src.models.common import MLP


class Value(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jax.Array], jax.Array] = nn.relu
    layernorm: Optional[bool] = False
    dropout_rate: Optional[float] = None

    @nn.compact
    def __call__(
        self,
        observations: jax.Array,
        training: bool = True,
    ) -> jax.Array:

        critic = MLP(
            (*self.hidden_dims, 1),
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, training=training)

        return jnp.squeeze(critic, -1)
