"""Modules used to build ACEv2 model."""

import math
from typing import Any, Callable, List, Optional, Tuple

import torch
from torch.distributions import Categorical, Normal
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask, flex_attention

from src.models.utils import (
    build_mlp_with_linear_skipcon,
    expand_kv_heads,
    positional_encoding_init,
)
from src.utils import DataAttr, LossAttr


class Embedder(torch.nn.Module):
    """
    Embeds context pairs (xc, yc), buffer pairs (xb, yb), or targets xt into a
    shared D-dimensional space, with an additional learnable marker indicating
    context, target, or buffer.
    
    The embedder uses separate MLPs for x and y values, and adds a learned
    marker embedding to distinguish between different data modes (context, buffer, target).
    """

    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        hidden_dim: int,
        out_dim: int,
        depth: int,
        mlp_builder: Callable[..., torch.nn.Module] = build_mlp_with_linear_skipcon,
        pos_emb_init: bool = False,
    ):
        """
        Initialize the Embedder module.
        
        Args:
            dim_x: Input dimension for x values
            dim_y: Input dimension for y values
            hidden_dim: Hidden dimension for the MLPs
            out_dim: Output embedding dimension
            depth: Number of layers in the MLPs
            mlp_builder: Function to build MLP networks
            pos_emb_init: Whether to initialize marker embeddings with positional encoding
        """
        super().__init__()
        self.marker_lookup = {"target": 0, "context": 1, "buffer": 2}
        self.x_embed = mlp_builder(dim_x, hidden_dim, out_dim, depth)
        self.y_embed = mlp_builder(dim_y, hidden_dim, out_dim, depth)
        self.marker_embed = torch.nn.Embedding(3, out_dim)

        if pos_emb_init:
            self.marker_embed.weight = positional_encoding_init(3, out_dim, 2)
    
    def _get_marker_embedding(
        self, 
        batch_size: int,
        marker_type: str, 
        device: torch.device
     ) -> torch.Tensor:
        """Get marker embedding for the specified type."""
        marker = self.marker_lookup[marker_type]
        marker_idx = torch.full(
            (batch_size, 1), marker, dtype=torch.long, device=device
        )
        return self.marker_embed(marker_idx)

    def embed_context(self, batch: DataAttr) -> torch.Tensor:
        """Embed context pairs (xc, yc) with context marker."""
        x_emb = self.x_embed(batch.xc)
        y_emb = self.y_embed(batch.yc)
        marker_emb = self._get_marker_embedding(x_emb.size(0), "context", x_emb.device)
        return x_emb + y_emb + marker_emb

    def embed_buffer(self, batch: DataAttr) -> torch.Tensor:
        """Embed buffer pairs (xb, yb) with buffer marker."""
        x_emb = self.x_embed(batch.xb)
        y_emb = self.y_embed(batch.yb)
        marker_emb = self._get_marker_embedding(x_emb.size(0), "buffer", x_emb.device)
        return x_emb + y_emb + marker_emb

    def embed_target(self, batch: DataAttr) -> torch.Tensor:
        """Embed target inputs (xt) with target marker."""
        x_emb = self.x_embed(batch.xt)
        marker_emb = self._get_marker_embedding(x_emb.size(0), "target", x_emb.device)
        return x_emb + marker_emb


class Head(torch.nn.Module):
    """Multi-component neural network head with separate MLPs."""

    def __init__(
        self, dim_model: int, dim_feedforward: int, dim_out: int, K: int
    ) -> None:
        super().__init__()
        self.K = K

        # weight shapes → [K, out, in]
        self.w1 = torch.nn.Parameter(torch.empty(K, dim_feedforward, dim_model))
        self.b1 = torch.nn.Parameter(torch.empty(K, dim_feedforward))
        self.w2 = torch.nn.Parameter(torch.empty(K, dim_out, dim_feedforward))
        self.b2 = torch.nn.Parameter(torch.empty(K, dim_out))

        # kaiming init for each block
        for w in (self.w1, self.w2):
            torch.nn.init.kaiming_uniform_(w, a=math.sqrt(5))
        for b, fan_in in ((self.b1, dim_model), (self.b2, dim_feedforward)):
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(b, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through all K separate MLPs."""
        # broadcast the same input to every component: [B, T, K, dim_model]
        x_exp = x.unsqueeze(-2).expand(-1, -1, self.K, -1)

        h = torch.einsum("btkd,khd->btkh", x_exp, self.w1) + self.b1
        h = F.gelu(h)
        y = torch.einsum("btkh,koh->btko", h, self.w2) + self.b2  # [B,T,K,dim_out]
        return y


class MixtureGaussian(torch.nn.Module):
    """Mixture of Gaussians model for probabilistic predictions."""

    def __init__(
        self,
        dim_y: int,
        dim_model: int,
        dim_feedforward: int,
        num_components: int,
        name: Optional[str] = None,
        trange: Tuple[float, float] = (-1.0, 1.0),
        std_min: float = 1e-3,
        logit_clip: float = 20.0,
    ) -> None:
        super().__init__()
        self.name = name
        self.dim_model = dim_model
        self.dim_feedforward = dim_feedforward
        self.dim_y = dim_y

        # initialize head
        self.head = Head(
            dim_model=dim_model,
            dim_feedforward=dim_feedforward,
            dim_out=3 * dim_y,
            K=num_components,
        )

        self.num_components = num_components
        min_range, max_range = tuple(trange)

        self.mean_global_bias = torch.nn.Parameter(
            torch.linspace(min_range, max_range, num_components)
        )
        delta = 0.5 * (max_range - min_range) / (num_components - 1)
        self.std_global_bias = torch.nn.Parameter(
            torch.ones_like(self.mean_global_bias)
            * self._inverse_softplus(torch.tensor(delta))
        )
        self.weights_global_bias = torch.nn.Parameter(torch.zeros(num_components))

        self.std_min = std_min
        self.logit_clip = logit_clip

    def forward(
        self,
        zt: torch.Tensor,
        yt: Optional[torch.Tensor] = None,
        *,
        loss_mask: Optional[torch.Tensor] = None,
        num_samples: int = 0,
    ) -> LossAttr:
        """Forward pass through the mixture model.

        Always computes the GMM head in float32 for numerical stability,
        regardless of surrounding mixed precision autocast context.
        """
        # Ensure computations in the head are carried out in float32
        device_type = "cuda" if zt.is_cuda else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            zt32 = zt.float()
            mean, std, weights = self._parameterize(zt32)

            if yt is not None:
                yt32 = yt.float()
                log_likelihood = self._loglikelihood(yt32, mean, std, weights)
                if loss_mask is not None:
                    log_likelihood = log_likelihood.squeeze(-1) * loss_mask
                    denom = loss_mask.sum().clamp(min=1)
                else:
                    denom = log_likelihood.numel()
                loss = -log_likelihood.sum() / denom
            else:
                # During inference, no loss or log_likelihood computation
                log_likelihood = None
                loss = None

            samples = None
            if num_samples > 0:
                B, T = zt32.shape[:2]

                def flat(x: torch.Tensor) -> torch.Tensor:
                    return x.view(-1, self.num_components, self.dim_y)

                samples = self._sample_mixture(
                    flat(mean),
                    flat(std),
                    flat(weights),
                    num_sample=num_samples,
                )
                samples = samples.permute(1, 0, 2).view(B, T, num_samples, self.dim_y)

        return LossAttr(
            log_likelihood=log_likelihood,
            loss=loss,
            means=mean,
            sds=std,
            weights=weights,
            samples=samples,
        )

    def sample(self, zt: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
        """Generate samples from the mixture Gaussian.
        
        Computes sampling in float32 for stability.
        """
        device_type = "cuda" if zt.is_cuda else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            zt32 = zt.float()
            means, stds, weights = self._parameterize(zt32)
            B, T, _ = zt32.shape

            samples = self._sample_mixture(
                self._flat(means),
                self._flat(stds),
                self._flat(weights),
                num_sample=num_samples
            )
            samples = samples.permute(1, 0, 2).view(B, T, num_samples, self.dim_y)

        return samples

    def _flat(self, x: torch.Tensor) -> torch.Tensor:
        """Flatten tensor for mixture sampling."""
        return x.view(-1, self.num_components, self.dim_y)

    def add_bias(
        self, raw_mean: torch.Tensor, raw_std: torch.Tensor, raw_weights: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Add global biases to raw mixture parameters."""
        # Reshape biases to [1, 1, K, 1] for correct broadcasting with [B, T, K, dim_y]
        mean_bias_reshaped = self.mean_global_bias[None, None, :, None]
        std_bias_reshaped = self.std_global_bias[None, None, :, None]
        weights_bias_reshaped = self.weights_global_bias[None, None, :, None]

        mean = raw_mean + mean_bias_reshaped
        # Cap std for stability without creating device-mismatched tensors
        std = F.softplus(raw_std + std_bias_reshaped).clamp_max(2.0) + self.std_min
        # Clamp logits before normalization to avoid extreme softmax
        weights_logits = raw_weights + weights_bias_reshaped
        weights_logits = weights_logits.clamp(-self.logit_clip, self.logit_clip)
        # Use log_softmax for stability, then exponentiate when probabilities are needed
        log_weights = F.log_softmax(weights_logits, dim=2)
        weights = log_weights.exp()
        return mean, std, weights

    def _parameterize(
        self, zt: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Convert network outputs to mixture parameters."""
        B, T, _ = zt.shape
        outputs = self.head(zt).view(B, T, self.num_components, 3, self.dim_y)
        outputs = outputs.permute(0, 1, 3, 2, 4)
        means, stds, weights = outputs.unbind(dim=2)
        return self.add_bias(means, stds, weights)

    @staticmethod
    def _loglikelihood(
        value: torch.Tensor,
        means: torch.Tensor,
        stds: torch.Tensor,
        weights: torch.Tensor,
    ) -> torch.Tensor:
        """Compute log-likelihood of values under the mixture distribution."""
        # Unsqueeze value at dim 2 (component dimension) for broadcasting
        value = value.unsqueeze(2)  # Shape: [B, T, 1, dim_y]
        # [B, T, 1, dim_y] can broadcast with means/stds [B, T, K, dim_y]
        components = Normal(means, stds, validate_args=False)
        # log_prob result shape: [B, T, K, dim_y]
        log_probs = components.log_prob(value) + torch.log(weights.clamp(min=1e-12))
        # logsumexp over K dimension (dim=2) -> [B, T, dim_y]
        return torch.logsumexp(log_probs, dim=2)

    @staticmethod
    def _inverse_softplus(y: torch.Tensor) -> torch.Tensor:
        """Compute inverse softplus function."""
        return torch.log(torch.special.expm1(y))

    @staticmethod
    def _sample_mixture(
        means: torch.Tensor,
        stds: torch.Tensor,
        weights: torch.Tensor,
        num_sample: int = 1000,
    ) -> torch.Tensor:
        """Sample from the mixture distribution."""
        # weights shape: [B*T, K, 1], squeeze to [B*T, K] for Categorical
        mixture = Categorical(weights.squeeze(-1))

        # component_indices shape: [num_sample, B*T]
        component_indices = mixture.sample((num_sample,))

        # Gather the parameters of the chosen components
        # Unsqueeze component_indices to match the dimensions of means/stds for gathering
        # Shape: [num_sample, B*T, 1, 1] -> expand to [num_sample, B*T, 1, dim_y]
        indices_expanded = (
            component_indices.unsqueeze(-1)
            .unsqueeze(-1)
            .expand(-1, -1, -1, means.shape[-1])
        )

        # means/stds shape: [B*T, K, dim_y] -> unsqueeze to [1, B*T, K, dim_y] and expand
        means_expanded = means.unsqueeze(0).expand(num_sample, -1, -1, -1)
        stds_expanded = stds.unsqueeze(0).expand(num_sample, -1, -1, -1)

        chosen_means = torch.gather(means_expanded, 2, indices_expanded).squeeze(2)
        chosen_stds = torch.gather(stds_expanded, 2, indices_expanded).squeeze(2)

        # Sample from the chosen components. Shape: [num_sample, B*T, dim_y]
        samples = Normal(chosen_means, chosen_stds).sample()

        return samples


class MultiChannelMixtureGaussian(torch.nn.Module):
    """
    Gaussian mixture model for multi-channel outputs.
    
    Each output channel has its own mixture parameters, but they share
    the same input features. The log-likelihood is computed independently
    per channel then summed.
    
    Args:
        dim_y: Number of output channels (e.g., 7 for EEG)
        dim_model: Input feature dimension
        dim_feedforward: Hidden dimension for the head
        num_components: Number of mixture components per channel
        std_min: Minimum standard deviation for numerical stability
    """
    
    def __init__(
        self,
        dim_y: int,
        dim_model: int,
        dim_feedforward: int,
        num_components: int,
        std_min: float = 1e-3,
        logit_clip: float = 20.0,
    ):
        super().__init__()
        self.dim_y = dim_y
        self.dim_model = dim_model
        self.dim_feedforward = dim_feedforward
        self.num_components = num_components
        self.std_min = std_min
        self.logit_clip = logit_clip
        
        # Create a head that outputs parameters for all channels
        # Output: (mean, log_std, logit_weights) for each component and channel
        # Size: K * dim_y * 3 (mean + log_std + logit_weight)
        self.head = torch.nn.Sequential(
            torch.nn.Linear(dim_model, dim_feedforward),
            torch.nn.GELU(),
            torch.nn.Linear(dim_feedforward, num_components * dim_y * 3)
        )
        
        # Initialize the final layer with small values
        torch.nn.init.normal_(self.head[-1].weight, mean=0.0, std=0.02)
        torch.nn.init.zeros_(self.head[-1].bias)
        
    def forward(
        self,
        zt: torch.Tensor,  # (B, T, dim_model)
        yt: Optional[torch.Tensor] = None,  # (B, T, dim_y)
        loss_mask: Optional[torch.Tensor] = None,  # (B, T)
        num_samples: int = 0,
    ) -> LossAttr:
        """Forward pass computing mixture parameters and optionally log-likelihood.

        Forces GMM computations to float32 for stability.
        """
        device_type = "cuda" if zt.is_cuda else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            zt32 = zt.float()
            B, T, _ = zt32.shape
            K = self.num_components
            D = self.dim_y
            
            # Get raw parameters from head
            raw_params = self.head(zt32)  # (B, T, K*D*3)
            raw_params = raw_params.view(B, T, K, D, 3)
            
            # Split into components
            raw_means = raw_params[..., 0]  # (B, T, K, D)
            raw_log_stds = raw_params[..., 1]  # (B, T, K, D)
            raw_logit_weights = raw_params[..., 2]  # (B, T, K, D)
            
            # Process parameters
            means = raw_means  # No transformation needed
            # Ensure positive, cap at 2 in a device/dtype-safe way
            stds = F.softplus(raw_log_stds).clamp_max(2.0) + self.std_min
            
            # Normalize weights per channel using softmax over K dimension
            # Clamp logits before normalization
            raw_logit_weights = raw_logit_weights.clamp(-self.logit_clip, self.logit_clip)
            log_weights = F.log_softmax(raw_logit_weights, dim=2)  # (B, T, K, D)
            weights = log_weights.exp()
            
            # Compute log-likelihood if target is provided
            if yt is not None:
                yt32 = yt.float()
                # Expand yt for mixture components
                yt_expanded = yt32.unsqueeze(2).expand(-1, -1, K, -1)  # (B, T, K, D)
                
                # Compute log prob for each component and channel
                diff = yt_expanded - means
                log_probs = -0.5 * (
                    torch.log(2 * torch.pi * stds**2) + 
                    (diff**2) / (stds**2)
                )  # (B, T, K, D)
                
                # Add log weights
                weighted_log_probs = log_probs + log_weights  # (B, T, K, D)
                
                # Log-sum-exp over components for each channel
                channel_log_probs = torch.logsumexp(weighted_log_probs, dim=2)  # (B, T, D)
                
                # Sum over channels to get total log probability
                total_log_prob = channel_log_probs.sum(dim=-1)  # (B, T)
                
                # Apply mask if provided
                if loss_mask is not None:
                    masked_log_prob = total_log_prob * loss_mask
                    denom = loss_mask.sum().clamp(min=1)
                else:
                    masked_log_prob = total_log_prob
                    denom = total_log_prob.numel()
                
                loss = -masked_log_prob.sum() / denom
                log_likelihood = total_log_prob
            else:
                loss = None
                log_likelihood = None
            
            # Sampling
            samples = None
            if num_samples > 0:
                samples = self._sample(means, stds, weights, num_samples)
            
        return LossAttr(
            log_likelihood=log_likelihood,
            loss=loss,
            means=means,  # (B, T, K, D)
            sds=stds,      # (B, T, K, D) 
            weights=weights,  # (B, T, K, D)
            samples=samples,
        )
    
    def sample(self, zt: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
        """Generate samples from the multi-channel mixture Gaussian.
        
        Computes sampling in float32 for stability.
        """
        device_type = "cuda" if zt.is_cuda else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            zt32 = zt.float()
            B, T, _ = zt32.shape
            K = self.num_components
            D = self.dim_y
            
            # Get raw parameters from head
            raw_params = self.head(zt32)  # (B, T, K*D*3)
            raw_params = raw_params.view(B, T, K, D, 3)
            
            # Split into components
            raw_means = raw_params[..., 0]  # (B, T, K, D)
            raw_log_stds = raw_params[..., 1]  # (B, T, K, D)
            raw_logit_weights = raw_params[..., 2]  # (B, T, K, D)
            
            # Process parameters
            means = raw_means  # No transformation needed
            stds = F.softplus(raw_log_stds).clamp_max(2.0) + self.std_min
            
            # Normalize weights per channel using softmax over K dimension
            # Clamp logits before normalization
            raw_logit_weights = raw_logit_weights.clamp(-self.logit_clip, self.logit_clip)
            weights = F.softmax(raw_logit_weights, dim=2)  # (B, T, K, D)
            
            # Sample from the mixture
            return self._sample(means, stds, weights, num_samples)
    
    def _sample(
        self, 
        means: torch.Tensor,  # (B, T, K, D)
        stds: torch.Tensor,   # (B, T, K, D)
        weights: torch.Tensor,  # (B, T, K, D)
        num_samples: int
    ) -> torch.Tensor:
        """Sample from the mixture distribution."""
        B, T, K, D = means.shape
        
        # For each channel, sample component indices based on weights
        # Reshape for easier sampling
        weights_flat = weights.view(B * T * D, K)
        
        # Sample component indices
        component_samples = torch.multinomial(
            weights_flat, 
            num_samples, 
            replacement=True
        )  # (B*T*D, num_samples)
        
        # Reshape back
        component_samples = component_samples.view(B, T, D, num_samples)
        
        # Gather means and stds for selected components
        means_flat = means.permute(0, 1, 3, 2).reshape(B*T*D, K)
        stds_flat = stds.permute(0, 1, 3, 2).reshape(B*T*D, K)
        
        selected_means = torch.gather(
            means_flat, 1, 
            component_samples.view(B*T*D, num_samples)
        ).view(B, T, D, num_samples)
        
        selected_stds = torch.gather(
            stds_flat, 1,
            component_samples.view(B*T*D, num_samples)
        ).view(B, T, D, num_samples)
        
        # Sample from Gaussians
        eps = torch.randn_like(selected_means)
        samples = selected_means + selected_stds * eps
        
        # Reshape to (B, T, num_samples, D)
        samples = samples.permute(0, 1, 3, 2)
        
        return samples


class MultiheadAttention(torch.nn.Module):
    """Multi-head attention with flexible key-value head configuration."""

    def __init__(
        self,
        in_features: int,
        num_heads: int,
        head_dim: Optional[int] = None,
        out_features: Optional[int] = None,
        key_features: Optional[int] = None,
        value_features: Optional[int] = None,
        num_kv_heads: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features or in_features
        self.num_heads = num_heads
        self.head_dim = head_dim or in_features // num_heads
        self.num_kv_heads = num_kv_heads or num_heads

        self.q_proj = torch.nn.Linear(
            in_features, num_heads * self.head_dim, bias=False
        )
        self.k_proj = torch.nn.Linear(
            key_features or in_features, self.num_kv_heads * self.head_dim, bias=False
        )
        self.v_proj = torch.nn.Linear(
            value_features or in_features, self.num_kv_heads * self.head_dim, bias=False
        )
        self.o_proj = torch.nn.Linear(
            num_heads * self.head_dim, self.out_features, bias=False
        )

    def forward(
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_mask: BlockMask
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Forward pass with flex attention."""
        B = q.size(0)
        qh = (
            self.q_proj(q)
            .view(B, q.size(1), self.num_heads, self.head_dim)
            .transpose(1, 2)
        )
        kh = (
            self.k_proj(k)
            .view(B, k.size(1), self.num_kv_heads, self.head_dim)
            .transpose(1, 2)
        )
        vh = (
            self.v_proj(v)
            .view(B, v.size(1), self.num_kv_heads, self.head_dim)
            .transpose(1, 2)
        )
        kh = expand_kv_heads(kh, self.num_heads // self.num_kv_heads)
        vh = expand_kv_heads(vh, self.num_heads // self.num_kv_heads)

        out = flex_attention(qh, kh, vh, block_mask=block_mask)
        out = out.transpose(1, 2).reshape(B, q.size(1), self.num_heads * self.head_dim)
        return self.o_proj(out), (kh, vh)


class TransformerLayer(torch.nn.Module):
    """Single transformer layer with attention and feed-forward."""

    def __init__(
        self,
        dim_model: int,
        num_head: int,
        *,
        dim_feedforward: int = 512,
        dropout: float = 0.0,
        layer_norm_eps: float = 1e-5,
        **mha_kw: Any,
    ) -> None:
        super().__init__()
        self.attn = MultiheadAttention(dim_model, num_head, **mha_kw)
        self.drop_attn = torch.nn.Dropout(dropout)
        self.ff1, self.ff2 = (
            torch.nn.Linear(dim_model, dim_feedforward),
            torch.nn.Linear(dim_feedforward, dim_model),
        )
        self.drop_ff = torch.nn.Dropout(dropout)
        self.norm1, self.norm2 = (
            torch.nn.LayerNorm(dim_model, eps=layer_norm_eps),
            torch.nn.LayerNorm(dim_model, eps=layer_norm_eps),
        )

    def forward(
        self, x: torch.Tensor, block_mask: BlockMask
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Forward pass through attention and feed-forward blocks."""
        y = self.norm1(x)
        attn_out, (k, v) = self.attn(y, y, y, block_mask)
        x = x + self.drop_attn(attn_out)
        y = self.norm2(x)
        return x + self.ff2(self.drop_ff(torch.nn.functional.gelu(self.ff1(y)))), (k, v)


class Transformer(torch.nn.Module):
    """Multi-layer transformer encoder with optional gradient checkpointing."""

    def __init__(
        self,
        num_layers: int,
        dim_model: int,
        num_head: int,
        *,
        dim_feedforward: int = 512,
        dropout: float = 0.0,
        gradient_checkpointing: bool = False,
        **mha_kw: Any,
    ) -> None:
        super().__init__()
        self.dim_model = dim_model
        self.layers = torch.nn.ModuleList(
            TransformerLayer(
                dim_model,
                num_head,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                **mha_kw,
            )
            for _ in range(num_layers)
        )
        self.norm = torch.nn.LayerNorm(dim_model)
        self.grad_ckpt = gradient_checkpointing

    def forward(
        self, x: torch.Tensor, block_mask: BlockMask
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        """Forward pass through all layers with optional gradient checkpointing."""
        kv_cache = []
        for lyr in self.layers:
            if self.grad_ckpt and self.training:
                x = torch.utils.checkpoint.checkpoint(lyr, x, block_mask)
            else:
                x, kv = lyr(x, block_mask)
                kv_cache.append(kv)
        return self.norm(x), kv_cache
