import flax.linen as nn
import jax.numpy as jnp

class MultiHeadAttention(nn.Module):
    config: dict

    @nn.compact
    def __call__(self, q, k, v, mask=None, deterministic=False, rngs=None):
        d_model = self.config['embedding_dim']
        num_heads = self.config['num_heads']
        d_head = d_model // num_heads
        dropout_rate = self.config['dropout_rate']

        # Linear projection for Q, K, V
        q = nn.Dense(d_model, use_bias=False, name='q_proj')(q)
        k = nn.Dense(d_model, use_bias=False, name='k_proj')(k)
        v = nn.Dense(d_model, use_bias=False, name='v_proj')(v)

        # Split heads
        def split_heads(x):
            return x.reshape(x.shape[:-1] + (num_heads, d_head)).transpose((0, 2, 1, 3))
        
        q, k, v = map(split_heads, (q, k, v))

        # Scaled dot-product attention
        scale = d_head ** -0.5
        attention = jnp.einsum('bhqd,bhkd->bhqk', q, k) * scale
        
        # Apply mask to attention scores
        if mask is not None:
            attention = attention - (1 - mask[:, None, None, :]) * 1e9

        attention = nn.softmax(attention, axis=-1)
        attention = nn.Dropout(rate=dropout_rate)(attention, deterministic=deterministic, rng=rngs.get('dropout') if rngs else None)
        
        # Combine heads
        output = jnp.einsum('bhqk,bhkd->bhqd', attention, v)
        output = output.transpose((0, 2, 1, 3)).reshape(output.shape[0], -1, d_model)

        # Final linear projection
        output = nn.Dense(d_model, use_bias=False, name='output_proj')(output)
        return nn.Dropout(rate=dropout_rate)(output, deterministic=deterministic, rng=rngs.get('dropout') if rngs else None)

class FeedForward(nn.Module):
    config: dict

    @nn.compact
    def __call__(self, x, deterministic=False, rngs=None):
        d_model = self.config['embedding_dim']   # Dimension of input and output embedding 
        d_proj = self.config['ff_dim']           # Dimension of projection in MLP
        dropout_rate = self.config['dropout_rate']
        
        x = nn.Dense(d_proj)(x)
        x = nn.silu(x)
        x = nn.Dense(d_model)(x)
        return nn.Dropout(rate=dropout_rate)(x, deterministic=deterministic, rng=rngs.get('dropout') if rngs else None)

class AttentionBlock(nn.Module):
    config: dict

    @nn.compact
    def __call__(self, x, mask=None, deterministic=False, rngs=None):
        attention_output = MultiHeadAttention(self.config)(x, x, x, mask=mask, deterministic=deterministic, rngs=rngs)
        x = x + attention_output
        x = nn.LayerNorm(epsilon=1e-6)(x)

        ff_output = FeedForward(self.config)(x, deterministic=deterministic, rngs=rngs)
        x = x + ff_output
        x = nn.LayerNorm(epsilon=1e-6)(x)
        return x