import jax
import jax.numpy as jnp
from flax import nnx
import distrax


class LinearScheduler(nnx.Module):
    def __init__(self,
                 steps: int,
                 min_prob: float = 0.05,
                 max_prob: float = 1.,
                 *, rngs: nnx.Rngs):
        self.steps = steps
        self.clip_min = 1 - max_prob
        self.clip_max = 1 - min_prob
        self.step_counter = nnx.Variable(jnp.zeros(shape=(), dtype=jnp.int32))
        self.rngs = rngs

    def __call__(self, ):
        v = self.step_counter.value
        prob = (v / self.steps).clip(self.clip_min, self.clip_max)
        result = distrax.Categorical(probs=jnp.asarray([prob, 1 - prob]), ).sample(seed=self.rngs())
        self.step_counter.value += 1
        return result, prob



if __name__ =='__main__':
    scheduler = LinearScheduler(1000, rngs=nnx.Rngs(32))
    for _ in range(1000):
        print(scheduler())



