from typing import Any
import distrax
import jax
import jax.numpy as jnp


Array = Any


class OneHotCategorical(distrax.OneHotCategorical):
    def __init__(self, logits = None, probs = None, dtype = int, argmax=False):
        self.argmax = argmax
        super().__init__(logits=logits, probs=probs, dtype=dtype)
    
    def _sample_n(self, key: jax.random.PRNGKey, n: int) -> Array:
        """See `Distribution._sample_n`."""
        new_shape = (n,) + self.logits.shape[:-1]
        if self.argmax:
            draws = jnp.argmax(self.logits, axis=-1)
            draws = jnp.tile(draws[None], (n,) + (1,) * len(self.logits.shape[:-1]))
        else:
            draws = jax.random.categorical(
                key=key, logits=self.logits, axis=-1, shape=new_shape)
        draws_one_hot = jax.nn.one_hot(
            draws, num_classes=self.num_categories).astype(self._dtype)
        probs = self._pad(self.probs, draws_one_hot.shape)
        draws_one_hot += probs - jax.lax.stop_gradient(probs)
        return draws_one_hot

    def _pad(self, tensor, shape):
        while len(tensor.shape) < len(shape):
            tensor = tensor[None]
        return tensor
