import jax
import jax.numpy as jnp
from flax import nnx
from . import *


class MultiHeadDecoder(nnx.Module):
    """ Multi-head attention decoder
    input shape: (batch, context_len, feature)
    output shape: (batch, context_len, feature)
    """
    def __init__(self, feature: int, attn_feature: int, ffn_feature: int, num_head: int, is_causal: bool, init_scalar: float, key, dtype=jnp.float32):
        super().__init__()
        keys = jax.random.split(key, num=4)

        if feature % num_head != 0:
            raise ValueError(f'MultiHeadDecoder: \'Feature\'({feature}) should be dividable by \'Head Count\'({num_head})')
        if attn_feature % num_head != 0:
            raise ValueError(f'MultiHeadDecoder: \'Attention Feature\'({attn_feature}) should be dividable by \'Head Count\'({num_head})')

        # Config
        self.num_head = num_head

        # Layers
        self.rotary = RotaryEmbed(head_feature=feature // num_head, dtype=dtype)
        self.attention = MultiHeadAttention(
            attn_feature=attn_feature,
            feature=feature,
            num_head=num_head,
            init_scalar=init_scalar,
            is_causal=is_causal,
            key=keys[0],
            dtype=dtype
        )
        self.ffn = SwiGluFFN(
            feature=feature,
            ffn_feature=ffn_feature,
            init_scalar=init_scalar,
            key=keys[1],
            dtype=dtype
        )
        self.norm1 = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[2])
        )
        self.norm2 = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[3])
        )

    def __call__(self, x, mask=None, q_len=None, kv_len=None):
        normed = self.norm1(x)
        normed = normed.reshape(*normed.shape[:-1], self.num_head, -1)
        rotated = self.rotary(normed)
        x = self.attention(rotated, rotated, normed, mask, q_len, kv_len) + x
        return self.ffn(self.norm2(x)) + x


class CrossDecoder(nnx.Module):
    """ Multi-head attention decoder with cross-attention
    input shape: (batch, context_len, feature)
    encoder_output shape: (batch, encoder_len, feature)
    output shape: (batch, context_len, feature)
    """

    def __init__(self, feature: int, attn_feature: int, ffn_feature: int, num_head: int, init_scalar: float, key,
                 dtype=jnp.float32):
        super().__init__()
        keys = jax.random.split(key, num=7)

        if feature % num_head != 0:
            raise ValueError(f'MHADecoder: \'Feature\'({feature}) should be dividable by \'Head Count\'({num_head})')
        if attn_feature % num_head != 0:
            raise ValueError(
                f'MHADecoder: \'Attention Feature\'({attn_feature}) should be dividable by \'Head Count\'({num_head})')

        # Config
        self.num_head = num_head

        # Layers
        self.rotary = RotaryEmbed(head_feature=feature // num_head, dtype=dtype)

        # Self-attention
        self.self_attention = MultiHeadAttention(
            attn_feature=attn_feature,
            feature=feature,
            num_head=num_head,
            init_scalar=init_scalar,
            is_causal=True,  # Decoder uses causal self-attention
            key=keys[0],
            dtype=dtype
        )

        # Cross-attention
        self.cross_attention = MultiHeadAttention(
            attn_feature=attn_feature,
            feature=feature,
            num_head=num_head,
            init_scalar=init_scalar,
            is_causal=False,  # Cross-attention is non-causal
            key=keys[1],
            dtype=dtype
        )

        self.ffn = SwiGluFFN(
            feature=feature,
            ffn_feature=ffn_feature,
            init_scalar=init_scalar,
            key=keys[2],
            dtype=dtype
        )

        # Normalization layers
        self.norm1 = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[3])
        )
        self.norm2 = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[4])
        )
        self.norm3 = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[5])
        )

    def __call__(self, x, encoder_output, self_mask=None, cross_mask=None, q_len=None, kv_len=None, encoder_len=None):
        # Self-attention
        normed1 = self.norm1(x)
        normed1 = normed1.reshape(*normed1.shape[:-1], self.num_head, -1)
        rotated1 = self.rotary(normed1)
        x = self.self_attention(rotated1, rotated1, normed1, self_mask, q_len, kv_len) + x

        # Cross-attention
        normed2 = self.norm2(x)
        normed2 = normed2.reshape(*normed2.shape[:-1], self.num_head, -1)
        rotated2 = self.rotary(normed2)

        encoder_normed = encoder_output.reshape(*encoder_output.shape[:-1], self.num_head, -1)
        encoder_rotated = self.rotary(encoder_normed)
        x = self.cross_attention(rotated2, encoder_rotated, encoder_normed, cross_mask, q_len, encoder_len) + x

        # FFN
        return self.ffn(self.norm3(x)) + x
