import jax
import jax.numpy as jnp
from flax import nnx
from . import *


class MultiHeadEncoder(nnx.Module):
    """ Multi-head attention encoder for encoder-decoder pair
    input shape: (batch, context_len, feature)
    output shape: (batch, context_len, feature)
    """
    def __init__(self, feature: int, attn_feature: int, ffn_feature: int, num_head: int, init_scalar: float, key, dtype=jnp.float32):
        super().__init__()
        keys = jax.random.split(key, num=4)

        if feature % num_head != 0:
            raise ValueError(f'MHAEncoder: \'Feature\'({feature}) should be dividable by \'Head Count\'({num_head})')
        if attn_feature % num_head != 0:
            raise ValueError(f'MHAEncoder: \'Attention Feature\'({attn_feature}) should be dividable by \'Head Count\'({num_head})')

        # Config
        self.num_head = num_head

        # Layers
        self.rotary = RotaryEmbed(head_feature=feature // num_head, dtype=dtype)
        self.attention = MultiHeadAttention(
            attn_feature=attn_feature,
            feature=feature,
            num_head=num_head,
            init_scalar=init_scalar,
            is_causal=False,  # Encoder uses non-causal attention
            key=keys[0],
            dtype=dtype
        )
        self.ffn = SwiGluFFN(
            feature=feature,
            ffn_feature=ffn_feature,
            init_scalar=init_scalar,
            key=keys[1],
            dtype=dtype
        )
        self.norm1 = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[2])
        )
        self.norm2 = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[3])
        )

    def __call__(self, x, mask=None, q_len=None, kv_len=None):
        normed = self.norm1(x)
        normed = normed.reshape(*normed.shape[:-1], self.num_head, -1)
        rotated = self.rotary(normed)
        x = self.attention(rotated, rotated, normed, mask, q_len, kv_len) + x
        return self.ffn(self.norm2(x)) + x
