# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


import logging
import os
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function

from torchtitan.models.attention import build_attention, init_attention_mask
from torchtitan.protocols.train_spec import ModelProtocol

from .args import TransformerModelArgs
from .triton_attention import triton_attention_with_bias

logger = logging.getLogger(__name__)


@torch.no_grad()
def qk_clip_rescale_(attention_module, gamma_h: torch.Tensor, alpha: float = 0.5):
    """
    Apply QK-clip rescaling to attention projection weights.

    Args:
        attention_module: The attention module containing wq and wk layers
        gamma_h: Per-head scaling factors (n_heads,)
        alpha: Balance between query/key scaling (0.5 = equal scaling)
    """
    n_heads = attention_module.n_heads
    head_dim = attention_module.head_dim

    # Compute per-head scaling factors
    gq = gamma_h.pow(alpha)
    gk = gamma_h.pow(1.0 - alpha)

    # Scale Q projection weights
    # wq.weight shape: [n_heads * head_dim, hidden_dim]
    for h in range(n_heads):
        start_idx = h * head_dim
        end_idx = (h + 1) * head_dim
        attention_module.wq.weight.data[start_idx:end_idx, :] *= gq[h]

    # Scale K projection weights
    # wk.weight shape: [n_kv_heads * head_dim, hidden_dim]
    # For now assume n_kv_heads == n_heads (no GQA)
    n_kv_heads = attention_module.n_kv_heads
    kv_head_dim = attention_module.head_dim

    # Handle GQA case where n_kv_heads < n_heads
    for kv_h in range(n_kv_heads):
        # Map kv head to corresponding q head(s)
        q_heads_per_kv = n_heads // n_kv_heads
        # Use the scaling from the first corresponding q head
        corresponding_q_head = kv_h * q_heads_per_kv

        start_idx = kv_h * kv_head_dim
        end_idx = (kv_h + 1) * kv_head_dim
        attention_module.wk.weight.data[start_idx:end_idx, :] *= gk[
            corresponding_q_head
        ]


class CompressionEnergyTracker(Function):
    """
    Custom autograd function to track energy loss during gradient compression.

    In the forward pass, this performs decompression.
    In the backward pass, it measures the energy loss when compressing gradients.
    """

    @staticmethod
    def forward(
        ctx,
        input_tensor,
        rcv,
        fixed_tok_embeddings,
        transformer_block,
        parent_model=None,
    ):
        """
        Forward pass: perform decompression and store context for backward pass.

        Args:
            input_tensor: Compressed input tensor
            rcv: Right compression vector
            fixed_tok_embeddings: Fixed token embeddings
            transformer_block: Reference to the transformer block for energy tracking
            parent_model: Reference to the parent Transformer model for stable rank tracking
        """
        ctx.rcv = rcv
        ctx.fixed_tok_embeddings = fixed_tok_embeddings
        ctx.transformer_block = transformer_block
        ctx.parent_model = parent_model

        # Perform the actual decompression
        rcv_expanded = rcv.unsqueeze(0).clone()

        # Extract compressed representation and token indices
        x = input_tensor[:, :, :-1].transpose(2, 1)
        idx = input_tensor[:, :, -1:]
        tokens = idx.to(torch.int).squeeze(2).clone()

        # Get fixed embeddings
        fixed_embed = fixed_tok_embeddings(tokens)

        # Decompress: h ≈ rcv @ compressed + fixed_embeddings
        decompressed_output = (
            rcv_expanded @ x + fixed_embed.transpose(2, 1)
        ).transpose(2, 1)

        # Concatenate the tokens to the decompressed output
        output = torch.cat([decompressed_output, input_tensor[:, :, -1:]], dim=-1)

        # Store input for gradient energy measurement
        ctx.save_for_backward(input_tensor, output)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: measure energy loss during gradient compression.

        The gradient compression happens here as we project the full gradient
        back to the compressed space.
        """
        input_tensor, output = ctx.saved_tensors
        rcv = ctx.rcv
        fixed_tok_embeddings = ctx.fixed_tok_embeddings
        transformer_block = ctx.transformer_block

        if grad_output is None:
            return None, None, None, None

        # Extract the gradient for the decompressed part (excluding token indices)
        grad_decompressed = grad_output[:, :, :-1]  # Shape: [batch, seq, hidden_dim]
        grad_tokens = grad_output[:, :, -1:]  # Token indices gradient (should be zero)

        # Get the original tokens for fixed embeddings gradient
        idx = input_tensor[:, :, -1:]
        tokens = idx.to(torch.int).squeeze(2).clone()

        # Compute gradient w.r.t. fixed embeddings (this will be discarded as fixed_embeddings.requires_grad=False)
        # But we need it for the compression operation
        fixed_embed = fixed_tok_embeddings(tokens)

        # Measure energy before compression (original gradient energy)
        original_grad_energy = (
            torch.norm(grad_decompressed, p="fro").item() ** 2  # noqa: TOR101
        )

        # Perform gradient compression: grad_compressed = rcv.T @ (grad_decompressed - grad_fixed_embed)
        rcv_expanded = rcv.unsqueeze(0).clone()

        # The gradient w.r.t. the compressed representation
        # grad_decompressed = rcv @ grad_compressed + grad_fixed_embed
        # So: grad_compressed = rcv.T @ (grad_decompressed - grad_fixed_embed)

        # Since fixed embeddings don't have gradients, we approximate:
        # grad_compressed ≈ rcv.T @ grad_decompressed
        grad_compressed = (
            rcv_expanded.transpose(2, 1) @ grad_decompressed.transpose(2, 1)
        ).transpose(2, 1)

        # Measure energy after compression by reconstructing the gradient
        reconstructed_grad = (rcv_expanded @ grad_compressed.transpose(2, 1)).transpose(
            2, 1
        )
        compressed_grad_energy = (
            torch.norm(reconstructed_grad, p="fro").item() ** 2  # noqa: TOR101
        )

        # Calculate energy loss
        energy_loss = original_grad_energy - compressed_grad_energy
        energy_loss_ratio = energy_loss / (
            original_grad_energy + 1e-8
        )  # Avoid division by zero

        # Update real-time energy loss tracking in the transformer block
        # For current step, we accumulate measurements and will average them.
        transformer_block.current_compression_energy_loss += energy_loss_ratio
        transformer_block.compression_measurements_count += 1

        # Store the no-compression upstream grad for the *previous* block
        parent_model = ctx.parent_model
        if parent_model is not None and hasattr(transformer_block, "layer_id"):
            # Handle wrapped models (DDP, FSDP) by accessing the underlying module
            actual_model = getattr(parent_model, "module", parent_model)
            prev_layer = transformer_block.layer_id - 1
            if prev_layer >= 0 and getattr(
                actual_model.model_args, "track_uncompressed_w2_stable_rank", False
            ):
                actual_model.store_uncompressed_grad(prev_layer, grad_decompressed)

        # Create the gradient for the input (compressed representation + token indices)
        grad_input = torch.cat([grad_compressed, grad_tokens], dim=-1)

        return grad_input, None, None, None, None


def build_norm(norm_type: str, dim: int, eps: float = 1e-6, trainable: bool = True):
    """
    Builds the specified normalization layer based on the norm_type.

    Args:
        norm_type (str): The type of normalization layer to build.
            Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm
        dim (int): The dimension of the normalization layer.
        eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.

    Returns:
        The built normalization layer.

    Raises:
        NotImplementedError: If an unknown norm_type is provided.
    """
    norm_type = norm_type.lower()  # Normalize to lowercase

    if norm_type == "layernorm":
        return nn.LayerNorm(dim, eps=eps, bias=False)
    elif norm_type == "np_layernorm":
        return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
    elif norm_type == "rmsnorm":
        return RMSNorm(dim, eps=eps, trainable=trainable)
    else:
        raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")


class RMSNorm(nn.Module):
    """
    Initialize the RMSNorm normalization layer.

    Args:
        dim (int): The dimension of the input tensor.
        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

    Attributes:
        eps (float): A small value added to the denominator for numerical stability.
        weight (nn.Parameter): Learnable scaling parameter.

    """

    def __init__(self, dim: int, eps: float = 1e-6, trainable=True):
        super().__init__()
        self.eps = eps
        self.trainable = trainable
        if trainable:
            self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight if self.trainable else output

    def reset_parameters(self):
        if self.trainable:
            torch.nn.init.ones_(self.weight)  # type: ignore


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponentials.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
    and the first seqlen elements will be sliced, but dim must match x.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    seqlen = x.shape[1]
    freqs_cis = freqs_cis[0:seqlen]
    assert freqs_cis.shape == (seqlen, x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        torch.unsqueeze(x, dim=3)
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )


def learnable_bias_softmax(
    scores: torch.Tensor, bias: torch.Tensor, dim: int = -1
) -> torch.Tensor:
    """
    Apply softmax with learnable bias in the denominator.

    Formula: softmax_bias(x)_i = exp(x_i) / (bias + sum_j(exp(x_j)))

    Args:
        scores (torch.Tensor): Input scores tensor
        bias (torch.Tensor): Learnable bias parameter for each head
        dim (int): Dimension along which to apply softmax

    Returns:
        torch.Tensor: Softmax probabilities with learnable bias
    """
    # Compute exponentials
    exp_scores = torch.exp(scores)

    # Sum along the specified dimension, keeping dimensions for broadcasting
    sum_exp = torch.sum(exp_scores, dim=dim, keepdim=True)

    # Add bias to denominator - bias should be broadcastable with sum_exp
    # bias shape: (bs, n_heads, 1, 1) to broadcast with sum_exp: (bs, n_heads, seqlen, 1)
    denominator = bias.unsqueeze(-1).unsqueeze(-1) + sum_exp

    # Apply softmax with bias
    return exp_scores / denominator


class Attention(nn.Module):
    """
    Multi-head attention module with optional learnable bias softmax.

    Args:
        model_args (TransformerModelArgs): Model configuration arguments.
        layer_id (int): Layer identifier for projection matrices.

    Attributes:
        n_kv_heads (int): Number of key and value heads.
        n_heads (int): Number of query heads.
        n_rep (int): Number of repetitions for local heads.
        head_dim (int): Dimension size of each attention head.
        wq (Linear): Linear transformation for queries.
        wk (Linear): Linear transformation for keys.
        wv (Linear): Linear transformation for values.
        wo (Linear): Linear transformation for output.
        q_norm (nn.Module): Query normalization layer.
        k_norm (nn.Module): Key normalization layer.
        use_learnable_bias_softmax (bool): Whether to use learnable bias softmax.
        softmax_bias (nn.Parameter): Learnable bias parameters for softmax denominator.

    """

    def __init__(self, model_args: TransformerModelArgs, layer_id: int):
        super().__init__()
        self.n_heads = model_args.n_heads
        self.n_kv_heads = (
            model_args.n_heads
            if model_args.n_kv_heads is None
            else model_args.n_kv_heads
        )
        self.layer_id = layer_id
        self.n_rep = self.n_heads // self.n_kv_heads
        self.head_dim = model_args.hidden_dim // model_args.n_heads

        self.wq = nn.Linear(
            model_args.hidden_dim, model_args.n_heads * self.head_dim, bias=False
        )
        self.wk = nn.Linear(
            model_args.hidden_dim, self.n_kv_heads * self.head_dim, bias=False
        )
        self.wv = nn.Linear(
            model_args.hidden_dim, self.n_kv_heads * self.head_dim, bias=False
        )
        self.wo = nn.Linear(
            model_args.n_heads * self.head_dim, model_args.hidden_dim, bias=False
        )

        self.hidden_dim = model_args.hidden_dim
        self.attn_proj = model_args.attn_proj
        self.qk_norm = model_args.qk_norm
        self.use_triton_attention = model_args.use_triton_attention

        self.use_softcap = model_args.use_softcap
        self.cap_threshold = model_args.cap_threshold

        self.trainable_rmsnorm = model_args.trainable_rmsnorm

        # QK-norm
        if model_args.qk_norm:
            self.q_norm = build_norm(
                norm_type=model_args.norm_type,
                dim=self.head_dim,
                eps=model_args.norm_eps,
                trainable=self.trainable_rmsnorm,
            )
            self.k_norm = build_norm(
                norm_type=model_args.norm_type,
                dim=self.head_dim,
                eps=model_args.norm_eps,
                trainable=self.trainable_rmsnorm,
            )

        # Learnable bias softmax
        self.use_learnable_bias_softmax = model_args.use_learnable_bias_softmax
        if self.use_learnable_bias_softmax:
            # Initialize bias parameter for each attention head
            # For Triton: Shape (n_heads,), for manual: Shape (1, n_heads)
            if self.use_triton_attention:
                self.softmax_bias = nn.Parameter(torch.ones(model_args.n_heads))
            else:
                self.softmax_bias = nn.Parameter(torch.ones(1, model_args.n_heads))

        if model_args.use_flex_attn and not self.use_learnable_bias_softmax:
            # Keep original FlexAttention functionality for backward compatibility
            # but disable if using learnable bias softmax
            self.sdpa = build_attention(
                use_flex_attn=model_args.use_flex_attn,
                attn_mask_type=model_args.attn_mask_type,
                fixed_block_size=None,
                use_softcap=self.use_softcap,
                cap_threshold=self.cap_threshold,
            )
        else:
            self.sdpa = None

    def init_weights(self, init_std: float):
        for linear in (self.wq, self.wk, self.wv):
            nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
        nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)

        # Initialize softmax bias to 1.0 (equivalent to standard softmax initially)
        if self.use_learnable_bias_softmax:
            nn.init.ones_(self.softmax_bias)

        if self.attn_proj:
            rng = torch.Generator()
            rng.manual_seed(1337 + self.layer_id)

            Rk = torch.randn(
                (
                    self.n_kv_heads * self.head_dim,
                    int(self.n_kv_heads * self.head_dim / 4),
                ),
                dtype=torch.float32,
                generator=rng,
            )
            self.Rk, _ = torch.linalg.qr(Rk, mode="reduced")

            Rq = torch.randn(
                (self.hidden_dim, int(self.hidden_dim / 4)),
                dtype=torch.float32,
                generator=rng,
            )
            self.Rq, _ = torch.linalg.qr(Rq, mode="reduced")

    def _manual_attention_softcap(
        self,
        xq: torch.Tensor,
        xk: torch.Tensor,
        xv: torch.Tensor,
        is_causal: bool = True,
        cap_threshold: float | None = None,  # e.g. 20.0 to enable; None to disable
        softcap_per_head: torch.Tensor | None = None,  # optional, shape [n_heads]
        dropout_p: float = 0.0,
    ) -> torch.Tensor:
        """
        Manual attention with optional tanh-softcap on logits.

        Args:
            xq, xk, xv: (bs, n_heads, seqlen, head_dim)
            is_causal: apply causal mask
            softcap: scalar cap (applies to all heads) if provided
            softcap_per_head: per-head caps, tensor of shape [n_heads]; overrides `softcap` if given
        """
        bs, n_heads, seqlen, head_dim = xq.shape

        scores = torch.matmul(xq, xk.transpose(-2, -1)) * (
            head_dim**-0.5
        )  # (bs, n_heads, seqlen, seqlen)

        if softcap_per_head is not None:

            cap = softcap_per_head.view(1, n_heads, 1, 1).to(
                device=scores.device, dtype=scores.dtype
            )

            cap_safe = torch.where(cap == 0, torch.ones_like(cap), cap)
            scores = cap * torch.tanh(scores / cap_safe)
        elif cap_threshold is not None:
            cap = torch.as_tensor(
                cap_threshold, device=scores.device, dtype=scores.dtype
            ).view(1, 1, 1, 1)
            scores = cap * torch.tanh(scores / cap)

        if is_causal:
            causal = torch.triu(
                torch.ones(seqlen, seqlen, device=scores.device, dtype=torch.bool),
                diagonal=1,
            )
            scores = scores.masked_fill(causal, float("-inf"))

        if getattr(self, "use_learnable_bias_softmax", False):
            bias_expanded = self.softmax_bias.expand(bs, -1)  # (bs, n_heads)
            attn_weights = learnable_bias_softmax(scores, bias_expanded, dim=-1)
        else:
            attn_weights = F.softmax(scores, dim=-1)

        if dropout_p:
            attn_weights = F.dropout(attn_weights, p=dropout_p, training=True)

        output = torch.matmul(attn_weights, xv)  # (bs, n_heads, seqlen, head_dim)
        return output

    def _manual_attention(
        self,
        xq: torch.Tensor,
        xk: torch.Tensor,
        xv: torch.Tensor,
        is_causal: bool = True,
    ) -> torch.Tensor:
        """
        Manual attention computation with learnable bias softmax.

        Args:
            xq: Query tensor (bs, n_heads, seqlen, head_dim)
            xk: Key tensor (bs, n_heads, seqlen, head_dim)
            xv: Value tensor (bs, n_heads, seqlen, head_dim)
            is_causal: Whether to apply causal masking

        Returns:
            Attention output tensor
        """
        bs, n_heads, seqlen, head_dim = xq.shape

        # Compute attention scores: Q @ K^T / sqrt(head_dim)
        scores = torch.matmul(xq, xk.transpose(-2, -1)) / (head_dim**0.5)

        # Apply causal mask if needed
        if is_causal:
            mask = torch.triu(
                torch.ones(seqlen, seqlen, device=scores.device, dtype=torch.bool),
                diagonal=1,
            )
            scores = scores.masked_fill(mask, float("-inf"))

        # Apply learnable bias softmax
        if self.use_learnable_bias_softmax:
            # Ensure softmax_bias has the right shape for broadcasting
            # scores: (bs, n_heads, seqlen, seqlen)
            # softmax_bias: (1, n_heads) -> need to expand to (bs, n_heads, 1, 1)
            bias_expanded = self.softmax_bias.expand(bs, -1)  # (bs, n_heads)
            attn_weights = learnable_bias_softmax(scores, bias_expanded, dim=-1)
        else:
            attn_weights = F.softmax(scores, dim=-1)

        # Apply attention weights to values
        output = torch.matmul(attn_weights, xv)

        return output

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
    ):
        """
        Forward pass of the attention module.

        Args:
            x (torch.Tensor): Input tensor.
            freqs_cis (torch.Tensor): Precomputed frequency tensor.

        Returns:
            torch.Tensor: Output tensor after attention.

        """
        bs, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        if self.attn_proj:
            self.Rk = self.Rk.to(xq.device)
            self.Rq = self.Rq.to(xq.device)
            xq = xq @ self.Rq[None, :, :] @ self.Rq[None, :, :].transpose(1, 2)
            xk = xk @ self.Rk[None, :, :] @ self.Rk[None, :, :].transpose(1, 2)
            xv = xv @ self.Rk[None, :, :] @ self.Rk[None, :, :].transpose(1, 2)

        # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
        # local heads from sizes of xq, xk, and xv as TP may have sharded them
        # after the above linear ops.
        xq = xq.view(bs, seqlen, -1, self.head_dim)
        xk = xk.view(bs, seqlen, -1, self.head_dim)
        xv = xv.view(bs, seqlen, -1, self.head_dim)

        # Normalize across the head dimension (last dimension)
        if self.qk_norm:
            xq = self.q_norm(xq)
            xk = self.k_norm(xk)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        values = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        xk = keys.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        xv = values.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)

        # Choose attention implementation
        if self.use_learnable_bias_softmax and self.use_triton_attention:
            # Use Triton attention with learnable bias softmax
            # Convert from (bs, n_heads, seqlen, head_dim) to (bs, seqlen, n_heads, head_dim)
            xq_triton = xq.transpose(1, 2).contiguous()
            xk_triton = xk.transpose(1, 2).contiguous()
            xv_triton = xv.transpose(1, 2).contiguous()

            sm_scale = 1.0 / (self.head_dim**0.5)

            # Check if we need to return metrics for QK-clip
            return_metrics = (
                hasattr(self, "_return_attention_metrics")
                and self._return_attention_metrics
            )

            if return_metrics:
                output_triton, metrics = triton_attention_with_bias(
                    xq_triton,
                    xk_triton,
                    xv_triton,
                    self.softmax_bias,
                    sm_scale,
                    return_metrics=True,
                )
                # Store metrics for later use
                self._last_attention_metrics = metrics
            else:
                output_triton = triton_attention_with_bias(
                    xq_triton, xk_triton, xv_triton, self.softmax_bias, sm_scale
                )

            # output_triton is already (bs, seqlen, n_heads * head_dim)
            # Convert to (bs, n_heads, seqlen, head_dim) for consistency with other paths
            output = output_triton.view(bs, seqlen, -1, self.head_dim).transpose(1, 2)
        elif self.use_learnable_bias_softmax:
            # Use custom attention with learnable bias softmax
            output = self._manual_attention(xq, xk, xv, is_causal=True)
        elif self.use_triton_attention:
            # Use standard triton attention (without learnable bias)
            # Convert from (bs, n_heads, seqlen, head_dim) to (bs, seqlen, n_heads, head_dim)
            xq_triton = xq.transpose(1, 2).contiguous()
            xk_triton = xk.transpose(1, 2).contiguous()
            xv_triton = xv.transpose(1, 2).contiguous()

            sm_scale = 1.0 / (self.head_dim**0.5)

            # Check if we need to return metrics for QK-clip
            return_metrics = (
                hasattr(self, "_return_attention_metrics")
                and self._return_attention_metrics
            )

            if return_metrics:
                output_triton, metrics = triton_attention_with_bias(
                    xq_triton, xk_triton, xv_triton, None, sm_scale, return_metrics=True
                )
                # Store metrics for later use
                self._last_attention_metrics = metrics
            else:
                output_triton = triton_attention_with_bias(
                    xq_triton, xk_triton, xv_triton, None, sm_scale
                )

            # output_triton is already (bs, seqlen, n_heads * head_dim)
            # Convert to (bs, n_heads, seqlen, head_dim) for consistency with other paths
            output = output_triton.view(bs, seqlen, -1, self.head_dim).transpose(1, 2)

        elif self.use_softcap and self.sdpa is not None:
            output = self.sdpa(xq, xk, xv)
        elif self.use_softcap and self.sdpa is None:
            output = self._manual_attention_softcap(
                xq, xk, xv, cap_threshold=self.cap_threshold, dropout_p=0.0
            )
        elif hasattr(self, "sdpa") and self.sdpa is not None:
            # Use FlexAttention if enabled
            output = self.sdpa(xq, xk, xv)
        else:
            # Use standard PyTorch scaled dot product attention
            output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)

        output = output.transpose(
            1, 2
        ).contiguous()  # (bs, seqlen, n_local_heads, head_dim)
        output = output.view(bs, seqlen, -1)
        return self.wo(output)


class FeedForward(nn.Module):
    """
    FeedForward module

    Args:
        dim (int): Input dimension.
        hidden_dim (int): Hidden dimension of the feedforward layer.
        multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
        ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.

    Attributes:
        w1 (Linear): Linear transformation for the first layer.
        w2 (Linear): Linear transformation for the second layer.
        w3 (Linear): Linear transformation for the third layer.
        _last_z (torch.Tensor): Cached z activations for stable rank computation.

    """

    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self._last_z: Optional[torch.Tensor] = None

    def forward(self, x):
        z1 = F.silu(self.w1(x))
        z2 = self.w3(x)
        z = z1 * z2
        # self._last_z = z.detach()  # Disabled for torch.compile compatibility
        return self.w2(z)

    def init_weights(self, init_std: float):
        nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
        for linear in (self.w2, self.w3):
            nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)


class TransformerBlock(nn.Module):
    """
    TransformerBlock Module

    Args:
        layer_id (int): Identifier for the layer.
        model_args (TransformerModelArgs): Model configuration arguments.

    Attributes:
        n_heads (int): Number of attention heads.
        dim (int): Dimension size of the model.
        head_dim (int): Dimension size of each attention head.
        attention (Attention): Attention module.
        feed_forward (FeedForward): FeedForward module.
        layer_id (int): Identifier for the layer.
        attention_norm (RMSNorm): Layer normalization for attention output.
        ffn_norm (RMSNorm): Layer normalization for feedforward output.

    """

    def __init__(self, layer_id: int, model_args: TransformerModelArgs):
        super().__init__()
        self.n_heads = model_args.n_heads
        self.dim = model_args.hidden_dim

        self.norm_reorder = model_args.norm_reorder
        self.trainable_rmsnorm = model_args.trainable_rmsnorm

        self.attention = Attention(model_args, layer_id)
        self.feed_forward = FeedForward(
            dim=model_args.hidden_dim,
            hidden_dim=4 * model_args.hidden_dim,
            multiple_of=model_args.multiple_of,
            ffn_dim_multiplier=model_args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.num_layers = model_args.n_layers
        self.model_args = model_args

        self.attention_norm = build_norm(
            norm_type=model_args.norm_type,
            dim=model_args.hidden_dim,
            eps=model_args.norm_eps,
            trainable=self.trainable_rmsnorm,
        )
        self.ffn_norm = build_norm(
            norm_type=model_args.norm_type,
            dim=model_args.hidden_dim,
            eps=model_args.norm_eps,
            trainable=self.trainable_rmsnorm,
        )

        # Store compression-related information
        # First layer (0): does not decompress, but does compress
        # Layer n-2: does decompress, but does not compress
        # Last layer (n-1): does not decompress, does not compress

        self.needs_decompression = model_args.use_compression and not (layer_id == 0)
        self.needs_compression = model_args.use_compression and not (
            layer_id == model_args.n_layers - 1
        )

        if model_args.depth_init:
            self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
        else:
            self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

        self.current_compression_energy_loss = 0.0
        self.compression_measurements_count = 0

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        rcv: torch.Tensor,
        fixed_tok_embeddings: nn.Embedding,
        tokens_tensor: Optional[torch.Tensor] = None,
        parent_model: Optional["Transformer"] = None,
    ):
        """
        Perform a forward pass through the TransformerBlock.

        Args:
            x (torch.Tensor): Input tensor.
            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
            rcv (torch.Tensor): Right compression vector.
            fixed_tok_embeddings (nn.Embedding): Fixed token embeddings for compression.
            tokens_tensor (torch.Tensor, optional): Token indices for compression.

        Returns:
            torch.Tensor: Output tensor after applying attention and feedforward layers.

        """
        # Handle decompression at the beginning if needed (layers > 0)
        if self.needs_decompression and rcv is not None:
            if (
                x.shape[-1]
                == int(self.model_args.hidden_dim // self.model_args.compression_rate)
                + 1
            ):
                x = self.decompress_input(x, rcv, fixed_tok_embeddings, parent_model)
                # Extract token indices and hidden states
                tokens_tensor = x[:, :, -1:].clone()
                x = x[:, :, :-1]

        # Apply standard transformer block operations
        # (Optionally) apply compression to the attention outputs
        if self.needs_compression and rcv is not None and False:
            rcv_ = rcv.unsqueeze(0).clone()
            pre_attention = (
                self.attention(self.attention_norm(x), freqs_cis)
                @ rcv_
                @ rcv_.transpose(2, 1)
            )
        else:
            if self.norm_reorder:
                pre_attention = self.attention_norm(self.attention(x, freqs_cis))
            else:
                pre_attention = self.attention(self.attention_norm(x), freqs_cis)

        # Collect attention metrics if enabled
        if (
            parent_model is not None
            and hasattr(parent_model, "_collect_attention_metrics")
            and parent_model._collect_attention_metrics
            and hasattr(self.attention, "_last_attention_metrics")
        ):
            parent_model._attention_metrics[
                self.layer_id
            ] = self.attention._last_attention_metrics

        h = x + pre_attention

        ffn_in = h
        if self.norm_reorder:
            ffn_out = self.ffn_norm(
                self.feed_forward(ffn_in)
            )  # computes & caches self.feed_forward._last_z
        else:
            ffn_out = self.feed_forward(
                self.ffn_norm(ffn_in)
            )  # computes & caches self.feed_forward._last_z
        out = h + ffn_out

        # Hook to compute "no-compression" w2 grad stable rank for this layer
        if parent_model is not None:
            actual_model = getattr(parent_model, "module", parent_model)
            if getattr(
                actual_model.model_args, "track_uncompressed_w2_stable_rank", False
            ):

                def _w2_sr_hook(grad_actual):
                    # We ignore 'grad_actual' (it's affected by compression); use stored no-compression grad.
                    try:
                        # Handle wrapped models by accessing the underlying module
                        actual_model = getattr(parent_model, "module", parent_model)
                        Gy = actual_model.uncompressed_upstream_grads.get(
                            self.layer_id, None
                        )
                        Z = getattr(self.feed_forward, "_last_z", None)
                        if Gy is None or Z is None:
                            return
                        # Shapes: Gy [B,S,dim], Z [B,S,hidden_dim]
                        with torch.no_grad():
                            B, S, D = Gy.shape
                            _, _, H = Z.shape
                            Gy_f = Gy.reshape(B * S, D).float()
                            Z_f = Z.reshape(B * S, H).float()

                            # dW2 (no-compression): [D,H] = (Gy_f)^T @ Z_f
                            dW = Gy_f.transpose(0, 1) @ Z_f  # [D, H]

                            # Frobenius norm squared
                            fro2 = (dW * dW).sum().item()
                            if self.model_args.compute_exact_svd:
                                s_max = torch.linalg.svdvals(dW).max()
                                s_max_sq = s_max * s_max
                            else:
                                # Spectral norm via lightweight power iteration (avoids SVD OOM)
                                # Compute s_max of dW using 5 iters of power method on (dW^T dW)
                                v = torch.randn(
                                    (dW.shape[1],), device=dW.device, dtype=dW.dtype
                                )
                                v = v / (v.norm() + 1e-12)
                                for _ in range(5):
                                    v = dW.t().matmul(dW.matmul(v))
                                    v_norm = v.norm()
                                    if v_norm.item() == 0.0 or torch.isnan(v_norm):
                                        break
                                    v = v / (v_norm + 1e-12)

                                # Rayleigh quotient to get s_max^2 approx
                                # s_max^2 = v^T dW^T dW v / v^T v
                                Av = dW.matmul(v)
                                s_max_sq = float((Av * Av).sum().item())
                            if s_max_sq <= 0.0 or not torch.isfinite(
                                torch.tensor(s_max_sq)
                            ):
                                return
                            sr = fro2 / (s_max_sq + 1e-12)
                            actual_model.store_w2_stable_rank(self.layer_id, float(sr))
                    except Exception:
                        # Silently skip to never break training
                        pass

                ffn_out.register_hook(_w2_sr_hook)

        # Handle compression at the end if needed (layers < n_layers - 1)
        if self.needs_compression and rcv is not None and tokens_tensor is not None:
            out = self.compress_output(out, tokens_tensor, rcv, fixed_tok_embeddings)

        return out

    def init_weights(self):
        for norm in (self.attention_norm, self.ffn_norm):
            norm.reset_parameters()
        self.attention.init_weights(self.weight_init_std)
        self.feed_forward.init_weights(self.weight_init_std)

    def decompress_input(
        self,
        input: torch.Tensor,
        rcv: torch.Tensor,
        fixed_tok_embeddings: nn.Embedding,
        parent_model: Optional["Transformer"] = None,
    ) -> torch.Tensor:
        """Decompress input using the compression vector from parent model"""
        if not self.needs_decompression or rcv is None or fixed_tok_embeddings is None:
            return input

        if self.model_args.track_compression_energy:
            return CompressionEnergyTracker.apply(
                input, rcv, fixed_tok_embeddings, self, parent_model
            )
        else:
            rcv = rcv.unsqueeze(0).clone()

            # Extract compressed representation and token indices
            x = input[:, :, :-1].transpose(2, 1)
            idx = input[:, :, -1:]
            tokens = idx.to(torch.int).squeeze(2).clone()

            # Get fixed embeddings
            fixed_embed = fixed_tok_embeddings(tokens)

            # Decompress: h ≈ rcv @ compressed + fixed_embeddings
            decompressed_output = (rcv @ x + fixed_embed.transpose(2, 1)).transpose(
                2, 1
            )

            # Concatenate the tokens to the decompressed output
            return torch.cat([decompressed_output, input[:, :, -1:]], dim=-1)

    def compress_output(
        self,
        output: torch.Tensor,
        tokens_tensor: torch.Tensor,
        rcv: torch.Tensor,
        fixed_tok_embeddings: nn.Embedding,
    ) -> torch.Tensor:
        """Compress output using the compression vector from parent model"""
        if not self.needs_compression or rcv is None or fixed_tok_embeddings is None:
            return output

        rcv = rcv.unsqueeze(0).clone()

        # Attach token indices to output
        output_with_tokens = torch.cat([output, tokens_tensor], dim=-1)

        # Extract output and token indices
        x = output_with_tokens[:, :, :-1]
        idx = output_with_tokens[:, :, -1:]
        tokens = idx.to(torch.int).squeeze(2)

        # Get fixed embeddings
        fixed_embed = fixed_tok_embeddings(tokens)

        # Compress: compressed ≈ rcv.T @ (output - fixed_embeddings)
        compressed_output = (
            rcv.transpose(2, 1) @ (x - fixed_embed).transpose(2, 1)
        ).transpose(2, 1)

        # Concatenate token indices back to the compressed output
        return torch.cat([compressed_output, output_with_tokens[:, :, -1:]], dim=-1)

    def get_compression_energy_loss(self) -> float:
        """
        Get the current step's average energy loss ratio for gradient compression in this layer.

        Returns:
            float: Current step's average energy loss ratio (0.0 to 1.0, where 1.0 means 100% energy lost)
        """
        if self.compression_measurements_count == 0:
            return 0.0
        return (
            self.current_compression_energy_loss / self.compression_measurements_count
        )

    def reset_compression_energy_stats(self) -> None:
        """Reset the compression energy loss statistics for the next optimization step."""
        self.current_compression_energy_loss = 0.0
        self.compression_measurements_count = 0


class Transformer(nn.Module, ModelProtocol):
    """
    Transformer Module

    Args:
        model_args (TransformerModelArgs): Model configuration arguments.

    Attributes:
        model_args (TransformerModelArgs): Model configuration arguments.
        vocab_size (int): Vocabulary size.
        n_layers (int): Number of layers in the model.
        tok_embeddings (nn.Embedding): Token embeddings.
        layers (torch.nn.ModuleDict): Dictionary of Transformer blocks.
        norm (RMSNorm): Layer normalization for the model output.
        output (nn.Linear): Linear layer for final output.
        freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
        rcv (torch.Tensor): Right compression vector
        fixed_tok_embeddings (nn.Embedding): Fixed token embeddings for compression
    """

    def __init__(self, model_args: TransformerModelArgs):
        super().__init__()
        self.model_args = model_args
        self.vocab_size = model_args.vocab_size
        self.n_layers = model_args.n_layers

        # Storage for uncompressed w2 stable rank tracking
        self.uncompressed_upstream_grads: dict[
            int, torch.Tensor
        ] = {}  # layer_id -> [B,S,dim]
        self.w2_stable_rank_sum: dict[int, float] = {}
        self.w2_stable_rank_count: dict[int, int] = {}
        self.stable_rank_w2: dict[int, float] = {}  # running avg per layer

        self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.hidden_dim)

        if model_args.use_learnable_sink_token:
            # Initialize with smaller variance and zero mean for stability
            self.learnable_sink_token = nn.Parameter(torch.zeros(model_args.hidden_dim))
        else:
            self.learnable_sink_token = None

        # TODO persistent should be set to false, since this buffer can be recomputed.
        # however, we set it to true for 2 reasons.  (1) due to pytorch/pytorch#123411,
        # compile or pipeline-tracer will not correctly handle non-persistent buffers,
        # so we need to fix that.  (2) if we initialize pipeline-parallel models from
        # a seed checkpoint rather than calling init_weights, we need freqs_cis to be
        # initialized by the checkpoint, or we need to add a separate initializer for
        # just the non-persistent buffers that is called after loading checkpoints.
        # Initialize compression components
        if model_args.use_compression:
            self.compression_length = int(
                model_args.hidden_dim // model_args.compression_rate
            )
            # If using a learnable sink token, allocate one extra fixed embedding row for it
            fixed_vocab_size = model_args.vocab_size + (
                1 if model_args.use_learnable_sink_token else 0
            )
            self.fixed_tok_embeddings = nn.Embedding(
                fixed_vocab_size, model_args.hidden_dim
            )
            # Define the sink token index used for compression bookkeeping
            self.sink_token_index = (
                model_args.vocab_size if model_args.use_learnable_sink_token else None
            )
            # Right compression vector as a non-trainable parameter so it is saved in checkpoints
            self.rcv = nn.Parameter(
                torch.empty(model_args.hidden_dim, self.compression_length),
                requires_grad=False,
            )
        else:
            self.fixed_tok_embeddings = None
            self.rcv = None

        # Precompute frequency tensor for positional embeddings
        self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)

        # Layers dictionary
        self.layers = torch.nn.ModuleDict()
        for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

        self.norm = build_norm(
            norm_type=model_args.norm_type,
            dim=model_args.hidden_dim,
            eps=model_args.norm_eps,
        )

        self.output = nn.Linear(
            model_args.hidden_dim, model_args.vocab_size, bias=False
        )

        self.init_weights()

        # QK-clip state
        self._collect_attention_metrics = False
        self._attention_metrics = {}  # layer_id -> metrics dict

    def init_weights(
        self,
        buffer_device: Optional[torch.device] = None,
    ):
        """
        [Note: On ``init_weights`` vs. ``reset_parameters``]
        Modules may define ``reset_parameters`` to initialize parameter values.
        ``reset_parameters`` is meant to only initialize directly owned
        parameters/buffers, not those of their child modules, and it can be
        used to give the initial values for these tensors.
        Separately, users may want custom initialization for their modules,
        different from that in ``reset_parameters``. For this, we define
        ``init_weights``. We only call it in the constructor of this
        ``Transformer`` root module to avoid reinitializing tensors.
        """
        buffer_device = buffer_device or self.freqs_cis.device
        with torch.device(buffer_device):
            self.freqs_cis = self._precompute_freqs_cis()
        if self.tok_embeddings is not None:
            nn.init.normal_(self.tok_embeddings.weight)

        # Initialize learnable sink token with proper scaling
        if self.learnable_sink_token is not None:
            # Use same initialization as token embeddings but with smaller variance
            nn.init.normal_(self.learnable_sink_token, mean=0.0, std=0.02)

        for layer in self.layers.values():
            if layer is not None:
                layer.init_weights()
        if self.norm is not None:
            self.norm.reset_parameters()

        final_out_std = self.model_args.hidden_dim**-0.5
        cutoff_factor = 3
        if self.output is not None:
            nn.init.trunc_normal_(
                self.output.weight,
                mean=0.0,
                std=final_out_std,
                a=-cutoff_factor * final_out_std,
                b=cutoff_factor * final_out_std,
            )

        # Initialize fixed token embeddings
        if self.model_args.use_compression and self.fixed_tok_embeddings is not None:
            nn.init.normal_(self.fixed_tok_embeddings.weight)
            self.fixed_tok_embeddings.weight.requires_grad = False

    def disable_trainable_weights(self, layer_ids=None) -> None:
        if layer_ids is None:
            return

        with torch.no_grad():
            names = []
            for layer_module in self.layers.values():
                if layer_module.layer_id not in layer_ids:
                    continue
                
                # Walk every parameter in this layer (recursively)
                for name, param in layer_module.named_parameters(recurse=True):
                    if param.requires_grad:
                        names.append(name)
                        if "norm" not in name:
                            param.data.zero_()
                        param.requires_grad = False
                        # Zero the parameter value
                       # param.data.zero_()
                        # Also clear its gradient if present
                        if param.grad is not None:
                            param.grad.zero_()
                print(f"diabled all trainable params in layer {layer_module.layer_id}")
            
            return names



    def enable_all_layer_weights(self, name_list, layer_ids=None) -> None:
        if layer_ids is None:
            return

        with torch.no_grad():
            for layer_module in self.layers.values():
                if layer_module.layer_id not in layer_ids:
                    continue

                # Walk every parameter in this layer (recursively)
                for name, param in layer_module.named_parameters(recurse=True):
                    if name in name_list:
                        if param.requires_grad == False:
                            # if "norm" not in name:
                            #     param.data.zero_()
                            # Also clear its gradient if present
                            if param.grad is not None:
                                param.grad.zero_()
                        #if param.requires_grad:
                            # Zero the parameter value
                            param.requires_grad = True
                        # Also clear its gradient if present
                    # if param.grad is not None:
                    #     param.grad.zero_()

                print(f"enabled grads for params in layer {layer_module.layer_id}")
    def _precompute_freqs_cis(self) -> torch.Tensor:
        if self.model_args.use_learnable_sink_token:
            return precompute_freqs_cis(
                self.model_args.hidden_dim // self.model_args.n_heads,
                self.model_args.max_seq_len + 1,  # +1 for learnable token
                self.model_args.rope_theta,
            )
        else:
            return precompute_freqs_cis(
                self.model_args.hidden_dim // self.model_args.n_heads,
                # Need to compute until at least the max token limit for generation
                # TODO: explain in docs/composability.md why we removed the 2x
                # relaxing in our CP enablement PR
                self.model_args.max_seq_len,
                self.model_args.rope_theta,
            )

    def forward(
        self,
        tokens: torch.Tensor,
        eos_id: Optional[int] = None,
        input_batch: Optional[torch.Tensor] = None,
        return_entropy: bool = False,
    ):
        """
        Perform a forward pass through the Transformer model.

        Args:
            tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled.
                If pipeline parallelism is enabled, this will be the input token indices
                for the ranks on the first pipeline stage. This will be the activation of the
                previous pipeline stage if the current rank is not on the first stage.
            input_batch (torch.Tensor): The input batch read from the dataloader.
                This will always be the input batch regardless of the pipeline stage.
                This field is required for non-first PP stages to perform document
                masking attention (to analyze the boundary of the document).
            return_entropy (bool): Whether to compute and return entropy of final logits.

        Returns:
            torch.Tensor: Output logits after applying the Transformer model.
            If return_entropy is True, returns tuple of (logits, entropy).

        """
        # FlexAttention initialization if enabled
        if (
            self.model_args.use_flex_attn
            and not self.model_args.use_learnable_bias_softmax
        ):
            init_attention_mask(
                input_batch if input_batch is not None else tokens, eos_id=eos_id
            )

        # Token embeddings and fixed embeddings (if compression)
        if self.tok_embeddings is not None:
            token_embeds = self.tok_embeddings(tokens)
        else:
            token_embeds = tokens

        # Start with the learnable sink token if enabled
        if self.model_args.use_learnable_sink_token:
            batch_size = token_embeds.shape[0]
            sink_learnable = (
                self.learnable_sink_token.unsqueeze(0)
                .unsqueeze(0)
                .repeat(batch_size, 1, 1)
            )
            h = torch.cat([sink_learnable, token_embeds], dim=1)
        else:
            h = token_embeds

        # Add fixed token embeddings when compression is enabled, including sink's fixed embedding
        if self.model_args.use_compression and self.fixed_tok_embeddings is not None:
            fixed_for_tokens = self.fixed_tok_embeddings(tokens)
            if self.model_args.use_learnable_sink_token:
                # Prepend the fixed embedding for the sink token index
                sink_fixed = self.fixed_tok_embeddings.weight[self.sink_token_index]
                sink_fixed = (
                    sink_fixed.unsqueeze(0).unsqueeze(0).expand(h.shape[0], 1, -1)
                )
                fixed_total = torch.cat([sink_fixed, fixed_for_tokens], dim=1)
            else:
                fixed_total = fixed_for_tokens
            h = h + fixed_total

        # Initialize tokens tensor for compression if needed (prepend sink token index when applicable)
        if self.model_args.use_compression:
            if self.model_args.use_learnable_sink_token:
                sink_idx = torch.full(
                    (tokens.shape[0], 1),
                    int(self.sink_token_index),
                    dtype=tokens.dtype,
                    device=tokens.device,
                )
                tokens_with_sink = torch.cat([sink_idx, tokens], dim=1)
                tokens_tensor = tokens_with_sink.unsqueeze(2).float()
            else:
                tokens_tensor = tokens.unsqueeze(2).float()
        else:
            tokens_tensor = None

        # Process through transformer layers
        for layer in self.layers.values():
            # Apply transformer layer with compression handling
            h = layer(
                h,
                self.freqs_cis,
                rcv=self.rcv,
                fixed_tok_embeddings=self.fixed_tok_embeddings,
                tokens_tensor=tokens_tensor,
                parent_model=self,
            )

        # Apply final normalization and output projection
        h = self.norm(h) if self.norm else h
        output = self.output(h) if self.output else h

        if self.model_args.use_learnable_sink_token:
            # Remove learnable token from the output
            output = output[:, 1:, :]

        if return_entropy:
            with torch.no_grad():
                # Calculate entropy of the final logits
                # Detach from computational graph to avoid affecting gradients
                output_detached = output.detach()

                # Convert logits to probabilities using softmax
                probs = F.softmax(
                    output_detached, dim=-1
                )  # Shape: (batch_size, seq_len, vocab_size)

                # Calculate entropy: H = -sum(p * log(p))
                log_probs = F.log_softmax(output_detached, dim=-1)
                entropy_per_token = -(probs * log_probs).sum(
                    dim=-1
                )  # Shape: (batch_size, seq_len)

                # Average entropy across sequence and batch
                mean_entropy = entropy_per_token.mean()

            return output, mean_entropy

        return output

    def enable_attention_metrics_collection(self, enable: bool = True):
        """Enable/disable attention metrics collection for QK-clip."""
        self._collect_attention_metrics = enable
        for layer in self.layers.values():
            layer.attention._return_attention_metrics = enable

    def get_attention_metrics(self):
        """Get collected attention metrics from all layers."""
        return self._attention_metrics.copy()

    def clear_attention_metrics(self):
        """Clear stored attention metrics."""
        self._attention_metrics.clear()

    def apply_qk_clip_rescaling(self, threshold: float = None, alpha: float = None):
        """Apply QK-clip rescaling based on collected metrics."""
        if threshold is None:
            threshold = self.model_args.qk_clip_threshold
        if alpha is None:
            alpha = self.model_args.qk_clip_alpha

        if not self._attention_metrics:
            return {}

        rescaling_stats = {}

        for layer_id, metrics in self._attention_metrics.items():
            if "qk_row_max" not in metrics:
                continue

            # Compute per-head max across batch and tokens: (B, H, T) -> (H,)
            qk_row_max = metrics["qk_row_max"]  # (B, H, T)
            smax_per_head = qk_row_max.amax(dim=2).amax(dim=0)  # (H,)

            # Compute scaling factors
            gamma_h = torch.minimum(
                torch.ones_like(smax_per_head), threshold / (smax_per_head + 1e-12)
            )

            # Only apply rescaling if any head exceeds threshold
            needs_rescaling = (smax_per_head > threshold).any()

            if needs_rescaling:
                layer = self.layers[str(layer_id)]
                qk_clip_rescale_(layer.attention, gamma_h, alpha)

                rescaling_stats[layer_id] = {
                    "smax_per_head": smax_per_head.cpu().tolist(),
                    "gamma_h": gamma_h.cpu().tolist(),
                    "max_exceeded": smax_per_head.max().item(),
                    "heads_clipped": (gamma_h < 1.0).sum().item(),
                }

        return rescaling_stats

    def regularize_weights(self) -> None:
        """Apply compression to model weights to maintain them in a compressed form."""
        if (
            not self.model_args.use_compression
            or self.model_args.compression_rate <= 1
            or self.rcv is None
        ):
            return

        with torch.no_grad():
            # Apply compression to attention and feed-forward weights in each layer
            for layer_module in list(self.layers.values())[:-1]:
                if not isinstance(layer_module, TransformerBlock):
                    continue

                # Compress attention weights
                if hasattr(layer_module.attention, "wo"):
                    weight = layer_module.attention.wo.weight
                    device = weight.data.device
                    rcv_device = self.rcv.to(device).contiguous()
                    compressed = rcv_device @ (rcv_device.T @ weight.data)
                    weight.data = compressed.contiguous()

                # Compress feed-forward weights
                if hasattr(layer_module.feed_forward, "w2"):
                    weight = layer_module.feed_forward.w2.weight
                    device = weight.data.device
                    rcv_device = self.rcv.to(device).contiguous()
                    compressed = rcv_device @ (rcv_device.T @ weight.data)
                    weight.data = compressed.contiguous()

    def get_weights(self) -> torch.Tensor:
        """
        Get a concatenated matrix of weights from the first layer's attention and feedforward weights.
        This is used to compute a single compression vector for all operations.

        Returns:
            torch.Tensor: Concatenated weight matrix
        """
        w1 = list(self.layers.values())[0].attention.wo.weight.data.clone()
        w2 = list(self.layers.values())[0].feed_forward.w2.weight.data.clone()
        weights = [w1, w2]

        weight_mat = torch.cat(weights, dim=1)

        return weight_mat

    def get_rcv(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute a single compression vector for all operations using SVD.

        Returns:
            torch.Tensor: Right compression vector
        """
        matrix = self.get_weights()
        rank = int(matrix.shape[0] // self.model_args.compression_rate)

        if matrix.is_meta:
            # Skip checks for meta tensors; ensure real tensors are passed at runtime
            print("Skipping NaN/Inf checks for meta tensors")
        else:
            if torch.isnan(matrix).any() or torch.isinf(matrix).any():
                raise ValueError("Input matrix contains NaNs or Infs.")

        norm = torch.linalg.norm(matrix)
        if matrix.is_meta:
            print("Skipping NaN/Inf checks for meta tensors")
        else:
            if norm == 0 or torch.isinf(norm):
                print(norm)
                raise ValueError("Input matrix has zero or infinite norm.")

        try:
            U, S, Vh = torch.linalg.svd(matrix, full_matrices=False)
        except RuntimeError as e:
            raise RuntimeError(f"SVD computation failed: {e}") from e

        # Verify numerical rank against desired k and optionally plot spectrum
        try:
            k = int(
                self.model_args.hidden_dim
                // max(1, int(self.model_args.compression_rate))
            )
        except Exception:
            k = rank

        if not matrix.is_meta and S.numel() > 0:
            eps = torch.finfo(S.dtype).eps
            tol = max(matrix.shape) * S.max() * eps
            num_gt_tol = int((S > tol).sum().item())
            pass_condition = num_gt_tol == k
            if pass_condition:
                logger.info(
                    f"SVD spectrum check passed: top-{k} singular values > tol ({float(tol):.3e}); rest ~0"
                )
            else:
                logger.warning(
                    f"SVD spectrum check mismatch: {num_gt_tol} singular values > tol ({float(tol):.3e}); expected {k}"
                )

            # Optional: dump/plot singular values when requested via env vars
            if os.environ.get("TT_PLOT_SINGULAR_VALUES", "0") == "1":
                try:
                    import matplotlib.pyplot as plt  # type: ignore

                    s_cpu = S.detach().float().cpu()
                    plt.figure(figsize=(8, 4))
                    plt.semilogy(
                        s_cpu.numpy(), marker="o", linestyle="none", markersize=2
                    )
                    plt.axvline(x=k - 0.5, color="r", linestyle="--", linewidth=1)
                    plt.title("Singular values (semilogy)")
                    plt.xlabel("Index")
                    plt.ylabel("Sigma")
                    out_dir = os.environ.get("TT_SINGULAR_VALUES_DIR", ".")
                    os.makedirs(out_dir, exist_ok=True)
                    out_path = os.path.join(out_dir, "singular_values.png")
                    plt.tight_layout()
                    plt.savefig(out_path)
                    plt.close()
                    logger.info(f"Saved singular value plot to {out_path}")
                except Exception as _plot_ex:
                    logger.warning(f"Failed to plot singular values: {_plot_ex}")

        if matrix.is_meta:
            print("Skipping NaN/Inf checks for meta tensors")
        else:
            condition_number = S.max() / S.min()
            if condition_number > 1e12:
                print("Warning: High condition number, results may be inaccurate.")

        if matrix.is_meta:
            print("Skipping NaN/Inf checks for meta tensors")
        else:
            if torch.isnan(U).any() or torch.isnan(S).any() or torch.isnan(Vh).any():
                print("NaNs detected in SVD output.")

        U_r = U[:, :rank]
        U_c = U[:, rank:]

        for _ in range(100):
            U_r, _ = torch.linalg.qr(U_r)

        for _ in range(100):
            U_c, _ = torch.linalg.qr(U_c)

        U_r = U_r.float()
        U_c = U_c.float()

        return U_r, U_c

    def copy_embedding_weights(self, data: torch.Tensor) -> None:
        """
        Copy embedding weights to fixed token embeddings.

        Args:
            data (torch.Tensor): Embedding weights to copy
        """
        if not self.model_args.use_compression or self.fixed_tok_embeddings is None:
            return

        data = data.to(self.fixed_tok_embeddings.weight.data.device)
        data = data.contiguous()

        target = self.fixed_tok_embeddings.weight.data
        target_rows = target.shape[0]
        src_rows = data.shape[0]

        # Cases:
        # 1) src == tgt: copy all
        # 2) tgt = vocab+1 (sink extra), src = vocab: copy first vocab rows, keep sink row
        # 3) tgt = vocab, src = vocab+1: drop extra row, copy first vocab rows
        # 4) otherwise: raise for clarity
        if src_rows == target_rows:
            target.copy_(data)
        elif target_rows == self.vocab_size + 1 and src_rows == self.vocab_size:
            target[: self.vocab_size].copy_(data)
        elif target_rows == self.vocab_size and src_rows == self.vocab_size + 1:
            target.copy_(data[: self.vocab_size])
        else:
            raise RuntimeError(
                "copy_embedding_weights shape mismatch: "
                f"src_rows={src_rows}, target_rows={target_rows}, "
                f"vocab_size={self.vocab_size}"
            )
        self.fixed_tok_embeddings.weight.requires_grad = False

    def regularize_embeddings(self) -> None:
        """
        Apply compression to token embeddings to maintain them in a compressed form.
        This is called after each optimization step.
        """
        if (
            not self.model_args.use_compression
            or self.model_args.compression_rate <= 1
            or self.tok_embeddings is None
            or self.rcv is None
        ):
            return

        with torch.no_grad():
            # Compress regular token embeddings
            first_layer = list(self.layers.values())[0]
            if isinstance(first_layer, TransformerBlock):
                # Move compression vector to the same device as embeddings and make contiguous
                rcv_device = self.rcv.to(self.tok_embeddings.weight.data.device)
                rcv_device = rcv_device.contiguous()

                # Apply compression: W ≈ (rcv @ (rcv.T @ W.T)).T
                # Note the transpose operations which differ from the weight compression
                self.tok_embeddings.weight.data = (
                    rcv_device @ (rcv_device.T @ self.tok_embeddings.weight.data.T)
                ).T
                self.tok_embeddings.weight.data = (
                    self.tok_embeddings.weight.data.contiguous()
                )

    def get_compression_energy_loss_stats(self) -> dict[str, float]:
        """
        Get current step's compression energy loss statistics from all transformer layers.

        Returns:
            dict: Dictionary with layer-wise energy loss ratios and overall average for current step
        """
        # If tracking is disabled, return an empty dict to avoid logging zeros
        if not self.model_args.track_compression_energy:
            return {}
        stats = {}
        total_loss = 0.0
        active_layers = 0

        for layer_id, layer in self.layers.items():
            # Avoid brittle type checks under wrappers; rely on attributes
            if getattr(layer, "needs_decompression", False):
                get_loss_fn = getattr(layer, "get_compression_energy_loss", None)
                if callable(get_loss_fn):
                    loss_ratio = get_loss_fn()
                    stats[f"layer_{layer_id}_energy_loss"] = loss_ratio
                    if getattr(layer, "compression_measurements_count", 0) > 0:
                        total_loss += loss_ratio
                        active_layers += 1

        # Calculate overall average for current step
        stats["average_energy_loss"] = (
            total_loss / active_layers if active_layers > 0 else 0.0
        )
        stats["active_compression_layers"] = active_layers

        return stats

    def reset_compression_energy_stats(self) -> None:
        """Reset compression energy loss statistics for all layers."""
        from torchtitan.tools.logging import logger

        logger.debug(
            f"Resetting compression energy stats for {len(self.layers)} layers"
        )
        for layer_id, layer in self.layers.items():
            logger.debug(f"Layer {layer_id} type is {type(layer).__name__}")
            if hasattr(layer, "reset_compression_energy_stats"):
                logger.debug(f"Calling reset on layer {layer_id}")
                layer.reset_compression_energy_stats()
            else:
                logger.debug(
                    f"Layer {layer_id} does not have reset_compression_energy_stats method"
                )

    def store_uncompressed_grad(self, layer_id: int, grad: torch.Tensor) -> None:
        """Store uncompressed gradient for a layer.

        Args:
            layer_id: Layer identifier
            grad: Gradient tensor [B, S, hidden_dim] - gradient wrt block layer_id's output (no compression)
        """
        self.uncompressed_upstream_grads[layer_id] = grad.detach()

    def store_w2_stable_rank(self, layer_id: int, value: float) -> None:
        """Store w2 stable rank value for running average computation."""
        self.w2_stable_rank_sum[layer_id] = self.w2_stable_rank_sum.get(
            layer_id, 0.0
        ) + float(value)
        self.w2_stable_rank_count[layer_id] = (
            self.w2_stable_rank_count.get(layer_id, 0) + 1
        )
        self.stable_rank_w2[layer_id] = (
            self.w2_stable_rank_sum[layer_id] / self.w2_stable_rank_count[layer_id]
        )

    def reset_w2_stable_rank_stats(self) -> None:
        """Reset w2 stable rank statistics."""
        self.uncompressed_upstream_grads.clear()
        self.w2_stable_rank_sum.clear()
        self.w2_stable_rank_count.clear()
        self.stable_rank_w2.clear()

    @classmethod
    def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer":
        """
        Initialize a Transformer model from a ModelArgs object.

        Args:
            model_args (TransformerModelArgs): Model configuration arguments.

        Returns:
            Transformer: Transformer model.
        """
        return cls(model_args)
