import jax
import jax.numpy as jnp
from flax import nnx
import functools


def _select_implementation(head_dim: int, dtype: jnp.dtype):
    """ Select scaled dot-product attention implementation
    Requirements for CuDNN implementation:
        - dtype: float16 or bfloat16
        - head dimension: multiple of 8 (16, 32, 64, ...)
        - mask size: no bigger than 2^31 - 1
    """
    cudnn_supported = {jnp.float16, jnp.bfloat16}
    if dtype in cudnn_supported and head_dim % 8 == 0:
        return 'cudnn'
    else:
        return 'xla'


class MultiHeadAttention(nnx.Module):
    """ Multi-head attention layer with QK norm
    input shape: (batch, context_len, num_head, head_feature)
    output shape: (batch, context_len, feature)
    """
    def __init__(self, feature: int, attn_feature: int, num_head: int, is_causal: bool, init_scalar: float, key, dtype=jnp.float32):
        super().__init__()
        keys = jax.random.split(key, num=6)
        initializer = jax.nn.initializers.variance_scaling(init_scalar, 'fan_in', 'normal')

        if feature % num_head != 0:
            raise ValueError(f'MultiHeadAttention: \'Feature\'({feature}) should be dividable by \'Head Count\'({num_head})')
        if attn_feature % num_head != 0:
            raise ValueError(f'MultiHeadAttention: \'Attention Feature\'({attn_feature}) should be dividable by \'Head Count\'({num_head})')

        # Linear layer for final output
        self.weight = nnx.Param(initializer(keys[0], shape=(attn_feature, feature), dtype=dtype), name='MHA_W')

        # Linear layer for attention head
        self.w_q = nnx.Param(initializer(keys[1], shape=(num_head, feature // num_head, attn_feature // num_head), dtype=dtype), name='MHA_Wq')
        self.w_k = nnx.Param(initializer(keys[2], shape=(num_head, feature // num_head, attn_feature // num_head), dtype=dtype), name='MHA_Wk')
        self.w_v = nnx.Param(initializer(keys[3], shape=(num_head, feature // num_head, attn_feature // num_head), dtype=dtype), name='MHA_Wv')

        # Scaled dot-product attention implementation
        self.dot_product_attention = functools.partial(
            jax.nn.dot_product_attention,
            is_causal=is_causal,
            implementation=_select_implementation(attn_feature // num_head, dtype)
        )

        # QK-Norm layer
        self.query_ln = nnx.LayerNorm(
            num_features=attn_feature // num_head,
            use_scale=False,
            use_bias=False,
            dtype=dtype,
            param_dtype=dtype,
            rngs=nnx.rnglib.Rngs(keys[4])
        )
        self.key_ln = nnx.LayerNorm(
            num_features=attn_feature // num_head,
            use_scale=False,
            use_bias=False,
            dtype=dtype,
            param_dtype=dtype,
            rngs=nnx.rnglib.Rngs(keys[5])
        )

    def __call__(self, q, k, v, mask=None, q_len=None, kv_len=None):
        q = self.query_ln(jnp.einsum('bthd,hdo->btho', q, self.w_q))
        k = self.key_ln(jnp.einsum('bthd,hdo->btho', k, self.w_k))
        v = jnp.einsum('bthd,hdo->btho', v, self.w_v)

        attention = self.dot_product_attention(query=q, key=k, value=v, mask=mask, query_seq_lengths=q_len, key_value_seq_lengths=kv_len)
        return attention.reshape(*attention.shape[:-2], -1) @ self.weight
