from typing import Literal
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx

Activation = Literal["tanh", "relu"]
Encoding = Literal["lin", "sin", "rnd", "d-asym"]


class PositionalEncoding(nnx.Module):
    def __init__(
        self,
        embed_dim: int,
        max_seq_len: int,
        encoding: Encoding = "lin",
    ):
        """
        Initialize positional encoding with different encoding schemes.

        Args:
            embed_dim: Dimension of the embedding
            max_seq_len: Maximum sequence length
            encoding: Type of positional encoding to use
        """
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        self.encoding = encoding

        if encoding == "sin":
            pe = self._sinusoidal_positioning(embed_dim, max_seq_len)
        elif encoding == "lin":
            pe = self._linear_encoding(embed_dim, max_seq_len)
        elif encoding == "d-asym":
            pe = self._d_asym_encoding(embed_dim, max_seq_len)
        elif encoding == "rnd":
            key = jax.random.PRNGKey(0)
            pe = jax.random.uniform(key, (max_seq_len, embed_dim))

        # Add batch dimension [1, max_seq_len, embed_dim
        self.pe = pe[None, :, :]

    def _sinusoidal_positioning(self, embed_dim: int, max_len: int):
        positions = jnp.arange(max_len, dtype=jnp.float32)[:, None]

        div_term = jnp.exp(
            jnp.arange(0, embed_dim, 2, dtype=jnp.float32)
            * (-np.log(10000.0) / embed_dim)
        )

        pe = jnp.zeros((max_len, embed_dim))

        pe = pe.at[:, 0::2].set(jnp.sin(positions * div_term))
        pe = pe.at[:, 1::2].set(jnp.cos(positions * div_term))

        return pe

    def _d_asym_encoding(self, embed_dim: int, max_len: int):
        seq = jnp.linspace(-1, 1, num=embed_dim * max_len)
        pos = seq.reshape([embed_dim, max_len]).T
        return pos

    def _linear_encoding(self, embed_dim: int, max_len: int):
        seq = jnp.linspace(0, 1, num=max_len)
        pos = jnp.broadcast_to(seq[:, None], (max_len, embed_dim))
        return pos

    def __call__(self, x):
        return x + self.pe[:, : x.shape[1], :]


class FeedForward(nnx.Module):
    def __init__(
        self,
        embed_dim: int,
        ff_dim: int,
        dropout: float = 0.1,
        bias: bool = False,
        *,
        rngs: nnx.Rngs,
        activation: Activation = "tanh",
    ):
        """
        Feed-forward network used in Transformer blocks.

        Args:
            embed_dim: Input embedding dimension
            ff_dim: Hidden dimension of feed-forward network
            dropout: Dropout rate
            bias: Whether to use bias in linear layers
            rngs: PRNG key for initialization
        """
        self.activation = activation
        self.linear1 = nnx.Linear(
            in_features=embed_dim,
            out_features=ff_dim,
            use_bias=bias,
            rngs=rngs,
        )
        self.linear2 = nnx.Linear(
            in_features=ff_dim,
            out_features=embed_dim,
            use_bias=bias,
            rngs=rngs,
        )
        self.dropout = nnx.Dropout(rate=dropout, rngs=rngs)

    def __call__(self, x, deterministic=False):
        """
        Forward pass through feed-forward network.

        Args:
            x: Input tensor
            deterministic: Whether to apply dropout (False during training)

        Returns:
            Processed tensor
        """
        x = self.linear1(x)
        if self.activation == "tanh":
            x = jnp.tanh(x)
        else:
            x = nnx.relu(x)

        if not deterministic:
            x = self.dropout(x)
        return self.linear2(x)


class MultiHeadAttention(nnx.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.1,
        bias: bool = False,
        *,
        rngs: nnx.Rngs,
    ):
        """
        Multi-head attention implementation.

        Args:
            embed_dim: Input embedding dimension
            num_heads: Number of attention heads
            dropout: Dropout rate
            bias: Whether to use bias in linear projections
            rngs: PRNG key for initialization
        """
        if embed_dim % num_heads != 0:
            raise ValueError(
                f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})"
            )

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Linear projections for Q, K, V
        self.query = nnx.Linear(
            in_features=embed_dim,
            out_features=embed_dim,
            use_bias=bias,
            rngs=rngs,
        )
        self.key = nnx.Linear(
            in_features=embed_dim,
            out_features=embed_dim,
            use_bias=bias,
            rngs=rngs,
        )
        self.value = nnx.Linear(
            in_features=embed_dim,
            out_features=embed_dim,
            use_bias=bias,
            rngs=rngs,
        )
        self.out_proj = nnx.Linear(
            in_features=embed_dim,
            out_features=embed_dim,
            use_bias=bias,
            rngs=rngs,
        )

        self.dropout = nnx.Dropout(rate=dropout, rngs=rngs)

    def __call__(self, query, key, value, mask=None, deterministic: bool = True):
        """
        Forward pass through multi-head attention.

        Args:
            query: Query tensor
            key: Key tensor
            value: Value tensor
            mask: Optional attention mask
            deterministic: Whether to apply dropout (False during training)

        Returns:
            Output after attention
        """
        batch_size = query.shape[0]

        # Linear projections and reshape to [batch_size, num_heads, seq_len, head_dim]
        shape_dyn = (batch_size, -1, self.num_heads, self.head_dim)
        trans_indices = (0, 2, 1, 3)
        q = self.query(query).reshape(*shape_dyn)
        q = jnp.transpose(q, trans_indices)

        k = self.key(key).reshape(*shape_dyn)
        k = jnp.transpose(k, trans_indices)

        v = self.value(value).reshape(*shape_dyn)
        v = jnp.transpose(v, trans_indices)

        scores = jnp.matmul(q, jnp.transpose(k, (0, 1, 3, 2))) / np.sqrt(self.head_dim)

        if mask is not None:
            scores = jnp.where(mask == 0, jnp.full_like(scores, -1e9), scores)

        attn_weights = jax.nn.softmax(scores, axis=-1)

        if not deterministic:
            attn_weights = self.dropout(attn_weights)

        attn_output = jnp.matmul(attn_weights, v)

        attn_output = jnp.transpose(attn_output, trans_indices)
        attn_output = attn_output.reshape(batch_size, -1, self.embed_dim)

        return self.out_proj(attn_output)


class TransformerBlock(nnx.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        ff_dim: int,
        dropout: float = 0.1,
        bias: bool = False,
        *,
        rngs: nnx.Rngs,
    ):
        """
        Transformer block with multi-head attention and feed-forward network.

        Args:
            embed_dim: Input embedding dimension
            num_heads: Number of attention heads
            ff_dim: Hidden dimension of feed-forward network
            dropout: Dropout rate
            bias: Whether to use bias in linear layers
            rngs: PRNG key for initialization
        """
        self.attention = MultiHeadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            bias=bias,
            rngs=rngs,
        )
        self.norm1 = nnx.LayerNorm(num_features=embed_dim, rngs=rngs)
        self.feed_forward = FeedForward(
            embed_dim=embed_dim,
            ff_dim=ff_dim,
            dropout=dropout,
            bias=bias,
            rngs=rngs,
        )
        self.norm2 = nnx.LayerNorm(num_features=embed_dim, rngs=rngs)
        self.dropout = nnx.Dropout(rate=dropout, rngs=rngs)

    def __call__(self, x, mask=None, deterministic: bool = True):
        """
        Forward pass through transformer block.

        Args:
            x: Input tensor
            mask: Optional attention mask
            deterministic: Whether to apply dropout (False during training)

        Returns:
            Processed tensor
        """
        # Post-LN architecture (as in original implementation)
        attn_output = self.attention(x, x, x, mask, deterministic=deterministic)
        if not deterministic:
            attn_output = self.dropout(attn_output)
        x = self.norm1(x + attn_output)

        ff_output = self.feed_forward(x, deterministic=deterministic)
        if not deterministic:
            ff_output = self.dropout(ff_output)
        x = self.norm2(x + ff_output)

        return x


class TransformerModel(nnx.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        num_heads: int,
        ff_dim: int,
        num_layers: int,
        dropout: float = 0.1,
        max_seq_len: int = 5000,
        bias: bool = False,
        regression: bool = False,
        encoding: Encoding = "lin",
        *,
        rngs: nnx.Rngs,
    ):
        """
        Complete Transformer model.

        Args:
            vocab_size: Size of vocabulary
            embed_dim: Embedding dimension
            num_heads: Number of attention heads
            ff_dim: Hidden dimension of feed-forward network
            num_layers: Number of transformer blocks
            dropout: Dropout rate
            max_seq_len: Maximum sequence length for positional encoding
            bias: Whether to use bias in linear layers
            regression: Whether model is for regression task
            encoding: Type of positional encoding
            rngs: PRNG key for initialization
        """
        self.regression = regression

        # Token embedding
        self.embedding = nnx.Embed(
            num_embeddings=vocab_size, features=embed_dim, rngs=rngs
        )

        # Positional encoding
        self.positional_encoding = PositionalEncoding(
            embed_dim=embed_dim, max_seq_len=max_seq_len, encoding=encoding
        )

        self.dropout = nnx.Dropout(rate=dropout, rngs=rngs)

        # Create stack of transformer blocks
        self.transformer_blocks = [
            TransformerBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                ff_dim=ff_dim,
                dropout=dropout,
                bias=bias,
                rngs=rngs,
            )
            for _ in range(num_layers)
        ]

        # Output projection
        self.output_proj = nnx.Linear(
            in_features=embed_dim, out_features=vocab_size, use_bias=bias, rngs=rngs
        )

    def __call__(self, x, mask=None, deterministic: bool = True):
        """
        Forward pass through transformer model.

        Args:
            x: Input tensor of shape [batch_size, seq_len] for token indices
                or [batch_size, embed_dim, seq_len] for regression
            mask: Optional attention mask
            deterministic: Whether to apply dropout (False during training)

        Returns:
            Output logits
        """
        if self.regression:
            if x.ndim != 3:
                raise ValueError("Input expected to have 3 dimensions for regression")
            # Handle regression case where input is [batch_size, embed_dim, seq_len]
            x = jnp.matmul(self.embedding.embedding.value.T[None], x)
            x = jnp.transpose(x, (0, 2, 1))
        else:
            # Handle classification where input is token indices [batch_size, seq_len]
            x = self.embedding(x)  # [batch_size, seq_len, embed_dim]

        # Add positional encoding
        x = self.positional_encoding(x)

        # Apply dropout if not in deterministic mode
        if not deterministic:
            x = self.dropout(x)

        # Pass through transformer blocks
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, mask, deterministic=deterministic)

        # Final projection to vocabulary
        return self.output_proj(x)  # [batch_size, seq_len, vocab_size]
