import haiku as hk
import jax
import jax.numpy as jnp


# Inspired by https://github.com/google-deepmind/dm-haiku/blob/main/examples/transformer/model.py
# Follows the architecture of https://gist.github.com/MJ10/59bfcc8bce4b5fce9c1c38a81b1105ae
class MultiTransformer(hk.Module):
    def __init__(
        self,
        networks: list[int],
        num_tokens: int,
        embed_dim: int,
        hid_dim: int,  # Dimension of hidden layer in feed-forward block
        num_layers: int,
        num_head: int,  # Number of attention heads
        pad_token: int,  # Token used in padding of sequences
        dropout=0.1,  # used for all parts that use dropout
        causal=True,
    ):
        """
        Encompasses multiple individual networks (specified by `networks` argument)

        `networks` is a list of ints, each of which corresponds to the output shape of
        a separate network.

        E.g. if your algorithm requires a value network and a policy network, you would use
        networks=[|A|, 1]
        """
        super().__init__()
        self.num_tokens = num_tokens
        self.embed_dim = embed_dim
        self.hid_dim = hid_dim
        self.num_layers = num_layers
        self.num_head = num_head
        self.mlp_layers = [4 * hid_dim, 4 * hid_dim]

        self.pad_token = pad_token
        self.dropout = dropout
        self.causal = causal

        self.initializer = hk.initializers.VarianceScaling(2 / self.num_layers)
        self.networks = networks

    def __call__(
        self,
        tokens: jax.Array,  # [B, T]
        is_training: bool = False,
    ) -> jax.Array:  # [B, T, D]
        def _sa_block(
            x,  # [B, T, D]
            mask,  # [B, 1, T, T]
        ) -> jax.Array:  # [B, T, D]
            x = hk.MultiHeadAttention(
                num_heads=self.num_head,
                key_size=self.embed_dim,
                model_size=self.embed_dim,
                w_init=self.initializer,
                with_bias=True,
            )(x, x, x, mask=mask)
            return hk.dropout(hk.next_rng_key(), self.dropout, x) if is_training else x

        def _ff_block(x):
            x = hk.Linear(self.hid_dim)(x)  # [B, T, H]
            x = jax.nn.relu(x)
            x = hk.dropout(hk.next_rng_key(), self.dropout, x) if is_training else x
            return hk.Linear(self.embed_dim)(x)  # [B, T, D]

        seq_len = tokens.shape[-1]

        # Build mask [B, 1, T, T]
        pad_mask = tokens != self.pad_token
        mask = pad_mask[:, None, None, :]  # [B, 1, 1, T]
        if self.causal:
            causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len)))  # [1, 1, T, T]
            mask = mask * causal_mask  # [B, 1, T, T]
        else:
            mask = mask * jnp.ones((1, 1, seq_len, seq_len))  # [B, 1, T, T]

        output = []
        for network_output_shape in self.networks:
            # Embed + positional encoding
            x = hk.Embed(self.num_tokens, self.embed_dim)(tokens)  # [B, T, D]
            x = PositionalEncoding(self.embed_dim, self.dropout, seq_len)(x, is_training=is_training)

            for _ in range(self.num_layers):
                x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x + _sa_block(x, mask))
                x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x + _ff_block(x))

            # MLP
            x = hk.Linear(self.mlp_layers[0])(x)
            x = jax.nn.relu(x)
            for hidden_dim in self.mlp_layers[1:]:
                x = hk.Linear(hidden_dim)(x)
                x = jax.nn.relu(x)

            logits = hk.Linear(network_output_shape)(x)  # [B, T, O]
            output.append(logits)

        return output


class PositionalEncoding(hk.Module):
    def __init__(self, embed_dim, dropout_rate=0.1, max_len=32):
        super().__init__()
        self.embed_dim = embed_dim
        self.dropout_rate = dropout_rate
        self.max_len = max_len

    def __call__(self, x, is_training=False):
        # TODO: cache this?
        positions = jnp.arange(self.max_len)[:, None]
        div_term = jnp.exp(jnp.arange(0, self.embed_dim, 2) * (-jnp.log(10000.0) / self.embed_dim))

        pe = jnp.zeros((self.max_len, 1, self.embed_dim))
        pe = pe.at[:, 0, 0::2].set(jnp.sin(positions * div_term))
        pe = pe.at[:, 0, 1::2].set(jnp.cos(positions * div_term))
        pe = jnp.permute_dims(pe, (1, 0, 2))
        x = x + pe[:, : x.shape[1]]
        return hk.dropout(hk.next_rng_key(), self.dropout_rate, x) if is_training else x
