"""
Minimal Language Model for mad-lab framework.

A streamlined implementation with:
- Configurable embed_dim (arbitrary embedding dimension)
- Simple MLP encoder (embed_dim → dim) with ReLU
- GaussMixer for direct layer stacking (no norm/residuals)
- Simple decoder MLP (dim → vocab_size) with ReLU

Architecture:
    Input IDs → Embedding(embed_dim) → Encoder MLP → GaussMixer →
    Decoder MLP → Linear(vocab_size) → Logits
"""

import typing as tp
import torch
import torch.nn as nn

from mad.model.layers.mlp_in_out import MLPInOut
from mad.model.layers.ops.norm.rmsnorm import RMSNorm


class LanguageModelMinimal(nn.Module):
    """
    Minimal language model with configurable embedding dimension and simple decoder.

    Uses custom Sequential MLPs with ReLU activation for encoder/decoder,
    and GaussMixer for layer stacking without normalization or residuals.
    """

    def __init__(
        self,
        vocab_size: int,
        layers: list,
        layer_cfgs: list,
        dim: int = 128,
        embed_dim: int = 16,
        encoder_hidden_units: tp.List[int] = None,
        decoder_hidden_units: tp.List[int] = None,
        max_length: int = 1280,
        norm: nn.Module = None,
        position_embeds: tp.Callable = None,
        embed_drop_rate: float = 0.0,
        use_norm: bool = True,
        **kwargs
    ):
        """
        Initialize minimal language model.

        Args:
            vocab_size: Vocabulary size
            layers: List of layer modules (GaussBlock instances)
            layer_cfgs: List of layer config dicts
            dim: Working dimension for GaussBlock layers
            embed_dim: Token embedding dimension (default: 16)
            encoder_hidden_units: Hidden layer sizes for encoder MLP (default: [120])
            decoder_hidden_units: Hidden layer sizes for decoder MLP (default: [120])
            max_length: Maximum sequence length
            norm: Normalization module (unused, kept for compatibility)
            position_embeds: Optional positional embedding function
            embed_drop_rate: Dropout rate after embeddings
            use_norm: If True, apply RMSNorm before each layer (default: True)
            **kwargs: Additional arguments (ignored)
        """
        super().__init__()

        self.use_norm = use_norm

        # Extract embed_dim from layer config if provided
        if layer_cfgs and 'embed_dim' in layer_cfgs[0]:
            embed_dim = layer_cfgs[0]['embed_dim']

        # Extract hidden units from layer config if provided
        if layer_cfgs and 'encoder_hidden_units' in layer_cfgs[0]:
            encoder_hidden_units = layer_cfgs[0]['encoder_hidden_units']
        if layer_cfgs and 'decoder_hidden_units' in layer_cfgs[0]:
            decoder_hidden_units = layer_cfgs[0]['decoder_hidden_units']

        # Default hidden units
        if encoder_hidden_units is None:
            encoder_hidden_units = [120]
        if decoder_hidden_units is None:
            decoder_hidden_units = [120]

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.dim = dim
        self.max_length = max_length
        self.position_embeds = position_embeds

        # 1. Token embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim)

        # 2. Optional embedding dropout
        self.embed_dropout = nn.Dropout(embed_drop_rate) if embed_drop_rate > 0 else None

        # 3. Encoder MLP: embed_dim → [hidden_units] → dim
        self.encoder_mlp = MLPInOut(
            dim_in=embed_dim,
            dim_out=dim,
            hidden_units=encoder_hidden_units,
            activation='gelu',
            dropout=0.0,
            bias=True
        )

        # 4. Layer stacking with RMSNorm + residual (from language_model_gauss.py)
        self.model = nn.ModuleList([])
        for layer, layer_cfg in zip(layers, layer_cfgs):
            cfg = layer_cfg.copy()

            # All layers use 'dim' parameter (GaussBlock, Mamba, GatedDeltaNet, etc.)
            cfg['dim'] = dim
            # Keep d_state, use_causal_conv, use_lambda_skip as-is (native names)

            # Remove MLP-related parameters (not needed by blocks)
            for key in ['embed_dim', 'encoder_hidden_units', 'decoder_hidden_units', 'max_length']:
                cfg.pop(key, None)

            # Create layer: optionally RMSNorm → Layer (GaussBlock or Mamba)
            if self.use_norm:
                self.model.append(nn.Sequential(RMSNorm(dim), layer(**cfg)))
            else:
                self.model.append(layer(**cfg))

        # 5. Decoder MLP: dim → [hidden_units] → dim
        self.decoder_mlp = MLPInOut(
            dim_in=dim,
            dim_out=dim,
            hidden_units=decoder_hidden_units,
            activation='gelu',
            dropout=0.0,
            bias=True
        )

        # 6. Unembedding: dim → vocab_size
        self.unembed = nn.Linear(dim, vocab_size)
        self.apply(self._init_weights)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through minimal language model.

        Args:
            input_ids: Input token IDs, shape (batch_size, seq_len)

        Returns:
            logits: Output logits, shape (batch_size, seq_len, vocab_size)
        """
        # 1. Embed tokens
        x = self.embedding(input_ids)  # (B, L, embed_dim)

        # 2. Optional positional embeddings
        if self.position_embeds is not None:
            x = x + self.position_embeds(x)

        # 3. Optional embedding dropout
        if self.embed_dropout is not None:
            x = self.embed_dropout(x)

        # 4. Encode to working dimension
        x = self.encoder_mlp(x)  # (B, L, dim)

        # 5. Process through layers with residual connections
        for layer in self.model:
            x = x + layer(x)

        # 6. Decode
        x = self.decoder_mlp(x)  # (B, L, dim)

        # 7. Project to vocabulary
        logits = self.unembed(x)  # (B, L, vocab_size)

        return logits
    
    def _init_weights(self, m, initializer_range=0.02) -> None:
            if isinstance(m, nn.Linear):
                if m.bias is not None:
                    if not getattr(m.bias, "_no_reinit", False):
                        nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0, std=initializer_range)
