import haiku as hk
import jax
import jax.numpy as jnp


class PositionalEncoding(hk.Module):
    def __init__(self, embed_dim, dropout_rate=0.1, max_len=64):
        super().__init__()
        self.embed_dim = embed_dim
        self.dropout_rate = dropout_rate
        self.max_len = max_len

    def __call__(self, x, is_training=False):
        positions = jnp.arange(self.max_len)[:, None]
        div_term = jnp.exp(jnp.arange(0, self.embed_dim, 2) * (-jnp.log(10000.0) / self.embed_dim))

        pe = jnp.zeros((self.max_len, 1, self.embed_dim))
        pe = pe.at[:, 0, 0::2].set(jnp.sin(positions * div_term))
        pe = pe.at[:, 0, 1::2].set(jnp.cos(positions * div_term))
        pe = jnp.permute_dims(pe, (1, 0, 2))
        x = x + pe[:, : x.shape[1]]
        return hk.dropout(hk.next_rng_key(), self.dropout_rate, x) if is_training else x


class EncoderTransformer(hk.Module):
    def __init__(
        self,
        num_tokens: int,
        embed_dim: int,
        hid_dim: int,
        output_dim: int,
        num_layers: int,
        num_head: int,
        pad_token_idx: int,
        cls_token_idx: int = 0,
        dropout=0.1,
    ):
        super().__init__()
        self.num_tokens = num_tokens
        self.embed_dim = embed_dim
        self.hid_dim = hid_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.num_head = num_head
        self.mlp_layers = [4 * hid_dim, 4 * hid_dim]

        self.pad_token_idx = pad_token_idx
        self.cls_token_idx = cls_token_idx
        self.dropout = dropout

        self.initializer = hk.initializers.VarianceScaling(2 / self.num_layers)

    def __call__(self, x: jax.Array, is_training: bool = False, get_embed: bool = False) -> jax.Array:
        def _sa_block(
            x,
            mask,
        ) -> jax.Array:
            x = hk.MultiHeadAttention(
                num_heads=self.num_head,
                key_size=self.embed_dim,
                model_size=self.embed_dim,
                w_init=self.initializer,
                with_bias=True,
            )(x, x, x, mask=mask)
            return hk.dropout(hk.next_rng_key(), self.dropout, x) if is_training else x

        def _ff_block(x):
            x = hk.Linear(self.hid_dim)(x)
            x = jax.nn.relu(x)
            x = hk.dropout(hk.next_rng_key(), self.dropout, x) if is_training else x
            return hk.Linear(self.embed_dim)(x)

        # Inspired by https://github.com/google-deepmind/dm-haiku/blob/main/examples/transformer/model.py
        seq_len = x.shape[-1]
        pad_mask = x != self.pad_token_idx
        mask = pad_mask[:, None, None, :]  # [B, 1, 1, T]
        mask = mask * jnp.ones((1, 1, seq_len, seq_len))  # [B, 1, T, T]

        x = hk.Embed(self.num_tokens, self.embed_dim)(x)
        x = PositionalEncoding(self.embed_dim, self.dropout, seq_len)(x, is_training=is_training)

        for _ in range(self.num_layers):
            x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x + _sa_block(x, mask))
            x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x + _ff_block(x))

        cls_embed = x[:, self.cls_token_idx]

        # MLP
        x = hk.Linear(self.mlp_layers[0])(cls_embed)
        x = jax.nn.relu(x)
        for hidden_dim in self.mlp_layers[1:]:
            x = hk.Linear(hidden_dim)(x)
            x = jax.nn.relu(x)

        outputs = hk.Linear(self.output_dim)(x)

        if get_embed:
            return outputs, cls_embed

        return outputs

    def get_kwargs(self):
        return {
            "num_tokens": self.num_tokens,
            "embed_dim": self.embed_dim,
            "hid_dim": self.hid_dim,
            "output_dim": self.output_dim,
            "num_layers": self.num_layers,
            "num_head": self.num_head,
            "pad_token_idx": self.pad_token_idx,
            "dropout": self.dropout,
        }
