from typing import Type, Any

import haiku as hk
import jax
import jax.numpy as jnp

from hashes_tester import LCGHash, SimHash
class MinatarEpistemicAZNet(hk.Module):
    """Epistemic AlphaZero NN architecture for Minatar. Linear layers hidden sizes are all the same, defaults to 64.
    Body:
        One Conv2D with 16 channels followed by two linear layers.
    Policy heads:
        Two linear layers, Relu activations, no activate on last layer.
    Value heads:
        Two linear layers, Relu activations, tanh activate on value and exp2 activate on ube.
    """

    def __init__(
        self,
        num_actions,
        num_channels: int = 16,
        hidden_layers_size: int = 64,
        max_u: float = jnp.inf,
        max_epistemic_variance_reward: float = 1.0,
        discount: float = 0.9997,
        hash_class: Type = SimHash,
        hash_args: dict[str, Any] | None = None,
        name="minatar_az_net",
    ):
        """
        num_actions = env action space size
        num_channels = num_channels for the conv2d layer
        hidden_layers_size = num of units in each hidden layer
        max_u = if passed, clips the UBE prediction <= max_u. In board games for example, max_u = 1
        max_epistemic_variance_reward = used to scale the hash to max_reward ** 2 := max V[R]
        discount = the bellman discount, used to scale the reward uncertainty for novel states
        hash_class = SimHash or LCGHash
        """
        super().__init__(name=name)
        self.num_actions = num_actions
        self.num_channels = num_channels
        self.hidden_layers_size = hidden_layers_size
        self.hash_class = hash_class
        self.hash_args = hash_args if hash_args is not None else dict()
        self.max_u = max_u
        self.local_unc_to_max_value_unc_scale = 1.0 / (1 - 0.997**2)
        self.max_reward_epistemic_variance = max_epistemic_variance_reward

    def __call__(
        self, x, is_training, test_local_stats, update_hash: bool = False
    ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
        # Minatar arch. body:
        x = x.astype(jnp.float32)
        x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
        x = jax.nn.relu(x)
        x = hk.Flatten()(x)
        x = hk.Linear(self.hidden_layers_size)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(self.hidden_layers_size)(x)
        x = jax.nn.relu(x)

        # policy head
        main_policy_logits = hk.Linear(self.hidden_layers_size)(x)
        main_policy_logits = jax.nn.relu(main_policy_logits)
        main_policy_logits = hk.Linear(self.num_actions)(main_policy_logits)

        # exploration policy head
        exploration_policy_logits = hk.Linear(self.hidden_layers_size)(x)
        exploration_policy_logits = jax.nn.relu(exploration_policy_logits)
        exploration_policy_logits = hk.Linear(self.num_actions)(exploration_policy_logits)

        # value head
        v = hk.Linear(self.hidden_layers_size)(x)
        v = jax.nn.relu(v)
        v = hk.Linear(1)(v)
        # v = jnp.tanh(v)   # Needed only in zero-sum games, and this is a minatar net
        v = v.reshape((-1,))

        # ube head
        u = hk.Linear(self.hidden_layers_size)(x)
        u = jax.nn.relu(u)
        u = hk.Linear(1)(u)
        u = jnp.exp2(u)
        u = u.reshape((-1,))

        # local uncertainty
        hash_obj = self.hash_class(**self.hash_args)
        scaled_state_novelty = (~hash_obj(x)) * self.max_reward_epistemic_variance
        if not is_training:
            # The UBE prediction for AZ is max(attainable sum of reward_unc speculated from local reward_unc, ube)
            u = jnp.maximum(scaled_state_novelty * self.local_unc_to_max_value_unc_scale, u)
            u.clip(min=0, max=self.max_u)
            u = 100 * jnp.ones_like(v)

        if update_hash:
            hash_obj.update(x)

        return main_policy_logits, exploration_policy_logits, 100 * jnp.ones_like(v), u, jnp.zeros_like(v)

