from flax import nnx
import jax
import jax.numpy as jnp


class PreferenceSampler(nnx.Module):
    def __init__(self,
                 reward_dims,
                 *,
                 rngs
                 ):
        self.reward_dim = reward_dims
        self.estimator = nnx.BatchNorm(num_features=reward_dims, momentum=0.9, epsilon=1e-5,
                                       use_bias=False, use_scale=False,
                                       rngs=rngs)
        self.rngs = rngs

    def __call__(self, values):
        return self.estimator(values)

    def sample_weight(self, placeholder):
        mu = jax.lax.stop_gradient(self.estimator.mean.value)
        half_sigma_square = 0.5 * jax.lax.stop_gradient(self.estimator.var.value)
        weight = jnp.abs(mu) - half_sigma_square
        return jax.random.dirichlet(self.rngs(), self.reward_dim * jax.nn.softmax(weight, axis=-1),
                                    shape=(placeholder.shape[0], ))





