import flax.linen as nn
import jax
import jax.numpy as jnp


class Lagrangian(nn.Module):
    ent_coef_init: float = 1e-1

    @nn.compact
    def __call__(self) -> jnp.ndarray:
        log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)))
        return jnp.full((), log_ent_coef)


class Constant(nn.Module):
    val: float = 1.0

    @nn.compact
    def __call__(self) -> jnp.ndarray:
        log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.val)))
        return jnp.full((), log_ent_coef)

