import time

from src.rllib.models.action_dist import ActionDistribution
from src.rllib.models.modelv2 import ModelV2
from src.rllib.utils.annotations import override
from src.rllib.utils.framework import try_import_jax, try_import_tfp
from src.rllib.utils.typing import TensorType, List

jax, flax = try_import_jax()
tfp = try_import_tfp()


class JAXDistribution(ActionDistribution):
    """Wrapper class for JAX distributions."""

    @override(ActionDistribution)
    def __init__(self, inputs: List[TensorType], model: ModelV2):
        super().__init__(inputs, model)
        # Store the last sample here.
        self.last_sample = None
        # Use current time as pseudo-random number generator's seed.
        self.prng_key = jax.random.PRNGKey(seed=int(time.time()))

    @override(ActionDistribution)
    def logp(self, actions: TensorType) -> TensorType:
        return self.dist.log_prob(actions)

    @override(ActionDistribution)
    def entropy(self) -> TensorType:
        return self.dist.entropy()

    @override(ActionDistribution)
    def kl(self, other: ActionDistribution) -> TensorType:
        return self.dist.kl_divergence(other.dist)

    @override(ActionDistribution)
    def sample(self) -> TensorType:
        # Update the state of our PRNG.
        _, self.prng_key = jax.random.split(self.prng_key)
        self.last_sample = jax.random.categorical(self.prng_key, self.inputs)
        return self.last_sample

    @override(ActionDistribution)
    def sampled_action_logp(self) -> TensorType:
        assert self.last_sample is not None
        return self.logp(self.last_sample)


class JAXCategorical(JAXDistribution):
    """Wrapper class for a JAX Categorical distribution."""

    @override(ActionDistribution)
    def __init__(self, inputs, model=None, temperature=1.0):
        if temperature != 1.0:
            assert temperature > 0.0, \
                "Categorical `temperature` must be > 0.0!"
            inputs /= temperature
        super().__init__(inputs, model)
        self.dist = tfp.experimental.substrates.jax.distributions.Categorical(
            logits=self.inputs)

    @override(ActionDistribution)
    def deterministic_sample(self):
        self.last_sample = self.inputs.argmax(axis=1)
        return self.last_sample

    @staticmethod
    @override(ActionDistribution)
    def required_model_output_shape(action_space, model_config):
        return action_space.n
