# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union, TYPE_CHECKING, Any, Tuple

import math
import warnings
from dataclasses import dataclass

import torch
from torch import nn
import torch.nn.functional as F

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin

# Handle use_kernel_forward_from_hub import with fallback
try:
    from transformers.integrations import use_kernel_forward_from_hub
except Exception:
    try:
        from kernels import use_kernel_forward_from_hub
    except Exception:
        use_kernel_forward_from_hub = None

from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import (
    GenericForQuestionAnswering,
    GenericForSequenceClassification,
    GenericForTokenClassification,
    GradientCheckpointingLayer,
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from models.hillama.configuration_hillama import HiLlamaConfig

from einops import rearrange

# Handle fla imports with fallback to native-sparse-attention directory
import os
import sys

try:
    # Try importing from current Python path (installed version or existing path)
    from fla.layers.utils import pad_input, unpad_input
    from fla.ops.utils.index import prepare_lens_from_mask
    from fla.modules import RMSNorm, RotaryEmbedding
except (ImportError, ValueError):
    # Fallback: Add native-sparse-attention directory to path and use alternative imports
    # Note: ValueError can occur when fla models conflict with transformers model registration
    native_sparse_attn_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "..", "native-sparse-attention")
    if os.path.exists(native_sparse_attn_dir) and native_sparse_attn_dir not in sys.path:
        sys.path.insert(0, native_sparse_attn_dir)
    
    # In native-sparse-attention, pad_input and unpad_input come from flash_attn.bert_padding
    try:
        from flash_attn.bert_padding import pad_input, unpad_input
    except ImportError:
        # If flash_attn is not available, provide fallback implementations
        def unpad_input(hidden_states, attention_mask, unused_mask=None):
            """Fallback implementation of unpad_input."""
            all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
            seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
            indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
            max_seqlen_in_batch = seqlens_in_batch.max().item()
            cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
            return (
                rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
                indices,
                cu_seqlens,
                max_seqlen_in_batch,
            )
        
        def pad_input(hidden_states, indices, batch, seqlen):
            """Fallback implementation of pad_input."""
            dim = hidden_states.shape[1:]
            output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
            output[indices] = hidden_states
            return rearrange(output, "(b s) ... -> b s ...", b=batch)
    
    # prepare_lens_from_mask might not exist in native-sparse-attention, try alternative
    try:
        from fla.ops.utils.index import prepare_lens_from_mask
    except (ImportError, ValueError):
        # If not available, create a simple wrapper
        def prepare_lens_from_mask(attention_mask):
            """Prepare sequence lengths from attention mask."""
            return attention_mask.sum(dim=-1, dtype=torch.int32)
    
    # Import RMSNorm and RotaryEmbedding from fla.modules (may trigger model registration conflicts)
    try:
        from fla.modules import RMSNorm, RotaryEmbedding
    except (ImportError, ValueError):
        # If fla.modules import fails due to registration conflict, provide minimal fallbacks
        warnings.warn(
            "fla.modules import failed (possibly due to model registration conflict). "
            "Using fallback implementations where possible.",
            category=ImportWarning,
        )
        # Use standard LayerNorm as fallback for RMSNorm
        RMSNorm = nn.LayerNorm
        # Try to import RotaryEmbedding separately (it's essential)
        from fla.modules import RotaryEmbedding

try:
    from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
    warnings.warn(
        "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
        category=ImportWarning,
    )
    flash_attn_func = None
    flash_attn_varlen_func = None

logger = logging.get_logger(__name__)

###########################
# Output
###########################
@dataclass
class BaseModelArmOutputWithPast(BaseModelOutputWithPast):
    """
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.

            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
            hidden_size)` is output.
        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
            input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reg_losses (`tuple(torch.FloatTensor)`, *optional*, returned when `output_reg_losses=True` is passed or when `config.output_reg_losses=True`):
            TODO: Add Definition
    """

    last_hidden_state: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Cache] = None
    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
    reg_losses: Optional[tuple[torch.FloatTensor, ...]] = None

@dataclass
class CausalLMArmOutputWithPast(CausalLMOutputWithPast):
    """
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reg_losses (`tuple(torch.FloatTensor)`, *optional*, returned when `output_reg_losses=True` is passed or when `config.output_reg_losses=True`):
            TODO: Add Definition
    """

    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Cache] = None
    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
    reg_losses: Optional[tuple[torch.FloatTensor, ...]] = None
###########################

###########################
# Hierarchical Router
###########################

import triton
import triton.language as tl

from ops.arm_kernel import parallel_arm
from ops.utils import argsort

# Try to import fused cache update kernel
try:
    from ops.fused_cache_update import fused_cache_update
    _has_fused_cache_update = True
except ImportError:
    _has_fused_cache_update = False

try:
    from accelerated_scan.warp import scan as accelerated_scan_fn
except ImportError:
    warnings.warn(
        "accelerated_scan is not installed. Parallel scan will be unavailable.",
        category=ImportWarning,
    )
    accelerated_scan_fn = None


# -----------------------------------------------------------------------------
# 1) Triton kernel: blocked, batched hierarchical beam search across all levels
# -----------------------------------------------------------------------------
"""
@triton.autotune(
    configs=[
        # triton.Config(kwargs={'BLOCK_TOKENS': 16}, num_warps=16, num_stages=1),     # 3.46s
        # triton.Config(kwargs={'BLOCK_TOKENS': 32}, num_warps=32, num_stages=1),      # 2.83s
        # triton.Config(kwargs={'BLOCK_TOKENS': 32}, num_warps=16, num_stages=2),       # 1.815
        # triton.Config(kwargs={'BLOCK_TOKENS': 64}, num_warps=32, num_stages=2),     # 2.02s
        triton.Config(kwargs={'BLOCK_TOKENS': 64}, num_warps=32, num_stages=3),     # 1.80s  64
        # triton.Config(kwargs={'BLOCK_TOKENS': 128}, num_warps=32, num_stages=3),  # 1.96s
        # triton.Config(kwargs={'BLOCK_TOKENS': 128}, num_warps=32, num_stages=1)  # 5.630046367645264
        # triton.Config(kwargs={'BLOCK_TOKENS': 256}, num_warps=32, num_stages=5), # 2.78s
        # triton.Config(kwargs={'BLOCK_TOKENS': 256}, num_warps=32, num_stages=3),  # 2.31s
    ],
    key=['B_S', 'beam']
)
"""


@triton.jit
def hierarchical_beam_search_blocked(
    # pointers
    q_ptr,  # float32* [B_S, D]
    route_ptr,  # float32* [sum_P, D, C]
    offsets_ptr,  # int32*   [L+1]
    counts_ptr,  # int32*   [L]
    scores_ptr,  # float32* [B_S, beam*C]
    idxs_ptr,  # int32*   [B_S, beam*C]
    scores_beam_ptr,  # float32* [B_S, beam]
    idxs_beam_ptr,  # float32* [B_S, beam]
    # scalars
    B_S,
    D,
    C,
    L,
    K: tl.constexpr,
    stride_qD: tl.constexpr,
    stride_route_L: tl.constexpr,
    stride_route_D: tl.constexpr,
    stride_route_C: tl.constexpr,
    beam: tl.constexpr,
    # compile-time tiles
    BLOCK_TOKENS: tl.constexpr,
    BLOCK_D: tl.constexpr,
    BLOCK_C: tl.constexpr,
):
    # token block
    block_id = tl.program_id(0)
    offs_t = tl.arange(0, BLOCK_TOKENS)
    token_ids = block_id * BLOCK_TOKENS + offs_t  # [T]
    valid_t = token_ids < B_S  # [T]

    # 1) load query tile [T, BLOCK_D]
    offs_d = tl.arange(0, BLOCK_D)
    mask_q = (offs_d[None, :] < D) & (valid_t[:, None])
    ptrs_q = q_ptr + token_ids[:, None] * stride_qD + offs_d[None, :]
    q_tile = tl.load(ptrs_q, mask=mask_q, other=0.0)  # [T, BLOCK_D]

    # 2) init beam state [T, beam]
    offs_beam = tl.arange(0, beam)
    offs_bucket = tl.arange(0, BLOCK_C)
    offs_cand = tl.arange(0, beam * stride_route_D)
    # beam_probs = tl.zeros([BLOCK_TOKENS, beam], dtype=tl.float32)
    beam_parents = tl.zeros([BLOCK_TOKENS, beam], dtype=tl.int32)
    # beam_probs = beam_probs + 1
    beam_probs = tl.where(valid_t[:, None] & (offs_beam[None, :] == 0), 1.0, 0.0)

    # set first beam slot prob=1.0
    zero_b = tl.arange(0, beam) == 0  # [beam]
    init_mask = valid_t[:, None] & zero_b[None, :]
    beam_probs = tl.where(
        init_mask, beam_probs, (tl.zeros((beam,), dtype=tl.float32) - 1)[None, :]
    )

    # scratch for bitonic
    MAX_CAND = beam * C

    # 3) loop levels
    for lvl in range(L):
        P_l = tl.load(counts_ptr + lvl)
        offset = tl.load(offsets_ptr + lvl)
        
        # compute scores
        prev_p = beam_probs  # [:, b]               # [T]
        parent_b = beam_parents  # [:, b]               # [T]
        baseW = route_ptr + (offset + parent_b) * stride_route_L  # [T]
        # for c in range(C):
        offs_d = tl.arange(0, BLOCK_D)
        mask_w = (offs_d[None, None, :, None] < D) & valid_t[:, None, None, None]
        ptrs_w = (
            baseW[:, :, None, None]
            + offs_bucket[None, None, None, :]
            + (offs_d[None, None, :, None] * stride_route_D)
        )
        w_tile = tl.load(
            ptrs_w, mask=mask_w, other=0.0
        )  # [T, beam, D, C] # [T, BLOCK_D]
        
        dot = tl.sum(q_tile[:, None, :, None] * w_tile, axis=2)  # [T, beam, C]

        
        # sc = tl.exp(tl.cast(dot, tl.float32))

        # sc = (
        #     sc / (tl.sum(sc, axis=2, keep_dims=True) + 1e-6) * prev_p[:, :, None]
        # )  # [T, beam*C] combined_probs

        dot = tl.cast(dot, tl.float32)
        m = tl.max(dot, axis=2, keep_dims=True)
        ex = tl.exp(dot - m)
        prob = ex / (tl.sum(ex, axis=2, keep_dims=True) + 1e-9)
        sc = prob * prev_p[:, :, None]

        sc = tl.reshape(sc, (BLOCK_TOKENS, beam * stride_route_D))

        # ——— Bitonic‐sort top‐K (vectorized) ———
        N = beam * stride_route_D
        ids = (
            tl.zeros([BLOCK_TOKENS, beam * stride_route_D], dtype=tl.int64)
            + offs_cand[None, :]
        )  # tl.arange
        ids = tl.broadcast_to(offs_cand[None, :], (BLOCK_TOKENS, beam * stride_route_D))
        sc = sc - tl.cast(ids, tl.float32) * 1e-12
        sc = tl.where(valid_t[:, None], sc, float("-inf"))
        new_sc, new_ids = argsort(sc, ids)

        ptr_s = (
            scores_ptr + token_ids[:, None] * (MAX_CAND) + offs_cand[None, :]
        )  # beam_parents
        tl.store(ptr_s, new_sc, mask=valid_t[:, None])
        ptr_i = (
            idxs_ptr + token_ids[:, None] * (MAX_CAND) + offs_cand[None, :]
        )  # beam_probs
        tl.store(
            ptr_i, new_ids, mask=valid_t[:, None]
        )

        # # extract top-K
        # offs_topk = tl.arange(0, K)
        # ptr_s = scores_ptr + token_ids[:, None] * MAX_CAND + offs_topk
        # ptr_i = idxs_ptr + token_ids[:, None] * MAX_CAND + offs_topk
        # top_sc = tl.load(ptr_s, mask=valid_t[:, None], other=0.0)
        # top_ix = tl.load(ptr_i, mask=valid_t[:, None], other=0)
        # p = top_ix // C
        # c = top_ix % C

        # ptr_s_beam = scores_beam_ptr + token_ids[:, None] * beam + offs_topk
        # tl.store(ptr_s_beam, top_sc, mask=valid_t[:, None])

        # p = tl.load(
        #     idxs_beam_ptr + token_ids[:, None] * beam + p, mask=valid_t[:, None]
        # )
        # ptr_i_beam = idxs_beam_ptr + token_ids[:, None] * beam + offs_topk
        # id = p * C + c
        # tl.store(ptr_i_beam, id, mask=valid_t[:, None])

        # beam_probs = top_sc
        # beam_parents = id
        # extract top-K
        offs_topk = tl.arange(0, K)
        ptr_s = scores_ptr + token_ids[:, None] * MAX_CAND + offs_topk
        ptr_i = idxs_ptr + token_ids[:, None] * MAX_CAND + offs_topk
        top_sc = tl.load(ptr_s, mask=valid_t[:, None], other=0.0)
        top_ix = tl.load(ptr_i, mask=valid_t[:, None], other=0)

        p = top_ix // C   # [T, K]  (which previous beam slot)
        c = top_ix % C    # [T, K]  (child id)

        # store probs (optional, for debug)
        ptr_s_beam = scores_beam_ptr + token_ids[:, None] * beam + offs_topk
        tl.store(ptr_s_beam, top_sc, mask=valid_t[:, None])

        # ---- REPLACE THIS GATHER LOGIC ----
        # gather previous path codes from *register* beam_parents (NOT from idxs_beam_ptr)
        offs_beam = tl.arange(0, beam)  # [beam]

        # parent_code[t,k] = beam_parents[t, p[t,k]]
        parent_code = tl.sum(
            tl.where(
                offs_beam[None, None, :] == p[:, :, None],
                beam_parents[:, None, :],
                0
            ),
            axis=2
        ).to(tl.int32)  # [T, K]

        id = parent_code * C + c        # [T, K]
        # ----------------------------------

        # write the updated paths to global (so next level could read if you want)
        ptr_i_beam = idxs_beam_ptr + token_ids[:, None] * beam + offs_topk
        tl.store(ptr_i_beam, id, mask=valid_t[:, None])

        # update the register state for next level
        beam_probs = top_sc
        beam_parents = id


# -----------------------------------------------------------------------------
# 2) Python wrapper + test
# -----------------------------------------------------------------------------
def hierarchical_search_triton(
    q, route_flat, level_offsets, parent_counts, beam_width0, num_levels
):
    def next_pow2(x):
        return 1 if x <= 1 else 2 ** math.ceil(math.log2(x))

    B, S, D = q.shape
    beam_width = next_pow2(beam_width0)
    B, S, D = q.shape
    beam = beam_width
    C = route_flat.shape[-1]
    BLOCK_D = D
    BLOCK_C = C
    L, K = num_levels, beam_width
    B_S = B * S

    # flatten + contig
    q_flat = q.contiguous().view(B_S, D)
    off_ptr = level_offsets.to(torch.int32).contiguous()
    cnt_ptr = parent_counts.to(torch.int32).contiguous()

    # global scratch
    scores = torch.zeros((B_S, beam * C), dtype=torch.float32, device=q.device)
    idxs = torch.zeros(
        (B_S, beam * C), dtype=torch.int32, device=q.device
    )  #  arange(beam*C, dtype=torch.int32, device=q.device)[None,:].expand(B_S, -1).contiguous()

    scores_beam = torch.zeros((B_S, beam), dtype=torch.float32, device=q.device)
    idxs_beam = torch.zeros(
        (B_S, beam), dtype=torch.int32, device=q.device
    )  #  arange(beam*C, dtype=torch.int32, device=q.device)[None,:].expand(B_S, -1).contiguous()

    # out_p  = torch.empty((B_S, K),   dtype=torch.int32,   device=q.device)
    # out_c  = torch.empty((B_S, K),   dtype=torch.int32,   device=q.device)

    # strides
    s_qD = D
    s_rL = D * C
    s_rD = C
    s_rC = 1

    # launch
    # num_blocks = (B_S + BLOCK_TOKENS - 1)//BLOCK_TOKENS
    # grid = (num_blocks,)
    BLOCK_TOKENS = 32  # 64 # 128: 2.334024667739868 512: 3.244183301925659
    num_blocks = (B_S + BLOCK_TOKENS - 1) // BLOCK_TOKENS
    grid = (num_blocks,)
    # grid = lambda META: ((B_S + META['BLOCK_TOKENS'] - 1)//META['BLOCK_TOKENS'],)  # ( (T + META['BS'] -1)//META['BS'], BH )
    hierarchical_beam_search_blocked[grid](
        q_flat,
        route_flat,
        off_ptr,
        cnt_ptr,
        scores,
        idxs,
        scores_beam,
        idxs_beam,
        B_S,
        D,
        C,
        L,
        K,
        s_qD,
        s_rL,
        s_rD,
        s_rC,
        beam,
        BLOCK_TOKENS,
        BLOCK_D=BLOCK_D,
        BLOCK_C=BLOCK_C,
    )
    return idxs_beam[..., :beam_width0]  # out_p.view(B, S, K), out_c.view(B, S, K)


"""
class HierarchicalRouter(nn.Module):
    # aim to learn generate hierachical trees with more layers #
    def __init__(self, input_dim, hidden_dim, num_levels=3, num_buckets_per_level=4, beam_width=4, dtype=torch.float32):
        super(HierarchicalRouter, self).__init__()
        self.num_levels = num_levels
        self.num_buckets_per_level = num_buckets_per_level # *self.num_buckets_per_level
        # Use single nn.Linear with all levels stored in one weight matrix
        # Total output features = sum(num_buckets_per_level**(l+1) for l in range(num_levels))
        # Precompute offsets for each level in the weight matrix
        self.Ps = [num_buckets_per_level**l for l in range(num_levels)]
        total_output_features = sum(num_buckets_per_level**(l+1) for l in range(num_levels))
        self.route = nn.Linear(input_dim, total_output_features, bias=False, dtype=dtype)
        # self.route_test = nn.ParameterList([nn.Parameter(torch.rand(num_buckets_per_level**l,input_dim, num_buckets_per_level)) for l in range(self.num_levels)])
        
        # Precompute offsets for each level in the weight matrix
        self.route_offsets = []
        offset = 0
        for l in range(num_levels):
            self.route_offsets.append(offset)
            offset += num_buckets_per_level**(l+1)
        self.beam_width = beam_width
        self.num_sample_per_bucket = 32
        self.epoch = 0
        self.counts = None
    
    def _get_route_weight(self, l):
        
        # Get route weight for level l in shape [num_buckets_per_level**l, inp_dim, num_buckets_per_level].
        # This slices and reshapes the single Linear weight matrix.
        
        # Get the slice for this level
        offset = self.route_offsets[l]
        size = self.num_buckets_per_level**(l+1)
        # self.route.weight shape: [total_output_features, input_dim]
        # Slice to get this level's weights: [size, input_dim]
        weight_slice = self.route.weight[offset:offset+size, :]  # [num_buckets_per_level**(l+1), input_dim]
        # Reshape to [num_buckets_per_level**l, num_buckets_per_level, input_dim]
        num_parents = self.num_buckets_per_level**l
        weight = weight_slice.view(num_parents, self.num_buckets_per_level, -1)
        # Transpose to [num_buckets_per_level**l, input_dim, num_buckets_per_level]
        weight = weight.transpose(1, 2)
        return weight
    

    def get_key_assignment(self, k, tau):
        batch_size, seq_len, inp_dim = k.size()
        b_s = batch_size * seq_len
        k = k.view(b_s, 1, inp_dim)
        route_prob_all = torch.ones((b_s,1,1), device=k.device)
        k_sorted = k.view(batch_size, 1, seq_len, inp_dim)
        original_ind = torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
        for l in range(self.num_levels):
            route_weight = self._get_route_weight(l)  # [num_buckets_per_level**l, inp_dim, num_buckets_per_level]
            route_weight = route_weight.to(k_sorted.dtype)  # Ensure dtype matches
            logits = k_sorted @ route_weight  # [batch_size, 1, seq_len, num_buckets_per_level]
            route_prob_l_hard = F.gumbel_softmax(logits, dim=-1, tau=tau, hard=True) 
            _, route_max_ind = torch.max(route_prob_l_hard, dim=-1,keepdim=True) 

            num_buckets = self.num_buckets_per_level**(l+1)  
            
            ind_sorted, shift_ind= torch.sort(route_max_ind.squeeze(-1), dim=-1, stable=True)
            original_ind = torch.gather(original_ind, -1,shift_ind)
            k_sorted = torch.gather(k.view(batch_size, seq_len,-1), 1, original_ind.view(batch_size,seq_len,1).expand(-1, -1, k.shape[-1]))

            ind_sorted = ind_sorted.view(batch_size, num_buckets, seq_len // num_buckets)
            original_ind = original_ind.view(batch_size, num_buckets, seq_len // num_buckets)

            k_sorted = k_sorted.view(batch_size, num_buckets, seq_len // num_buckets, -1) #.detach()

        return original_ind
    
    def test(self, q, k, beam_width=4,tau=0.01):
        D = q.shape[-1]
        key_assignment = self.get_key_assignment(k, tau)
        # route_flat_test = torch.cat([param for param in self.route_test ], dim=0)
        # Concatenate route weights from all levels
        route_flat = self.route.weight.view(-1, self.num_buckets_per_level,D).transpose(1, 2).contiguous()
        if self.counts == None:
            self.route_offsets = torch.tensor([0] + self.route_offsets, dtype=torch.int32,device=q.device)
            self.counts  = torch.tensor(self.Ps, dtype=torch.int32,device=q.device)
        query_assignment_tri = hierarchical_search_triton(q, route_flat, self.route_offsets, self.counts, beam_width, self.num_levels)
        query_assignment = query_assignment_tri.view(q.shape[0],q.shape[1],-1)
        return query_assignment, key_assignment  
    
    def forward(self, q, k, tau=0.01,eps=1e-12):
        self.epoch+=1
        x = torch.cat([q, k], dim=0)
        batch_size, seq_len, inp_dim = x.size()
        b_s = batch_size * seq_len
        x = x.view(b_s, 1, inp_dim)
        loss = 0
        x_sorted = x.view(batch_size, 1, seq_len, inp_dim)
        original_ind = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
        for l in range(self.num_levels):
            logits = x_sorted @ self.route[l]
            route_prob_l = torch.softmax(logits, dim=-1) # p(li | l_i-1, ..., l_1, x)
            route_prob_l_hard = F.gumbel_softmax(logits, dim=-1, tau=tau, hard=True) 
            _, route_max_ind = torch.max(route_prob_l_hard, dim=-1,keepdim=True) # TODO: error here!!!

            route_prob_avg = torch.mean(route_prob_l_hard, -2)
            term1 = torch.mean(route_prob_avg* torch.log(route_prob_avg+eps))
            term2 = -torch.mean(route_prob_l* torch.log(route_prob_l+eps))
            mi_loss = term1 + term2 
            loss += mi_loss

            num_buckets = self.num_buckets_per_level**(l+1)  
            # Arm probabilities at each level
            ind_sorted, shift_ind= torch.sort(route_max_ind.squeeze(-1), dim=-1, stable=True)
            original_ind = torch.gather(original_ind, -1,shift_ind)
            # No need to calculate prob: prob_sorted = torch.gather(prob, -1, original_ind)
            x_sorted = torch.gather(x.view(batch_size, seq_len,-1), 1, original_ind.view(batch_size,seq_len,1).expand(-1, -1, x.shape[-1]))
            # chunkwise
            # ind_sorted: sorted assignment results; original_ind: sorted results corresponding to original index 
            ind_sorted = ind_sorted.view(batch_size, num_buckets, seq_len // num_buckets)
            original_ind = original_ind.view(batch_size, num_buckets, seq_len // num_buckets)
            # q k all been sorted, so 
            x_sorted = x_sorted.view(batch_size, num_buckets, seq_len // num_buckets, -1)  # .detach()
            
        return loss

"""
class HierarchicalRouter(nn.Module):
    # aim to learn generate hierachical trees with more layers #
    def __init__(self, input_dim, hidden_dim, num_levels=3, num_buckets_per_level=4, beam_width=4, dtype=torch.float32):
        super(HierarchicalRouter, self).__init__()
        self.num_levels = num_levels
        self.num_buckets_per_level = num_buckets_per_level # *self.num_buckets_per_level
        # Use single nn.Linear with all levels stored in one weight matrix
        # Total output features = sum(num_buckets_per_level**(l+1) for l in range(num_levels))
        total_output_features = sum(num_buckets_per_level**(l+1) for l in range(num_levels))
        self.route = nn.Linear(input_dim, total_output_features, bias=False, dtype=dtype)
        # Precompute offsets for each level in the weight matrix
        self.route_offsets = []
        offset = 0
        for l in range(num_levels):
            self.route_offsets.append(offset)
            offset += num_buckets_per_level**(l+1)
        self.beam_width = beam_width
        self.num_sample_per_bucket = 32
        self.epoch = 0

        self.Ps = [num_buckets_per_level**l for l in range(num_levels)]
        self.counts = None
    
    def _get_route_weight(self, l):

        # Get route weight for level l in shape [num_buckets_per_level**l, inp_dim, num_buckets_per_level].
        # This slices and reshapes the single Linear weight matrix.

        # Get the slice for this level
        offset = self.route_offsets[l]
        size = self.num_buckets_per_level**(l+1)
        # self.route.weight shape: [total_output_features, input_dim]
        # Slice to get this level's weights: [size, input_dim]
        weight_slice = self.route.weight[offset:offset+size, :]  # [num_buckets_per_level**(l+1), input_dim]
        # Reshape to [num_buckets_per_level**l, num_buckets_per_level, input_dim]
        num_parents = self.num_buckets_per_level**l
        weight = weight_slice.view(num_parents, self.num_buckets_per_level, -1)
        # Transpose to [num_buckets_per_level**l, input_dim, num_buckets_per_level]
        weight = weight.transpose(1, 2)
        return weight

    
    def assign_keys(self, k, tau=0.01):
        batch_size, seq_len, inp_dim = k.size()
        b_s = batch_size * seq_len
        k = k.reshape(b_s, 1, inp_dim)
        k_sorted = k.reshape(batch_size, 1, seq_len, inp_dim)
        original_ind = torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
        for l in range(self.num_levels):
            route_weight = self._get_route_weight(l)  # [num_buckets_per_level**l, inp_dim, num_buckets_per_level]
            route_weight = route_weight.to(k_sorted.dtype)  # Ensure dtype matches
            logits = k_sorted @ route_weight  # [batch_size, 1, seq_len, num_buckets_per_level]
            route_prob_l_hard = F.gumbel_softmax(logits, dim=-1, tau=tau, hard=True) 
            _, route_max_ind = torch.max(route_prob_l_hard, dim=-1, keepdim=True) 

            num_buckets = self.num_buckets_per_level ** (l + 1)

            ind_sorted, shift_ind = torch.sort(route_max_ind.squeeze(-1), dim=-1, stable=True)
            original_ind = torch.gather(original_ind, -1, shift_ind)
            k_sorted = torch.gather(
                k.view(batch_size, seq_len, -1),
                1,
                original_ind.view(batch_size, seq_len, 1).expand(-1, -1, k.shape[-1]),
            )

            ind_sorted = ind_sorted.view(batch_size, num_buckets, seq_len // num_buckets)
            original_ind = original_ind.view(batch_size, num_buckets, seq_len // num_buckets)

            k_sorted = k_sorted.view(batch_size, num_buckets, seq_len // num_buckets, -1).detach()

        return original_ind  # key_assignment

    def assign_queries(self, q, beam_width=None, tau=0.01):
        beam_width = self.beam_width if beam_width is None else beam_width
        # Handle both 3D [B, S, D] and 4D [B, H, S, D] inputs
        if q.dim() == 4:
            B, H, S, D = q.size()
            q = q.flatten(0, 1)  # [B*H, S, D]
            batch_size, seq_len, inp_dim = q.size()
        else:
            batch_size, seq_len, inp_dim = q.size()
        
        # Optimization: if beam_width equals total number of possible paths,
        # directly return all possible path indices without beam search
        total_buckets = self.num_buckets_per_level ** self.num_levels
        if beam_width == total_buckets:
            # Return tensor with shape [batch_size, seq_len, total_paths]
            # Last dimension is torch.arange(total_paths) repeated for each position
            bucekt_indices = torch.arange(total_buckets, device=q.device, dtype=torch.long)  # [total_paths]
            # Expand to [batch_size, seq_len, total_paths]
            query_assignment = bucekt_indices.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1)
            return query_assignment
        
        if False: #batch_size * seq_len >= 32768:  # an empirical number triton faster than pytorch
        # prepare route_flat: [sum_P, D, C]
            route_flat_list = []
            for l in range(self.num_levels):
                route_weight = self._get_route_weight(l)  # [num_buckets_per_level**l, D, C]
                route_flat_list.append(route_weight)
            route_flat = torch.cat(route_flat_list, dim=0).contiguous()  # [sum_P, D, C]
            route_offsets_triton = [0]
            sum = 0
            for p in self.Ps:
                sum+=p
                route_offsets_triton.append(sum)
            route_offsets_triton = torch.tensor(route_offsets_triton, dtype=torch.int32,device=q.device)
            counts = torch.tensor(self.Ps, dtype=torch.int32, device=q.device)
            query_assignment_tri = hierarchical_search_triton(q, route_flat, route_offsets_triton, counts, beam_width, self.num_levels)
            query_assignment = query_assignment_tri.view(q.shape[0],q.shape[1],-1)
            return query_assignment
            
        route_prob_all = torch.ones((batch_size, seq_len, 1), device=q.device)
        q = q.view(batch_size, seq_len, inp_dim)

        # Initialize beam state: (probability, bucket_path)
        beam_probs = route_prob_all  # [B,S,1]
        beam_parents = torch.zeros(
            (batch_size, seq_len, 1, 0), dtype=torch.long, device=q.device
        )
        for l in range(self.num_levels):
            # Get active parent buckets from beam
            if l == 0:
                # First level uses all possible parents (just 1 virtual parent)
                active_parents = torch.zeros(
                    (batch_size, seq_len, 1), dtype=torch.long, device=q.device
                )
                num_active = 1
            else:
                active_parents = beam_parents
                # first is parent node, second is child node
                num_active = beam_width if beam_parents.shape[-1] >= beam_width else beam_parents.shape[-1]

            # Get route parameters for active parents
            route_weight = self._get_route_weight(l)  # [num_buckets_per_level**l, inp_dim, num_buckets_per_level]
            route_weight = route_weight.to(q.dtype)  # Ensure dtype matches
            route_weights = torch.index_select(
                route_weight,
                dim=0,
                index=active_parents.view(-1)
            ).view(batch_size, seq_len, num_active, inp_dim, -1)  # [B,S,beam,in_dim,num_children]

            # Compute logits only for active paths
            q_expanded = q.unsqueeze(2).unsqueeze(-1)  # [B,S,1,in_dim]
            logits = torch.sum(q_expanded * route_weights, dim=-2)  # [B,S,beam,num_children]

            # Combine with previous probabilities
            combined_probs = beam_probs.unsqueeze(-1) * torch.softmax(logits, dim=-1)
            combined_probs = combined_probs.view(batch_size, seq_len, -1)

            # Select top-k paths
            topk_probs, topk_indices = torch.topk(
                combined_probs,
                k=min(beam_width, combined_probs.size(-1)),
                dim=-1
            )

            # Update beam state
            beam_probs = topk_probs

            # Convert flat indices to parent/child indices
            new_parents = topk_indices // self.num_buckets_per_level
            new_children = topk_indices % self.num_buckets_per_level

            # Update path history
            if l == 0:
                beam_parents = new_children
            else:
                # Gather historical paths using new_parent indices
                beam_parents = torch.gather(
                    beam_parents,
                    2,
                    new_parents
                )
                beam_parents = beam_parents * self.num_buckets_per_level + new_children

        # Final assignment from best path
        
        return beam_parents  # query_assignment

    def test(self, q, k, beam_width=None, tau=0.01):
        key_assignment = self.assign_keys(k, tau=tau)
        query_assignment = self.assign_queries(q, beam_width=beam_width, tau=tau)
        return query_assignment, key_assignment

    def forward(self, q, k, tau=0.01,eps=1e-12):
        self.epoch+=1
        # if self.epoch %5 !=0:  # no need to update the routers in every step; if apply finetuning strategy update every step
        #     return 0*q.sum()
        x = torch.cat([q, k], dim=0)
        batch_size, seq_len, inp_dim = x.size()
        b_s = batch_size * seq_len
        x = x.view(b_s, 1, inp_dim)
        loss = 0
        x_sorted = x.view(batch_size, 1, seq_len, inp_dim)  # .detach()  # should not detach  # FINDINGS: i detach the x_sorted!
        original_ind = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
        for l in range(self.num_levels):
            route_weight = self._get_route_weight(l)  # [num_buckets_per_level**l, inp_dim, num_buckets_per_level]
            route_weight = route_weight.to(x_sorted.dtype)  # Ensure dtype matches
            logits = x_sorted @ route_weight  # [batch_size, 1, seq_len, num_buckets_per_level]
            route_prob_l = torch.softmax(logits, dim=-1) # p(li | l_i-1, ..., l_1, x)
            route_prob_l_hard = F.gumbel_softmax(logits, dim=-1, tau=tau, hard=True) 
            _, route_max_ind = torch.max(route_prob_l_hard, dim=-1,keepdim=True) # TODO: error here!!!

            route_prob_avg = torch.mean(route_prob_l_hard, -2)
            term1 = torch.mean(route_prob_avg* torch.log(route_prob_avg+eps))  # .to(logits.dtype)
            term2 = -torch.mean(route_prob_l* torch.log(route_prob_l+eps))  # .to(logits.dtype)
            mi_loss = term1 + term2 
            loss += mi_loss

            num_buckets = self.num_buckets_per_level**(l+1)  
            # Arm probabilities at each level
            ind_sorted, shift_ind= torch.sort(route_max_ind.squeeze(-1), dim=-1)
            original_ind = torch.gather(original_ind, -1,shift_ind)
            # No need to calculate prob: prob_sorted = torch.gather(prob, -1, original_ind)
            x_sorted = torch.gather(x.view(batch_size, seq_len,-1), 1, original_ind.view(batch_size,seq_len,1).expand(-1, -1, x.shape[-1]))
            # chunkwise
            # ind_sorted: sorted assignment results; original_ind: sorted results corresponding to original index 
            ind_sorted = ind_sorted.view(batch_size, num_buckets, seq_len // num_buckets)
            original_ind = original_ind.view(batch_size, num_buckets, seq_len // num_buckets)
            # q k all been sorted, so 
            x_sorted = x_sorted.view(batch_size, num_buckets, seq_len // num_buckets, -1)  # .detach()
            
        return loss

###########################

@use_kernel_forward_from_hub("RMSNorm")
class HiLlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        HiLlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class HiLlamaRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: HiLlamaConfig, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class HiLlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs: Unpack[TransformersKwargs],
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


class HiLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: HiLlamaConfig, layer_idx: int, shared_params=None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim ** -0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )

        self.causal = config.causal
        self.num_levels = config.num_levels
        self.beam_width = config.beam_width
        self.local_window_size = config.local_window_size

        # self.sink = nn.Parameter(torch.rand(config.num_key_value_heads, self.head_dim))
        if config.use_sink:
            self.sink = nn.Linear(self.head_dim, config.num_key_value_heads, bias=False)
        else:
            self.sink = None

        # Check if shared parameters are provided
        # Shared parameters will be passed via forward() function arguments, not stored as module attributes
        # This avoids issues with save_pretrained detecting shared tensors
        self.use_shared_g_proj = shared_params is not None and hasattr(shared_params, 'shared_g_proj') and shared_params.shared_g_proj is not None
        
        # Only create layer-specific parameters if not using shared ones
        if not self.use_shared_g_proj:
            # g_proj output dimension: 3 if use_sink, 2 otherwise (no sink token needed)
            g_proj_output_dim = config.num_attention_heads * (3 if config.use_sink else 2)
            self.g_proj = nn.Linear(config.hidden_size, g_proj_output_dim, bias=False)
            self.router = HierarchicalRouter(
                self.head_dim,
                self.head_dim,
                beam_width=config.beam_width,
                num_levels=config.num_levels,
                num_buckets_per_level=getattr(config, 'num_buckets_per_level', 4),
            )
    
    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        shared_g_proj: Optional[nn.Module] = None,
        shared_router: Optional[nn.Module] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        
        ########################################################################################
        # Pad sequences to the nearest power of 2 based on longest sequence
        ########################################################################################
        batch_size, sequence_len, hidden_dimension = hidden_states.size()

        next_pow2 = 1 << (sequence_len - 1).bit_length()  # e.g. 513 → 1024, 512 → 512

        # Get router (shared or layer-specific)
        router = self._shared_params_ref.shared_router if self.use_shared_g_proj else self.router
        total_num_buckets = (
            router.num_buckets_per_level**router.num_levels
        )

        while next_pow2 < total_num_buckets:
            next_pow2 = 1 << (next_pow2).bit_length()

        if next_pow2 != sequence_len:  # already a power of 2 → nothing to do
            pad_len = next_pow2 - sequence_len

            # ------- generate Gaussian noise on same device / dtype -------
            noise = torch.randn(
                batch_size,
                pad_len,
                hidden_dimension,
                dtype=hidden_states.dtype,
                device=hidden_states.device,
            )

            # ------- concatenate along seq_len dimension (axis 1) -------
            hidden_states = torch.cat([hidden_states, noise], dim=1)

        ########################################################################################
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # attention_interface: Callable = eager_attention_forward
        # if self.config._attn_implementation != "eager":
        #     attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        # attn_output, attn_weights = attention_interface(
        #     self,
        #     query_states,
        #     key_states,
        #     value_states,
        #     attention_mask,
        #     dropout=0.0 if not self.training else self.attention_dropout,
        #     scaling=self.scaling,
        #     **kwargs,
        # )

        ########################################################################################
        # ARM logic
        ########################################################################################

        batch_size, num_heads, sequence_len, _ = query_states.shape
        b_h = batch_size * num_heads
        
        # print(query_states.shape, key_states.shape, value_states.shape)
        
        # Change order of dimensions to [batch, head, length, head_dim] and then flatten
        # This is for compatability with router logic
        query_states = query_states.flatten(0, 1)
        key_states = key_states.flatten(0, 1)
        value_states = value_states.flatten(0, 1)
        
        # print(query_states.shape, key_states.shape, value_states.shape)

        # Use shared router if provided, otherwise use layer-specific one
        # Use shared router if provided, otherwise use layer-specific one
        if self.use_shared_g_proj:
            assert shared_router is not None, "use_shared_g_proj is True but shared_router is None"
            router = shared_router
        else:
            router = self.router
        
        reg_loss = router(query_states, key_states)
        
        with torch.no_grad():
            q_idx, k_idx = router.test(query_states, key_states)

        beam_width = q_idx.size(-1)
        _, num_bucket, N_sample_per_bucket = k_idx.shape

        bucket_keys = torch.gather(
                key_states, 
                dim=1, 
                index=k_idx.flatten(1, 2).unsqueeze(-1).expand(-1, -1, self.head_dim)
            ).view(
                b_h // self.num_key_value_groups, 
                num_bucket * N_sample_per_bucket, 
                self.head_dim
            ) 

        bucket_values = torch.gather(
                value_states, 
                dim=1, 
                index=k_idx.flatten(1, 2).unsqueeze(-1).expand(-1, -1, self.head_dim)
            ).view(
                b_h // self.num_key_value_groups, 
                num_bucket * N_sample_per_bucket, 
                self.head_dim
            )
            
        query_states = query_states.view(batch_size, num_heads, sequence_len, self.head_dim).transpose(1, 2).to(value_states.dtype)
        key_states = key_states.view(batch_size, self.config.num_key_value_heads, sequence_len, self.head_dim).transpose(1, 2).to(value_states.dtype)
        value_states = value_states.view(batch_size, self.config.num_key_value_heads, sequence_len, self.head_dim).transpose(1, 2)

        # Use shared g_proj if provided, otherwise use layer-specific one
        if self.use_shared_g_proj:
            assert shared_g_proj is not None, "use_shared_g_proj is True but shared_g_proj is None"
            g_proj = shared_g_proj
        else:
            g_proj = self.g_proj
        # Output dimension depends on use_sink: 3 if sink token is used, 2 otherwise
        g_proj_output_dim = 3 if self.config.use_sink else 2
        g_topk =  F.softmax(
            g_proj(hidden_states).view(batch_size, sequence_len, num_heads, g_proj_output_dim),
            dim=-1
        )

        bucket_keys = bucket_keys.view(batch_size, self.config.num_key_value_heads, sequence_len, self.head_dim).transpose(1, 2).to(value_states.dtype)
        bucket_values = bucket_values.view(batch_size, self.config.num_key_value_heads, sequence_len, self.head_dim).transpose(1, 2)
        
        q_idx = q_idx.view(batch_size, num_heads, sequence_len, -1).transpose(1, 2)
        k_idx = k_idx.view(batch_size, self.config.num_key_value_heads, k_idx.shape[1] * k_idx.shape[2]).transpose(1, 2)

        attn_output = parallel_arm(query_states, key_states, value_states, g_topk, bucket_keys, bucket_values, q_indices=q_idx, k_indices=k_idx, block_size=N_sample_per_bucket, window_size=self.local_window_size)
        # attn_output = attn_output + g_topk[..., [2]] * self.sink.repeat_interleave(self.num_key_value_groups, dim=0)[None, None, ...]
        if self.sink is not None:
            attn_output = attn_output + g_topk[..., [2]] * self.sink.weight.repeat_interleave(self.num_key_value_groups, dim=0)[None, None, ...]

        ########################################################################################

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)

        return attn_output, None, reg_loss


def _get_arm_cache(past_key_values: Cache, layer_idx: int) -> dict:
    """
    Get arm cache for a specific layer from past_key_values.
    
    Args:
        past_key_values: Cache object
        layer_idx: Layer index
        
    Returns:
        Dictionary containing arm cache states for the layer
    """
    if not hasattr(past_key_values, '_arm_cache_dict'):
        past_key_values._arm_cache_dict = {}
    
    if layer_idx not in past_key_values._arm_cache_dict:
        # Initialize empty dict for this layer if not exists
        past_key_values._arm_cache_dict[layer_idx] = {}
    
    return past_key_values._arm_cache_dict[layer_idx]


def _set_arm_cache(
    past_key_values: Cache,
    layer_idx: int,
    cache_k_window: torch.Tensor,
    cache_v_window: torch.Tensor,
    cache_key_idx: torch.Tensor,
    cache_bucket_keys: torch.Tensor,
    cache_bucket_values: torch.Tensor,
    beam_width_M: int,
    beam_width_M_plus: int,
    num_bucket: int,
    N_sample_per_bucket: int,
    beam_width_score: torch.Tensor,
) -> None:
    """
    Store arm cache for a specific layer in past_key_values.
    
    Args:
        past_key_values: Cache object
        layer_idx: Layer index
        cache_k_window: [B, window_size, H, D]
        cache_v_window: [B, window_size, H, D]
        cache_key_idx: [B*H, num_bucket, N_sample_per_bucket]
        cache_bucket_keys: [B*H, num_bucket*N_sample, D]
        cache_bucket_values: [B*H, num_bucket*N_sample, D]
        beam_width_M: int
        beam_width_M_plus: int
        num_bucket: int
        N_sample_per_bucket: int
        beam_width_score: torch.Tensor
    """
    if not hasattr(past_key_values, '_arm_cache_dict'):
        past_key_values._arm_cache_dict = {}
    
    past_key_values._arm_cache_dict[layer_idx] = {
        'cache_k_window': cache_k_window,
        'cache_v_window': cache_v_window,
        'cache_key_idx': cache_key_idx,
        'cache_bucket_keys': cache_bucket_keys,
        'cache_bucket_values': cache_bucket_values,
        'beam_width_M': beam_width_M,
        'beam_width_M_plus': beam_width_M_plus,
        'num_bucket': num_bucket,
        'N_sample_per_bucket': N_sample_per_bucket,
        'beam_width_score': beam_width_score,
    }


class HiLlamaAttentionWithMemory(HiLlamaAttention):
    """Wrapper adapter for ArmTopKAttentionWithMemory to match HiLlamaAttention interface."""

    def __init__(
        self,
        config: HiLlamaConfig, 
        layer_idx: int,
        shared_params=None,
        shared_g_router=None,
    ):
        # Initialize arm logic with shared g_proj and router if provided
        super().__init__(config, layer_idx, shared_params=shared_g_router)

        # Memory stuff
        self.cache_len = config.cache_len
        self.gen_len = config.gen_len
        self.gumbel_tau_bucket = config.gumbel_tau_bucket
        self.gumbel_tau_slot = config.gumbel_tau_slot
        self.context_tokens = config.context_tokens

        # Check if shared parameters are provided
        # Shared parameters will be passed via forward() function arguments, not stored as module attributes
        # This avoids issues with save_pretrained detecting shared tensors
        self.use_shared_params = shared_params is not None
        # Only create layer-specific parameters if not using shared ones
        if not self.use_shared_params:
            # Sigmoid-gated update head: alpha = sigmoid(W [old_k ⊕ new_k] + b)
            self.key_gate = nn.Linear(2 * self.head_dim, self.head_dim, bias=True)

            # Neural network for slot selection logits (replaces dot product)
            slot_hidden_dim = self.head_dim // 2
            self.slot_query_proj = nn.Linear(self.head_dim, slot_hidden_dim, bias=True)
            self.slot_key_proj = nn.Linear(self.head_dim, slot_hidden_dim, bias=True)

            # Markovian decision process: LSTM-based context model to predict beam_width per level
            # Replaced LSTM with Linear layers
            hidden_size_half = config.hidden_size // 2
            # Support multi-layer LSTM, default to 1 layer
            beam_width_lstm_num_layers = getattr(config, 'beam_width_lstm_num_layers', 1)
            # Input-to-gate transformation: [input_size] -> [4 * hidden_size]
            # Gates: input, forget, cell, output (in that order)
            # Use ModuleList to support multiple layers
            self.beam_width_lstm_input = nn.ModuleList([
                nn.Linear(config.hidden_size if i == 0 else hidden_size_half, 4 * hidden_size_half, bias=True)
                for i in range(beam_width_lstm_num_layers)
            ])
            # Hidden-to-gate transformation: [hidden_size] -> [4 * hidden_size]
            self.beam_width_lstm_hidden = nn.ModuleList([
                nn.Linear(hidden_size_half, 4 * hidden_size_half, bias=True)
                for _ in range(beam_width_lstm_num_layers)
            ])
            self.beam_width_lstm_num_layers = beam_width_lstm_num_layers

            # Step embedding to differentiate between decision steps
            # CRITICAL FIX: embedding size must be >= num_leaf_buckets to avoid out-of-bounds access
            # Get num_buckets_per_level and num_levels from config
            num_buckets_per_level = getattr(config, 'num_buckets_per_level', 2)
            num_levels = getattr(config, 'num_levels', 1)
            num_leaf_buckets = int(num_buckets_per_level ** num_levels)
            # Use max(num_levels, num_leaf_buckets) to ensure we have enough entries
            # Add some padding for safety
            embedding_size = max(config.num_levels, num_leaf_buckets) + 100
            self.step_embedding = nn.Embedding(embedding_size, config.hidden_size)
            
            # Projection: hidden_state + step_embedding -> score
            self.beam_width_proj = nn.Sequential(
                nn.Linear(config.hidden_size // 2, 1, bias=True), nn.Sigmoid()
            )
        
        # self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.config.rope_theta)

    @property
    def num_leaf_buckets(self) -> int:
        # Use router from parent class (HiLlamaAttention), which may be shared or layer-specific
        router = getattr(self, 'router', None)
        if router is None:
            # Fallback to config if router not available
            num_buckets_per_level = getattr(self.config, 'num_buckets_per_level', 4)
            num_levels = getattr(self.config, 'num_levels', 3)
            return int(num_buckets_per_level ** num_levels)
        return int(router.num_buckets_per_level**router.num_levels)

    def _lstm_cell_forward(
        self,
        input_tensor: torch.Tensor,  # [batch, seq_len, input_size]
        hidden_state: torch.Tensor,  # [1, 1, hidden_size//2]
        cell_state: torch.Tensor,  # [1, 1, hidden_size//2]
        beam_width_lstm_input: Optional[nn.Module] = None,
        beam_width_lstm_hidden: Optional[nn.Module] = None,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        LSTM cell forward pass using Linear layers with support for multiple layers.
        
        Args:
            input_tensor: Input tensor [batch, seq_len, input_size]
            hidden_state: Previous hidden state [1, 1, hidden_size//2]
            cell_state: Previous cell state [1, 1, hidden_size//2]
            
        Returns:
            output: Output tensor [batch, seq_len, hidden_size//2]
            (new_hidden, new_cell): Tuple of new hidden and cell states, each [1, 1, hidden_size//2]
        """
        batch_size, seq_len, input_size = input_tensor.shape
        hidden_size = hidden_state.shape[-1]
        
        # Extract hidden and cell states: [1, 1, hidden_size] -> [hidden_size]
        h_prev = hidden_state[0, 0, :]  # [hidden_size]
        c_prev = cell_state[0, 0, :]  # [hidden_size]
        
        # Use shared parameters if provided, otherwise use layer-specific ones
        if self.use_shared_params:
            assert beam_width_lstm_input is not None, "use_shared_params is True but beam_width_lstm_input is None"
            assert beam_width_lstm_hidden is not None, "use_shared_params is True but beam_width_lstm_hidden is None"
            lstm_input_layers = beam_width_lstm_input
            lstm_hidden_layers = beam_width_lstm_hidden
            num_layers = len(beam_width_lstm_input) if isinstance(beam_width_lstm_input, nn.ModuleList) else 1
        else:
            lstm_input_layers = self.beam_width_lstm_input
            lstm_hidden_layers = self.beam_width_lstm_hidden
            num_layers = self.beam_width_lstm_num_layers
        
        # Process sequence
        outputs = []
        h = h_prev  # [hidden_size]
        c = c_prev  # [hidden_size]
        
        for t in range(seq_len):
            x_t = input_tensor[:, t, :]  # [batch, input_size]
            
            # Process through multiple LSTM layers
            layer_input = x_t
            for layer_idx in range(num_layers):
                # Get layer-specific Linear modules
                if isinstance(lstm_input_layers, nn.ModuleList):
                    lstm_input = lstm_input_layers[layer_idx]
                    lstm_hidden = lstm_hidden_layers[layer_idx]
                else:
                    # Backward compatibility: single layer
                    lstm_input = lstm_input_layers
                    lstm_hidden = lstm_hidden_layers
                
                # Compute gates: input, forget, cell, output
                # Input transformation: [batch, input_size] -> [batch, 4 * hidden_size]
                gates_input = lstm_input(layer_input)  # [batch, 4 * hidden_size]
                
                # Hidden transformation: [hidden_size] -> [4 * hidden_size]
                # Expand h to [1, hidden_size] for Linear layer, then expand result to [batch, 4 * hidden_size]
                h_expanded = h.unsqueeze(0)  # [1, hidden_size]
                gates_hidden_single = lstm_hidden(h_expanded)  # [1, 4 * hidden_size]
                gates_hidden = gates_hidden_single.expand(batch_size, -1)  # [batch, 4 * hidden_size]
                
                # Combine gates: [batch, 4 * hidden_size]
                gates = gates_input + gates_hidden  # [batch, 4 * hidden_size]
                
                # Split into 4 gates: input, forget, cell, output
                i_t, f_t, g_t, o_t = torch.chunk(gates, 4, dim=-1)  # Each: [batch, hidden_size]
                
                # Apply activations
                i_t = torch.sigmoid(i_t)  # Input gate
                f_t = torch.sigmoid(f_t)  # Forget gate
                g_t = torch.tanh(g_t)  # Cell gate
                o_t = torch.sigmoid(o_t)  # Output gate
                
                # Expand c to [batch, hidden_size] for element-wise operations
                c_batch = c.unsqueeze(0).expand(batch_size, -1)  # [batch, hidden_size]
                
                # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t
                c_new = f_t * c_batch + i_t * g_t  # [batch, hidden_size]
                
                # Update hidden state: h_t = o_t * tanh(c_t)
                h_new = o_t * torch.tanh(c_new)  # [batch, hidden_size]
                
                # For next iteration, use the last element (or mean if batch > 1)
                if batch_size == 1:
                    h = h_new[0, :]  # [hidden_size]
                    c = c_new[0, :]  # [hidden_size]
                else:
                    # Use mean across batch for state update
                    h = h_new.mean(dim=0)  # [hidden_size]
                    c = c_new.mean(dim=0)  # [hidden_size]
                
                # For next layer, use hidden state as input (except for last layer)
                if layer_idx < num_layers - 1:
                    layer_input = h.unsqueeze(0).expand(batch_size, -1)  # [batch, hidden_size]
            
            outputs.append(h_new)
        
        # Stack outputs: [batch, seq_len, hidden_size]
        output = torch.stack(outputs, dim=1)  # [batch, seq_len, hidden_size]
        
        # Reshape hidden and cell states to match LSTM output format: [1, 1, hidden_size]
        new_hidden = h.unsqueeze(0).unsqueeze(0)  # [1, 1, hidden_size]
        new_cell = c.unsqueeze(0).unsqueeze(0)  # [1, 1, hidden_size]
        
        return output, (new_hidden, new_cell)

    def _compute_beam_width_scores(
        self, 
        hidden_states: torch.Tensor,
        shared_router: Optional[nn.Module] = None,
        shared_step_embedding: Optional[nn.Module] = None,
        shared_beam_width_proj: Optional[nn.Module] = None,
        shared_beam_width_lstm_input: Optional[nn.Module] = None,
        shared_beam_width_lstm_hidden: Optional[nn.Module] = None,
    ) -> Tuple[int, torch.Tensor]:
        """Compute effective beam_width using LSTM-based Markovian decision process."""
        # Take first l tokens and mean across batch to get initial context
        l = min(self.context_tokens, hidden_states.size(1))
        context_tokens = hidden_states[:, :l, :]  # [B, l, hidden_size]

        # Mean over batch and sequence dimensions to get a single context vector
        context_mean = context_tokens.mean(dim=0).mean(dim=0)  # [hidden_size]

        # Feed initial context through LSTM to get initial hidden state
        context_input = context_mean.unsqueeze(0).unsqueeze(0)  # [1, 1, hidden_size]
        # Initialize hidden and cell states to zeros
        hidden_size_half = self.config.hidden_size // 2
        
        h = torch.zeros(1, 1, hidden_size_half, device=hidden_states.device, dtype=hidden_states.dtype)
        c = torch.zeros(1, 1, hidden_size_half, device=hidden_states.device, dtype=hidden_states.dtype)
        _, (h_n, c_n) = self._lstm_cell_forward(
            context_input, h, c, 
            beam_width_lstm_input=shared_beam_width_lstm_input,
            beam_width_lstm_hidden=shared_beam_width_lstm_hidden
        )

        h = h_n  # [1, 1, hidden_size//2]
        c = c_n  # [1, 1, hidden_size//2]

        # Get base beam width and total buckets
        base_beam_width = 1
        # Use shared router if provided, otherwise use layer-specific one
        if hasattr(self, 'use_shared_g_proj') and self.use_shared_g_proj:
            assert shared_router is not None, "use_shared_g_proj is True but shared_router is None"
            router = shared_router
        else:
            router = self.router
        total_buckets = int(router.num_buckets_per_level**router.num_levels)

        # Initialize effective beam width
        effective_beam_width = base_beam_width
        last_score = torch.tensor(0.5, device=hidden_states.device)

        # Use shared parameters if provided, otherwise use layer-specific ones
        if self.use_shared_params:
            assert shared_step_embedding is not None, "use_shared_params is True but shared_step_embedding is None"
            assert shared_beam_width_proj is not None, "use_shared_params is True but shared_beam_width_proj is None"
            step_embedding = shared_step_embedding
            beam_width_proj = shared_beam_width_proj
        else:
            step_embedding = self.step_embedding
            beam_width_proj = self.beam_width_proj
        
        num_levels = router.num_levels
        for step in range(total_buckets):
            # Stop if next increment would exceed total buckets
            if self.training and effective_beam_width + 2 > total_buckets:
                break
            elif not self.training and effective_beam_width + 1 > total_buckets:
                break

            step_emb = step_embedding(
                torch.tensor(step, device=hidden_states.device)
            )  # [hidden_size]
            step_input = step_emb.unsqueeze(0).unsqueeze(0)  # [1, 1, hidden_size]

            lstm_out, (h, c) = self._lstm_cell_forward(
                step_input, h, c,
                beam_width_lstm_input=shared_beam_width_lstm_input,
                beam_width_lstm_hidden=shared_beam_width_lstm_hidden
            )

            h_current = h[0, 0, :]  # [hidden_size//2]
            score = beam_width_proj(h_current)  # [1] -> scalar
            last_score = score.squeeze()

            score_value = last_score.item()
            
            # Exploration mechanism: sample a random number and decide based on exploration rate
            exploration_rate = 0.1  # Exploration rate in (0, 1)
            random_sample = torch.rand(1, device=hidden_states.device).item()  # Sample [0, 1]
            
            if self.training and random_sample < exploration_rate:
                # Exploration: 50% chance to increase, 50% chance to break
                exploration_decision = torch.rand(1, device=hidden_states.device).item()
                if exploration_decision >= 0.5:
                    effective_beam_width += 1
                else:
                    # Stop incrementing (exploration break)
                    break
            else:
                # Exploitation: use original logic based on score
                if score_value >= 0.5:
                    effective_beam_width += 1
                else:
                    # Stop incrementing when we hit a score < 0.5
                    break
        return effective_beam_width, last_score

    def _compute_slot_selection_logits(
        self, 
        k_t: torch.Tensor, 
        selected_bucket_keys: torch.Tensor,
        shared_slot_query_proj: Optional[nn.Module] = None,
        shared_slot_key_proj: Optional[nn.Module] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute logits for selecting which slot in a bucket to update and gate logits."""
        # Use shared parameters if provided, otherwise use layer-specific ones
        if self.use_shared_params:
            assert shared_slot_query_proj is not None, "use_shared_params is True but shared_slot_query_proj is None"
            assert shared_slot_key_proj is not None, "use_shared_params is True but shared_slot_key_proj is None"
            slot_query_proj = shared_slot_query_proj
            slot_key_proj = shared_slot_key_proj
        else:
            slot_query_proj = self.slot_query_proj
            slot_key_proj = self.slot_key_proj
        
        # k_t_expanded = k_t.unsqueeze(2)  # [B, H, 1, D]
        # Apply projections - output is 2*slot_hidden_dim (selection + gate)
        q_slot = slot_query_proj(k_t)  # [B, H, 1, 2*slot_hidden_dim]
        k_slot = slot_key_proj(selected_bucket_keys)  # [B, H, N_sample_per_bucket, 2*slot_hidden_dim]

        # Get slot_hidden_dim from output dimension
        slot_hidden_dim = q_slot.shape[-1] // 2
        # q_slot is [B, H, 2*slot_hidden_dim] (3D), k_slot is [B, H, N_sample_per_bucket, 2*slot_hidden_dim] (4D)
        # Split into selection and gate parts
        q_slot_selection = q_slot[:, :, :slot_hidden_dim]  # [B, H, slot_hidden_dim]
        q_slot_gate = q_slot[:, :, slot_hidden_dim:]  # [B, H, slot_hidden_dim]
        k_slot_selection = k_slot[:, :, :, :slot_hidden_dim]  # [B, H, N_sample_per_bucket, slot_hidden_dim]
        k_slot_gate = k_slot[:, :, :, slot_hidden_dim:]  # [B, H, N_sample_per_bucket, slot_hidden_dim]

        # Expand q_slot dimensions for matmul: [B, H, slot_hidden_dim] -> [B, H, 1, slot_hidden_dim]
        q_slot_selection = q_slot_selection.unsqueeze(2)  # [B, H, 1, slot_hidden_dim]
        q_slot_gate = q_slot_gate.unsqueeze(2)  # [B, H, 1, slot_hidden_dim]

        # Compute logits via direct dot product (matrix multiplication)
        # [B, H, 1, slot_hidden_dim] @ [B, H, slot_hidden_dim, N_sample_per_bucket] -> [B, H, 1, N_sample_per_bucket]
        selection_logits = torch.matmul(q_slot_selection, k_slot_selection.transpose(-2, -1))  # [B, H, 1, N_sample_per_bucket]
        gate_logits = torch.matmul(q_slot_gate, k_slot_gate.transpose(-2, -1))  # [B, H, 1, N_sample_per_bucket]
        
        # Squeeze to [B, H, N_sample_per_bucket]
        selection_logits = selection_logits.squeeze(2)  # [B, H, N_sample_per_bucket]
        gate_logits = gate_logits.squeeze(2)  # [B, H, N_sample_per_bucket]

        return selection_logits, gate_logits

    def _route_query_and_key_for_token(
        self,
        q_t: torch.Tensor,
        k_t: torch.Tensor,
        B: int,
        HQ: int,
        H: int,
        beam_width_M: int,
        beam_width_M_plus: int,
        shared_router: Optional[nn.Module] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Route query and key for current token to get bucket assignments.
        
        Returns:
            query_idx_t_M_stored: [B, HQ, 1, beam_width_M]
            query_idx_t_M_plus_stored: [B, HQ, 1, beam_width_M_plus] or None
            bucket_id: [B, H] - bucket index for cache update per head
        """
        # Use shared router if provided, otherwise use layer-specific one
        # Use shared router if provided, otherwise use layer-specific one
        if hasattr(self, 'use_shared_g_proj') and self.use_shared_g_proj:
            assert shared_router is not None, "use_shared_g_proj is True but shared_router is None"
            router = shared_router
        else:
            router = self.router
        
        # Use view instead of flatten to avoid copy - reshape in-place where possible
        q_t_for_route = q_t.view(B * HQ, 1, -1)  # [B*HQ, 1, D] - view, no copy
        k_t_for_route = k_t.view(B * H, 1, -1)  # [B*H, 1, D] - view, no copy

        # Route queries
        if self.training:
            query_idx_t_M_plus = router.assign_queries(q_t_for_route, beam_width=beam_width_M_plus)
            query_idx_t_M = query_idx_t_M_plus[..., :-1]  # Slice, no copy
            query_idx_t_M_stored = query_idx_t_M.view(B, HQ, 1, -1)
            query_idx_t_M_plus_stored = query_idx_t_M_plus.view(B, HQ, 1, -1)
        else:
            query_idx_t_M = router.assign_queries(q_t_for_route, beam_width=beam_width_M)
            query_idx_t_M_stored = query_idx_t_M.view(B, HQ, 1, -1)
            query_idx_t_M_plus_stored = None
        
        # Route keys for cache update
        key_bucket_idx = router.assign_queries(k_t_for_route, beam_width=1)
        key_bucket_idx = key_bucket_idx.view(B, H, 1, -1)
        bucket_id = key_bucket_idx[:, :, 0, 0]  # Direct indexing instead of double squeeze
        
        return query_idx_t_M_stored, query_idx_t_M_plus_stored, bucket_id

    def _gather_bucket_keys_values(
        self,
        key_idx: torch.Tensor,
        key_flat: torch.Tensor,
        value_flat: torch.Tensor,
        B: int,
        H: int,
        hdim: int,
        expand_for_heads: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Gather bucket keys and values using arm indices."""
        if expand_for_heads:
            # Use view + repeat instead of expand to avoid creating large intermediate tensor
            key_idx_expanded = key_idx.unsqueeze(1).repeat(1, H, 1, 1).view(-1, *key_idx.shape[1:])
        else:
            key_idx_expanded = key_idx

        # Use view instead of flatten to avoid copy
        gather_idx = key_idx_expanded.view(key_idx_expanded.shape[0], -1)  # [B*H, num_bucket*N_sample_per_bucket]
        
        # Use advanced indexing instead of gather + expand for better memory efficiency
        # gather_idx: [B*H, num_bucket*N_sample_per_bucket]
        # key_flat: [B*H, seq_len, hdim]
        # Use direct indexing: key_flat[batch_idx, gather_idx[batch_idx], :]
        batch_idx = torch.arange(gather_idx.shape[0], device=gather_idx.device)[:, None]  # [B*H, 1]
        bucket_keys = key_flat[batch_idx, gather_idx, :]  # [B*H, num_bucket*N_sample_per_bucket, hdim]
        bucket_values = value_flat[batch_idx, gather_idx, :]  # [B*H, num_bucket*N_sample_per_bucket, hdim]

        return bucket_keys, bucket_values

    def _select_slot_in_bucket(
        self,
        k_t: torch.Tensor,
        bucket_id: torch.Tensor,
        cache_bucket_keys: torch.Tensor,
        cache_bucket_values: torch.Tensor,
        cache_key_idx: torch.Tensor,
        B: int,
        H: int,
        N_sample_per_bucket: int,
        D: int,
        shared_slot_query_proj: Optional[nn.Module] = None,
        shared_slot_key_proj: Optional[nn.Module] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Select which slot in the bucket to update."""
        # Use direct indexing instead of gather to avoid creating intermediate tensors
        # bucket_id: [B, H] -> flatten to [B*H]
        bucket_id_flat = bucket_id.flatten()  # [B*H]
        bucket_start_idx = bucket_id_flat * N_sample_per_bucket  # [B*H]
        
        # Create slot_indices using in-place operations where possible
        slot_indices = bucket_start_idx[:, None] + torch.arange(N_sample_per_bucket, device=bucket_id.device, dtype=bucket_id.dtype)[None, :]  # [B*H, N_sample_per_bucket]

        # Use gather for correct indexing: expand indices to match D dimension for gather
        slot_indices_expanded = slot_indices.unsqueeze(-1).expand(-1, -1, D)  # [B*H, N_sample_per_bucket, D]

        selected_bucket_keys_flat = torch.gather(cache_bucket_keys, dim=1, index=slot_indices_expanded)  # [B*H, N_sample_per_bucket, D]
        selected_bucket_keys = selected_bucket_keys_flat.view(B, H, N_sample_per_bucket, D)
        
        # Use gather for values too
        selected_bucket_values_flat = torch.gather(cache_bucket_values, dim=1, index=slot_indices_expanded)  # [B*H, N_sample_per_bucket, D]
        selected_bucket_values = selected_bucket_values_flat.view(B, H, N_sample_per_bucket, D)

        selection_logits, gate_logits = self._compute_slot_selection_logits(
            k_t, selected_bucket_keys,
            shared_slot_query_proj=shared_slot_query_proj,
            shared_slot_key_proj=shared_slot_key_proj
        )
        slot_probs = F.gumbel_softmax(selection_logits, tau=self.gumbel_tau_slot, hard=False, dim=-1)

        alpha = torch.sigmoid(gate_logits)  # [B, H, N_sample_per_bucket]

        return slot_probs, alpha, selected_bucket_keys, selected_bucket_values

    def _update_cache_and_sliding_window(
        self,
        cache_k_window: torch.Tensor,
        cache_v_window: torch.Tensor,
        cache_bucket_keys: torch.Tensor,
        cache_bucket_values: torch.Tensor,
        k_t: torch.Tensor,
        v_t: torch.Tensor,
        slot_probs: torch.Tensor,
        alpha: torch.Tensor,
        selected_bucket_keys: torch.Tensor,
        selected_bucket_values: torch.Tensor,
        bucket_id: torch.Tensor,
        B: int,
        H: int,
        N_sample_per_bucket: int,
        D: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Update cache with new rule: KV_new = KV_old + slot_probs * alpha * (KV_new - KV_old)."""
        # Normalize k_t and v_t to [B, H, D] shape - use view/squeeze without creating new tensors
        if k_t.dim() == 3:
            k_t_normalized = k_t
        elif k_t.dim() == 4 and k_t.size(2) == 1:
            k_t_normalized = k_t.squeeze(2)
        elif k_t.dim() == 4:
            k_t_normalized = k_t[:, :, 0, :]  # Direct indexing, no copy
        else:
            k_t_normalized = k_t.view(B, H, D)
        
        if v_t.dim() == 3:
            v_t_normalized = v_t
        elif v_t.dim() == 4 and v_t.size(2) == 1:
            v_t_normalized = v_t.squeeze(2)
        elif v_t.dim() == 4:
            v_t_normalized = v_t[:, :, 0, :]  # Direct indexing, no copy
        else:
            v_t_normalized = v_t.view(B, H, D)
        
        # Compute update in-place where possible: slot_probs * alpha * (k_t - old_k)
        # Use broadcasting instead of expand to avoid creating new tensors
        # slot_probs: [B, H, N_sample_per_bucket], alpha: [B, H, N_sample_per_bucket]
        # k_t_normalized: [B, H, D], selected_bucket_keys: [B, H, N_sample_per_bucket, D]
        
        # Compute update directly using broadcasting (no expand needed)
        # k_t_normalized[:, :, None, :] broadcasts to [B, H, 1, D] -> [B, H, N_sample_per_bucket, D]
        k_diff = k_t_normalized[:, :, None, :] - selected_bucket_keys  # [B, H, N_sample_per_bucket, D]
        v_diff = v_t_normalized[:, :, None, :] - selected_bucket_values  # [B, H, N_sample_per_bucket, D]
        
        # slot_probs[:, :, :, None] broadcasts to [B, H, N_sample_per_bucket, 1]
        # alpha[:, :, :, None] broadcasts to [B, H, N_sample_per_bucket, 1]
        update_k = (slot_probs[:, :, :, None] * alpha[:, :, :, None]) * k_diff  # [B, H, N_sample_per_bucket, D]
        update_v = (slot_probs[:, :, :, None] * alpha[:, :, :, None]) * v_diff  # [B, H, N_sample_per_bucket, D]
        
        # Update in-place: new = old + update (reuse selected_bucket_keys memory)
        new_k_all = selected_bucket_keys.add_(update_k)  # In-place addition
        new_v_all = selected_bucket_values.add_(update_v)  # In-place addition
        
        # Update cache_bucket_keys directly using indexing (avoid scatter)
        bucket_id_flat = bucket_id.flatten()  # [B*H]
        bucket_start_idx = bucket_id_flat * N_sample_per_bucket  # [B*H]
        slot_indices = bucket_start_idx[:, None] + torch.arange(N_sample_per_bucket, device=bucket_id.device, dtype=bucket_id.dtype)[None, :]  # [B*H, N_sample_per_bucket]
        
        # Direct assignment using scatter_ for batched updates
        new_k_flat = new_k_all.view(B * H, N_sample_per_bucket, D)  # [B*H, N_sample_per_bucket, D]
        new_v_flat = new_v_all.view(B * H, N_sample_per_bucket, D)  # [B*H, N_sample_per_bucket, D]
        
        # Use scatter_ for batched in-place updates: scatter along dim=1 (cache_len dimension)
        # scatter_(dim, index, src) writes src[i, j, k] to target[i, index[i, j, k], k]
        # We need index to be [B*H, N_sample_per_bucket, D] with same index for all D elements
        slot_indices_expanded = slot_indices.unsqueeze(-1).expand(-1, -1, D).to(torch.long)  # [B*H, N_sample_per_bucket, D]
        cache_bucket_keys.scatter_(1, slot_indices_expanded, new_k_flat.to(cache_bucket_keys.dtype))
        cache_bucket_values.scatter_(1, slot_indices_expanded, new_v_flat.to(cache_bucket_values.dtype))

        # Update sliding window - use in-place operations
        # cache_k_window: [B, window_size, H, D], k_t_normalized: [B, H, D]
        current_window_size = cache_k_window.size(1)
        if current_window_size < self.local_window_size:
            # If window is not full, directly append new token
            k_t_expanded = k_t_normalized.unsqueeze(1)  # [B, 1, H, D]
            v_t_expanded = v_t_normalized.unsqueeze(1)  # [B, 1, H, D]
            cache_k_window = torch.cat([cache_k_window, k_t_expanded], dim=1)  # [B, current_window_size + 1, H, D]
            cache_v_window = torch.cat([cache_v_window, v_t_expanded], dim=1)  # [B, current_window_size + 1, H, D]
        else:
            # Shift left: copy [:, 1:, :, :] to [:, :-1, :, :]
            cache_k_window[:, :-1, :, :] = cache_k_window[:, 1:, :, :].clone()
            cache_v_window[:, :-1, :, :] = cache_v_window[:, 1:, :, :].clone()

            # Append new token at the end (overwrites the wrapped-around element from roll)
            cache_k_window[:, -1, :, :] = k_t_normalized
            cache_v_window[:, -1, :, :] = v_t_normalized

        return cache_k_window, cache_v_window, cache_bucket_keys, cache_bucket_values

    def _initialize_cache_and_arm(
        self,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
        cache_q: torch.Tensor,
        hidden_states: torch.Tensor,
        B: int,
        H: int,
        HQ: int,
        hdim: int,
        shared_router: Optional[nn.Module] = None,
        shared_step_embedding: Optional[nn.Module] = None,
        shared_beam_width_proj: Optional[nn.Module] = None,
        shared_beam_width_lstm_input: Optional[nn.Module] = None,
        shared_beam_width_lstm_hidden: Optional[nn.Module] = None,
    ) -> Tuple[
        torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, int
    ]:
        """Initialize cache, compute beam width, and route cache into buckets."""
        effective_beam_width, beam_width_score = self._compute_beam_width_scores(
            hidden_states,
            shared_router=shared_router,
            shared_step_embedding=shared_step_embedding,
            shared_beam_width_proj=shared_beam_width_proj,
            shared_beam_width_lstm_input=shared_beam_width_lstm_input,
            shared_beam_width_lstm_hidden=shared_beam_width_lstm_hidden,
        )
        beam_width_M = effective_beam_width
        times = 4
        beam_width_M = effective_beam_width // times if effective_beam_width // times > 0 else 4 # TODO
        beam_width_M_plus = effective_beam_width + 1
        cache_len = cache_k.size(2)  # Get cache_len from cache_k shape [B, H, cache_len, D]
        cache_k_flat = cache_k.flatten(0, 1)  # [B*H, cache_len, D]
        cache_v_flat = cache_v.flatten(0, 1)  # [B*H, cache_len, D]
        cache_q_flat = cache_q.flatten(0, 1)  # [B*HQ, cache_len, D]
        
        # Use shared router if provided, otherwise use layer-specific one
        # Use shared router if provided, otherwise use layer-specific one
        if hasattr(self, 'use_shared_g_proj') and self.use_shared_g_proj:
            assert shared_router is not None, "use_shared_g_proj is True but shared_router is None"
            router = shared_router
        else:
            router = self.router
        
        reg_loss = 0
        if self.training:
            reg_loss = router(cache_q_flat, cache_k_flat)
        with torch.no_grad():
            # For training, need both M and M_plus versions
            if self.training:
                cache_query_idx_M_plus, cache_key_idx = router.test(cache_q_flat, cache_k_flat, beam_width=beam_width_M_plus)
                # cache_query_idx_M_plus shape: [B*HQ, cache_len, beam_width_M_plus]
                # Extract M version by slicing
                cache_query_idx_M = cache_query_idx_M_plus[..., :-1]  # [B*HQ, cache_len, beam_width_M]
                cache_query_idx_M_stored = cache_query_idx_M.view(B, HQ, cache_len, beam_width_M)
                cache_query_idx_M_plus_stored = cache_query_idx_M_plus.view(B, HQ, cache_len, beam_width_M_plus)
            else:
                cache_query_idx, cache_key_idx = router.test(cache_q_flat, cache_k_flat, beam_width=beam_width_M)
                # cache_query_idx shape: [B*HQ, cache_len, beam_width_M]
                # Reshape to [B, HQ, cache_len, beam_width_M] for cache_query_idx_M_stored
                cache_query_idx_M_stored = cache_query_idx.view(B, HQ, cache_len, beam_width_M)
                cache_query_idx_M_plus_stored = None

        _, num_bucket, N_sample_per_bucket = cache_key_idx.shape
        # TODO: what is the purpose of this?
        cache_bucket_keys, cache_bucket_values = self._gather_bucket_keys_values(
            cache_key_idx, cache_k_flat, cache_v_flat, B, H, hdim, expand_for_heads=False
        )

        return (
            cache_bucket_keys,
            cache_bucket_values,
            cache_key_idx,
            cache_query_idx_M_stored,
            cache_query_idx_M_plus_stored,
            reg_loss,
            beam_width_score,
            beam_width_M,
            beam_width_M_plus,
            num_bucket,
            N_sample_per_bucket,
        )

    def _compute_attention_for_token(
        self,
        q_t_attn: torch.Tensor,
        cache_k_for_attn: torch.Tensor,
        cache_v_for_attn: torch.Tensor,
        cache_bucket_keys_reshaped: torch.Tensor,
        cache_bucket_values_reshaped: torch.Tensor,
        cache_key_idx_reshaped: torch.Tensor,
        query_idx_t_M_stored: torch.Tensor,
        query_idx_t_M_plus_stored: Optional[torch.Tensor],
        g_topk_t: torch.Tensor,
        beam_width_score: torch.Tensor,
        N_sample_per_bucket: int,
        HQ: int,
        use_causal_swa: bool = True,
    ) -> torch.Tensor:
        """Compute attention output for current token."""
        
        if self.training:
            query_idx_t_M_expanded = query_idx_t_M_stored.permute(0, 2, 1, 3)
            attn_output_t_M = parallel_arm(
                q_t_attn,
                cache_k_for_attn,
                cache_v_for_attn,
                g_topk_t,
                cache_bucket_keys_reshaped,
                cache_bucket_values_reshaped,
                q_indices=query_idx_t_M_expanded,
                k_indices=cache_key_idx_reshaped,
                block_size=N_sample_per_bucket,
                window_size=self.local_window_size,
                use_causal_swa=use_causal_swa,
            )

            query_idx_t_M_plus_expanded = query_idx_t_M_plus_stored.permute(0, 2, 1, 3)
            attn_output_t_M_plus = parallel_arm(
                q_t_attn,
                cache_k_for_attn,
                cache_v_for_attn,
                g_topk_t,
                cache_bucket_keys_reshaped,
                cache_bucket_values_reshaped,
                q_indices=query_idx_t_M_plus_expanded,
                k_indices=cache_key_idx_reshaped,
                block_size=N_sample_per_bucket,
                window_size=self.local_window_size,
                use_causal_swa=use_causal_swa,
            )

            attn_output_t = (1.0 - beam_width_score) * attn_output_t_M + beam_width_score * attn_output_t_M_plus
        else:
            query_idx_t_M_expanded = query_idx_t_M_stored.permute(0, 2, 1, 3)
            attn_output_t = parallel_arm(
                q_t_attn,
                cache_k_for_attn,
                cache_v_for_attn,
                g_topk_t,
                cache_bucket_keys_reshaped,
                cache_bucket_values_reshaped,
                q_indices=query_idx_t_M_expanded,
                k_indices=cache_key_idx_reshaped,
                block_size=N_sample_per_bucket,
                window_size=self.local_window_size,
                use_causal_swa=use_causal_swa,
            )

        if self.sink is not None:
            attn_output_t = (
                attn_output_t
                + g_topk_t[..., [2]]
                * self.sink.weight.repeat_interleave(self.num_key_value_groups, dim=0)[None, None, ...]
            )
        # attn_output_t from HiR is already in correct format [B, 1, HQ, D], no need for transpose or permute
        # attn_output_t = attn_output_t.transpose(1, 2)  # Removed: not needed

        return attn_output_t

    def _route_queries_and_keys_batched(
        self,
        q_gen: torch.Tensor,
        k_gen: torch.Tensor,
        B: int,
        HQ: int,
        H: int,
        gen_len: int,
        beam_width_M: int,
        beam_width_M_plus: int,
        shared_router: Optional[nn.Module] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
        """Route all generation queries and keys to get bucket assignments (batched).
        
        Returns:
            query_idx_gen_M: [B, HQ, gen_len, beam_width_M]
            query_idx_gen_M_plus: [B, HQ, gen_len, beam_width_M_plus] or None
            key_bucket_idx_gen: [B, H, gen_len]
        """
        # Flatten q_gen and k_gen for arm (in-place reshape)
        # q_gen: [B, HQ, gen_len, D] -> [B*HQ*gen_len, 1, D]
        q_gen = q_gen.reshape(B * HQ * gen_len, -1).unsqueeze(1)  # [B*HQ*gen_len, 1, D]
        
        # k_gen: [B, H, gen_len, D] -> [B*H*gen_len, 1, D]
        k_gen = k_gen.reshape(B * H * gen_len, -1).unsqueeze(1)  # [B*H*gen_len, 1, D]
        
        # Use shared router if provided, otherwise use layer-specific one
        # Use shared router if provided, otherwise use layer-specific one
        if hasattr(self, 'use_shared_g_proj') and self.use_shared_g_proj:
            assert shared_router is not None, "use_shared_g_proj is True but shared_router is None"
            router = shared_router
        else:
            router = self.router
        
        # Route queries (batched)
        if self.training:
            query_idx_gen_M_plus = router.assign_queries(q_gen, beam_width=beam_width_M_plus)
            query_idx_gen_M = query_idx_gen_M_plus[..., :-1]
            query_idx_gen_M = query_idx_gen_M.view(B, HQ, gen_len, -1)  # [B, HQ, gen_len, beam_width_M]
            query_idx_gen_M_plus = query_idx_gen_M_plus.view(B, HQ, gen_len, -1)  # [B, HQ, gen_len, beam_width_M_plus]
        else:
            query_idx_gen_M = router.assign_queries(q_gen, beam_width=beam_width_M)
            query_idx_gen_M = query_idx_gen_M.view(B, HQ, gen_len, -1)  # [B, HQ, gen_len, beam_width_M]
            query_idx_gen_M_plus = None
        
        # Route keys for cache update (batched)
        key_bucket_idx_gen = router.assign_queries(k_gen, beam_width=1)
        key_bucket_idx_gen = key_bucket_idx_gen.view(B, H, gen_len, -1)
        key_bucket_idx_gen = key_bucket_idx_gen.squeeze(-1)  # [B, H, gen_len]

        return query_idx_gen_M, query_idx_gen_M_plus, key_bucket_idx_gen

    def _prepare_generation_attention_inputs(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        hidden_states: torch.Tensor,
        cache_k_window: torch.Tensor,
        cache_v_window: torch.Tensor,
        B: int,
        HQ: int,
        H: int,
        gen_len: int,
        beam_width_M: int,
        beam_width_M_plus: int,
        shared_router: Optional[nn.Module] = None,
        shared_g_proj: Optional[nn.Module] = None,
    ) -> Tuple[
        torch.Tensor,  # query_idx_gen_M_stored
        Optional[torch.Tensor],  # query_idx_gen_M_plus_stored
        torch.Tensor,  # bucket_id_gen
        torch.dtype,  # target_dtype_gen
        torch.Tensor,  # g_topk_gen
        torch.Tensor,  # q_gen_attn
        Optional[torch.Tensor],  # query_idx_gen_M_stored_padded
        Optional[torch.Tensor],  # query_idx_gen_M_plus_stored_padded
        torch.Tensor,  # cache_k_window (updated)
        torch.Tensor,  # cache_v_window (updated)
        torch.Tensor,  # cache_k_for_attn
        torch.Tensor,  # cache_v_for_attn
        torch.Tensor,  # k_gen
        torch.Tensor,  # v_gen
    ]:
        """Prepare all inputs needed for generation attention computation.
        
        This function extracts generation tokens, routes queries/keys, computes attention
        inputs, performs padding operations, and updates the sliding window cache.
        
        Args:
            q: [B, HQ, L, D] query tensor
            k: [B, H, L, D] key tensor
            v: [B, H, L, D] value tensor
            hidden_states: [B, L, hidden_size] hidden states
            cache_k_window: [B, window_size, H, D] sliding window cache for keys
            cache_v_window: [B, window_size, H, D] sliding window cache for values
            B: Batch size
            HQ: Number of query heads
            H: Number of key/value heads
            gen_len: Generation length (tokens beyond cache_len)
            beam_width_M: Beam width for M
            beam_width_M_plus: Beam width for M_plus
            shared_router: Optional shared router module
            shared_g_proj: Optional shared g_proj module
            
        Returns:
            query_idx_gen_M_stored: [B, HQ, gen_len, beam_width_M]
            query_idx_gen_M_plus_stored: [B, HQ, gen_len, beam_width_M_plus] or None
            bucket_id_gen: [B, H, gen_len]
            target_dtype_gen: Target dtype for attention computation
            g_topk_gen: [B, gen_len, HQ, g_proj_output_dim]
            q_gen_attn: [B, gen_len, HQ, D]
            query_idx_gen_M_stored_padded: [B, HQ, local_window_size + gen_len, beam_width_M] or None
            query_idx_gen_M_plus_stored_padded: [B, HQ, local_window_size + gen_len, beam_width_M_plus] or None
            cache_k_window: [B, local_window_size + gen_len, H, D] (updated)
            cache_v_window: [B, local_window_size + gen_len, H, D] (updated)
            cache_k_for_attn: [B, local_window_size + gen_len, H, D]
            cache_v_for_attn: [B, local_window_size + gen_len, H, D]
            k_gen: [B, H, gen_len, D]
            v_gen: [B, H, gen_len, D]
        """
        # 1) Get all generation tokens' q/k/v (all tokens beyond cache_len)
        q_gen = q[:, :, self.cache_len:, :]  # [B, HQ, gen_len, D]
        k_gen = k[:, :, self.cache_len:, :]  # [B, H, gen_len, D]
        v_gen = v[:, :, self.cache_len:, :]  # [B, H, gen_len, D]

        # 2) Route all queries and keys to get bucket assignments (batched)
        query_idx_gen_M_stored, query_idx_gen_M_plus_stored, bucket_id_gen = \
            self._route_queries_and_keys_batched(
                q_gen, k_gen, B, HQ, H, gen_len, beam_width_M, beam_width_M_plus,
                shared_router=shared_router
            )  ### TODO: check bucket_id_gen

        # Pre-compute target_dtype_gen from v_gen (before window update)
        if v_gen.dtype == torch.float16 or v_gen.dtype == torch.bfloat16:
            target_dtype_gen = v_gen.dtype
        else:
            target_dtype_gen = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        
        # Pre-compute hidden_gen, g_topk_gen, and q_gen_attn (can be done early)
        hidden_gen = hidden_states[:, self.cache_len:, :]  # [B, gen_len, hidden_size]
        # Use shared g_proj if provided, otherwise use layer-specific one
        if hasattr(self, 'use_shared_g_proj') and self.use_shared_g_proj:
            assert shared_g_proj is not None, "use_shared_g_proj is True but shared_g_proj is None"
            g_proj = shared_g_proj
        else:
            g_proj = self.g_proj
        # Output dimension depends on use_sink: 3 if sink token is used, 2 otherwise
        g_proj_output_dim = 3 if self.config.use_sink else 2
        g_proj_output = g_proj(hidden_gen)  # [B, gen_len, HQ*g_proj_output_dim]
        g_topk_gen = F.softmax(g_proj_output.view(B, gen_len, HQ, g_proj_output_dim), dim=-1)  # [B, gen_len, HQ, g_proj_output_dim]
        q_gen_attn = q_gen.transpose(1, 2).to(target_dtype_gen)  # [B, gen_len, HQ, D]
        
        # Pad query indices to match the padded sequence length
        if query_idx_gen_M_stored is not None:
            pad_shape_idx_M = (B, HQ, self.local_window_size, query_idx_gen_M_stored.shape[-1])
            query_idx_gen_M_stored_padded = torch.cat([
                torch.zeros(pad_shape_idx_M, dtype=query_idx_gen_M_stored.dtype, device=query_idx_gen_M_stored.device),
                query_idx_gen_M_stored
            ], dim=2)  # [B, HQ, self.local_window_size + gen_len, beam_width_M]
        else:
            query_idx_gen_M_stored_padded = None
        
        if query_idx_gen_M_plus_stored is not None:
            pad_shape_idx_M_plus = (B, HQ, self.local_window_size, query_idx_gen_M_plus_stored.shape[-1])
            query_idx_gen_M_plus_stored_padded = torch.cat([
                torch.zeros(pad_shape_idx_M_plus, dtype=query_idx_gen_M_plus_stored.dtype, device=query_idx_gen_M_plus_stored.device),
                query_idx_gen_M_plus_stored
            ], dim=2)  # [B, HQ, self.local_window_size + gen_len, beam_width_M_plus]
        else:
            query_idx_gen_M_plus_stored_padded = None
        
        # Update sliding window: append all new tokens (can be done early)
        k_gen_window = k_gen.transpose(1, 2)  # [B, gen_len, H, D]
        v_gen_window = v_gen.transpose(1, 2)  # [B, gen_len, H, D]
        cache_k_window = torch.cat([cache_k_window, k_gen_window], dim=1)  # [B, window_size + gen_len, H, D]
        cache_v_window = torch.cat([cache_v_window, v_gen_window], dim=1)  # [B, window_size + gen_len, H, D]
        # Keep only the last window_size + gen_len tokens
        cache_k_window = cache_k_window[:, -(self.local_window_size + gen_len):, :, :]
        cache_v_window = cache_v_window[:, -(self.local_window_size + gen_len):, :, :]
        
        # Pre-compute cache_k_for_attn and cache_v_for_attn (will be used later)
        cache_k_for_attn = cache_k_window.to(target_dtype_gen)  # [B, window_size + gen_len, H, D]
        cache_v_for_attn = cache_v_window.to(target_dtype_gen)  # [B, window_size + gen_len, H, D]
        
        return (
            query_idx_gen_M_stored,
            query_idx_gen_M_plus_stored,
            bucket_id_gen,
            target_dtype_gen,
            g_topk_gen,
            q_gen_attn,
            query_idx_gen_M_stored_padded,
            query_idx_gen_M_plus_stored_padded,
            cache_k_window,
            cache_v_window,
            cache_k_for_attn,
            cache_v_for_attn,
            k_gen,
            v_gen,
        )

    def _rebalance_bucket_assignments(
        self,
        bucket_id_gen: torch.Tensor,
        num_bucket: int,
        threshold_ratio: float = 0.25,
    ) -> torch.Tensor:
        """Rebalance bucket assignments if most tokens assign to a single bucket.
        
        If for any (batch, head) pair, more than threshold_ratio of tokens are assigned
        to a single bucket, randomly reassign all tokens for that pair uniformly.
        
        Optimized version using efficient counting without materializing one-hot tensors.
        Time: O(B*H*gen_len) for counting, O(B*H*num_bucket) for max finding
        Space: O(B*H*num_bucket) for counts (vs O(B*H*gen_len*num_bucket) for one-hot)
        
        Args:
            bucket_id_gen: [B, H, gen_len] tensor of bucket assignments
            num_bucket: Total number of buckets
            threshold_ratio: Fraction threshold (default 0.25 = 1/4)
            
        Returns:
            bucket_id_gen: [B, H, gen_len] tensor with potentially rebalanced assignments
        """
        B, H, gen_len = bucket_id_gen.shape
        threshold = gen_len * threshold_ratio
        
        # Reshape to [B*H, gen_len] for batch processing
        bucket_id_flat = bucket_id_gen.view(B * H, gen_len)  # [B*H, gen_len]
        
        # Count occurrences per bucket per (batch, head) pair using scatter_add_
        # Space: O(B*H*num_bucket) instead of O(B*H*gen_len*num_bucket) for one-hot
        bucket_counts = torch.zeros(B * H, num_bucket, device=bucket_id_gen.device, dtype=torch.long)
        # Use scatter_add on dimension 1 (bucket dimension) for each row
        # bucket_id_flat: [B*H, gen_len], each element is a bucket_id in [0, num_bucket-1]
        # We want to count: for each row i, how many times does each bucket_id j appear?
        # Use scatter_add with dim=1, index=bucket_id_flat, src=ones
        ones = torch.ones_like(bucket_id_flat, dtype=torch.long)  # [B*H, gen_len]
        bucket_counts.scatter_add_(1, bucket_id_flat, ones)  # [B*H, num_bucket]
        
        # Find max count per (batch, head): [B*H]
        max_counts = bucket_counts.max(dim=1)[0]
        
        # Create mask for pairs needing reassignment: [B*H]
        needs_reassignment = max_counts > threshold
        
        # Early exit if no reassignment needed (time optimization)
        if not needs_reassignment.any():
            return bucket_id_gen
        
        # Only generate random assignments for pairs that need it (space/time optimization)
        # Only allocate memory for pairs that actually need reassignment
        num_needs_reassign = needs_reassignment.sum().item()
        # Generate random assignments only for pairs that need reassignment: [num_needs_reassign, gen_len]
        random_assignments = torch.randint(
            0, num_bucket, size=(num_needs_reassign, gen_len),
            device=bucket_id_gen.device, dtype=bucket_id_gen.dtype
        )
        # Create output tensor (clone to avoid modifying input if it's a view)
        # bucket_id_out = bucket_id_flat.clone()
        # Apply random assignments only where needed
        bucket_id_flat[needs_reassignment] = random_assignments
        # Reshape back to original shape: [B, H, gen_len]
        return bucket_id_flat.view(B, H, gen_len)

    def _sort_and_pad_bucket_id_and_k(
        self,
        bucket_id_gen: torch.Tensor,
        k_gen: torch.Tensor,
        cache_bucket_keys: torch.Tensor,
        B: int,
        H: int,
        gen_len: int,
        N_sample_per_bucket: int,
        D: int,
        v_gen: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, int, int, dict, torch.Tensor]:
        """Sort tokens by bucket_id and pad buckets to same size for efficient processing.
        
        Optimized version with improved time and space complexity.
        Time: O(B*H*gen_len*log(gen_len)) for sorting, O(B*H*gen_len) for position computation
        Space: O(B*H*gen_len) intermediate tensors (reduced from O(B*H*gen_len*log(B*H*gen_len)))
        
        Returns:
            bucket_id_padded: [B, H, num_bucket, max_tokens_per_bucket]
            k_gen_padded: [B, H, num_bucket, max_tokens_per_bucket, D]
            num_bucket: int
            max_tokens_per_bucket: int (max tokens per bucket)
            unsort_info: dict with information needed to unsort
            v_gen_padded: [B, H, num_bucket, max_tokens_per_bucket, D]
        """
        cache_bucket_keys = cache_bucket_keys.view(B, H, -1, D)
        # Get num_bucket from cache_bucket_keys shape
        num_bucket = cache_bucket_keys.size(2) // N_sample_per_bucket

        # Sort tokens by bucket_id independently for each (B, H) pair along gen_len dimension
        # Time: O(B*H*gen_len*log(gen_len))
        bucket_id_sorted, sort_indices = torch.sort(bucket_id_gen, dim=2, stable=True)  # [B, H, gen_len], [B, H, gen_len]
        
        # Compute position_in_group efficiently using vectorized operations
        # After sorting, tokens are grouped by bucket within each (B, H) pair
        # Strategy: Use cumulative sum with reset at bucket boundaries
        bucket_id_sorted_flat = bucket_id_sorted.view(B * H, gen_len)  # [B*H, gen_len]
        
        # Detect bucket changes: True where bucket_id changes (new bucket group starts)
        bucket_changes = torch.cat([
            torch.ones(B * H, 1, dtype=torch.bool, device=bucket_id_gen.device),
            bucket_id_sorted_flat[:, 1:] != bucket_id_sorted_flat[:, :-1]
        ], dim=1)  # [B*H, gen_len]
        
        # Compute position within each bucket group using vectorized cumulative sum with reset
        # For each row, we want a counter that increments within each bucket group and resets at bucket changes
        # Strategy: Use cumsum of ones, then subtract cumsum of bucket_changes weighted by current position
        
        # Create position indices: [B*H, gen_len] with values 0, 1, 2, ..., gen_len-1 per row
        pos_indices = torch.arange(gen_len, device=bucket_id_gen.device, dtype=torch.long).unsqueeze(0).expand(B * H, -1)  # [B*H, gen_len]
        
        # Mark bucket start positions: at bucket changes, store position index; elsewhere use -1
        # Then use cummax to forward-fill (find maximum start position seen so far, which is the most recent)
        bucket_start_markers = torch.where(
            bucket_changes,
            pos_indices,
            torch.full((B * H, gen_len), -1, dtype=torch.long, device=bucket_id_gen.device)
        )  # [B*H, gen_len]
        
        # Use cummax to forward-fill: for each position, find the maximum (most recent) bucket start <= current position
        # This gives us the start position of the current bucket group
        bucket_group_starts = bucket_start_markers.cummax(dim=1)[0]  # [B*H, gen_len]
        
        # Handle case where first element is not a bucket change (shouldn't happen, but be safe)
        # If bucket_group_starts[i] == -1, it means no bucket start found yet, so position = i
        position_in_group_flat = torch.where(
            bucket_group_starts >= 0,
            pos_indices - bucket_group_starts,
            pos_indices  # Fallback: if no start found, use position as-is
        )  # [B*H, gen_len]
        
        # Reshape back to [B, H, gen_len]
        position_in_group = position_in_group_flat.view(B, H, gen_len)  # [B, H, gen_len]
        
        # Find max tokens per bucket
        max_tokens_per_bucket = (position_in_group + 1).max().item()  # +1 because positions are 0-indexed
        
        # Initialize padded tensors: [B, H, num_bucket, max_tokens_per_bucket, ...]
        bucket_id_padded = torch.zeros(B, H, num_bucket, max_tokens_per_bucket, device=bucket_id_gen.device, dtype=bucket_id_gen.dtype) - 1  # -1 for padding
        k_gen_padded = torch.zeros(B, H, num_bucket, max_tokens_per_bucket, D, device=k_gen.device, dtype=k_gen.dtype)
        v_gen_padded = None
        if v_gen is not None:
            v_gen_padded = torch.zeros(B, H, num_bucket, max_tokens_per_bucket, D, device=v_gen.device, dtype=v_gen.dtype)
        
        # Efficiently place tokens into padded structure using linear indexing
        # Prepare indices (reuse computations where possible)
        batch_flat = torch.arange(B, device=bucket_id_gen.device).repeat_interleave(H * gen_len)  # [B*H*gen_len]
        head_flat = torch.arange(H, device=bucket_id_gen.device).repeat(B).repeat_interleave(gen_len)  # [B*H*gen_len]
        bucket_flat = bucket_id_sorted.flatten()  # [B*H*gen_len]
        pos_flat = position_in_group.flatten()  # [B*H*gen_len]
        sort_indices_flat = sort_indices.flatten()  # [B*H*gen_len]

        # Compute linear source indices: k_gen[b, h, sort_indices[b, h, sorted_pos]]
        k_gen_flat = k_gen.reshape(B * H * gen_len, D)  # [B*H*gen_len, D]
        source_linear_idx = (batch_flat * H * gen_len + head_flat * gen_len + sort_indices_flat)  # [B*H*gen_len]
        k_gen_values = k_gen_flat[source_linear_idx]  # [B*H*gen_len, D]
        
        # Compute linear destination indices
        dest_linear_idx = (batch_flat * H * num_bucket * max_tokens_per_bucket +
                          head_flat * num_bucket * max_tokens_per_bucket +
                          bucket_flat * max_tokens_per_bucket +
                          pos_flat)  # [B*H*gen_len]
        
        # Use direct indexing with linear indices (more efficient than index_put_)
        bucket_id_padded_flat = bucket_id_padded.reshape(B * H * num_bucket * max_tokens_per_bucket)
        bucket_id_padded_flat[dest_linear_idx] = bucket_id_sorted.flatten()
        
        k_gen_padded_flat = k_gen_padded.reshape(B * H * num_bucket * max_tokens_per_bucket, D)
        k_gen_padded_flat[dest_linear_idx] = k_gen_values
        
        if v_gen is not None:
            v_gen_values = v_gen.reshape(B * H * gen_len, D)[source_linear_idx]  # [B*H*gen_len, D]
            v_gen_padded_flat = v_gen_padded.reshape(B * H * num_bucket * max_tokens_per_bucket, D)
            v_gen_padded_flat[dest_linear_idx] = v_gen_values
        
        # Store unsort information
        unsort_indices = torch.argsort(sort_indices, dim=2, stable=True)  # [B, H, gen_len]
        unsort_info = {
            'sort_indices': sort_indices,  # [B, H, gen_len] - indices used for sorting
            'unsort_indices': unsort_indices,  # [B, H, gen_len] - indices to reverse the sort
            'position_in_group': position_in_group,  # [B, H, gen_len] - position within bucket for each token
            'bucket_id_gen': bucket_id_gen,  # [B, H, gen_len] - original bucket assignments
        }
        
        # Always return v_gen_padded (create zero tensor if v_gen was not provided)
        if v_gen_padded is None:
            v_gen_padded = torch.zeros_like(k_gen_padded)
        return bucket_id_padded, k_gen_padded, num_bucket, max_tokens_per_bucket, unsort_info, v_gen_padded


    def _select_slots_batched(
        self,
        k_gen_padded: torch.Tensor,
        bucket_id_padded: torch.Tensor,
        cache_bucket_keys: torch.Tensor,
        cache_key_idx: torch.Tensor,
        B: int,
        H: int,
        num_bucket: int,
        max_tokens_per_bucket: int,
        N_sample_per_bucket: int,
        D: int,
        shared_slot_query_proj: Optional[nn.Module] = None,
        shared_slot_key_proj: Optional[nn.Module] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Select slots for all generation tokens (batched, working with padded inputs).
        
        Args:
            k_gen_padded: [B, H, num_bucket, max_tokens_per_bucket, D] - padded and sorted by bucket
            bucket_id_padded: [B, H, num_bucket, max_tokens_per_bucket] - padded and sorted by bucket
            cache_bucket_keys: [B*H, num_bucket*N_sample_per_bucket, D]
            cache_key_idx: [B*H, num_bucket, N_sample_per_bucket]
            
        Returns:
            slot_probs_gen_padded: [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
            alpha_gen_padded: [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
        """
        # cache_bucket_keys is [B*H, num_bucket*N_sample_per_bucket, D]
        # Reshape to [B, H, num_bucket, N_sample_per_bucket, D] for easier indexing
        # MEMORY: cache_bucket_keys_reshaped: [B, H, num_bucket, N_sample_per_bucket, D] = B * H * num_bucket * N_sample_per_bucket * D
        cache_bucket_keys_reshaped = cache_bucket_keys.view(B, H, num_bucket, N_sample_per_bucket, D)  # [B, H, num_bucket, N_sample_per_bucket, D]
        
        # Apply projections first without expanding
        # k_gen_padded: [B, H, num_bucket, max_tokens_per_bucket, D]
        # cache_bucket_keys_reshaped: [B, H, num_bucket, N_sample_per_bucket, D]
        # MEMORY: q_slot_gen: [B, H, num_bucket, max_tokens_per_bucket, 2*slot_hidden_dim] = B * H * num_bucket * max_tokens_per_bucket * 2*slot_hidden_dim
        # MEMORY: k_slot_gen: [B, H, num_bucket, N_sample_per_bucket, 2*slot_hidden_dim] = B * H * num_bucket * N_sample_per_bucket * 2*slot_hidden_dim
        # Use shared parameters if provided, otherwise use layer-specific ones
        if self.use_shared_params:
            assert shared_slot_query_proj is not None, "use_shared_params is True but shared_slot_query_proj is None"
            assert shared_slot_key_proj is not None, "use_shared_params is True but shared_slot_key_proj is None"
            slot_query_proj = shared_slot_query_proj
            slot_key_proj = shared_slot_key_proj
        else:
            slot_query_proj = self.slot_query_proj
            slot_key_proj = self.slot_key_proj
        
        q_slot_gen = slot_query_proj(k_gen_padded)  # [B, H, num_bucket, max_tokens_per_bucket, 2*slot_hidden_dim]
        k_slot_gen = slot_key_proj(cache_bucket_keys_reshaped)  # [B, H, num_bucket, N_sample_per_bucket, 2*slot_hidden_dim]
        
        # Get slot_hidden_dim from output dimension
        slot_hidden_dim = q_slot_gen.shape[-1] // 2
        
        # Split into selection and gate parts (views, not copies)
        q_slot_selection = q_slot_gen[:, :, :, :, :slot_hidden_dim]  # [B, H, num_bucket, max_tokens_per_bucket, slot_hidden_dim]
        q_slot_gate = q_slot_gen[:, :, :, :, slot_hidden_dim:]  # [B, H, num_bucket, max_tokens_per_bucket, slot_hidden_dim]
        k_slot_selection = k_slot_gen[:, :, :, :, :slot_hidden_dim]  # [B, H, num_bucket, N_sample_per_bucket, slot_hidden_dim]
        k_slot_gate = k_slot_gen[:, :, :, :, slot_hidden_dim:]  # [B, H, num_bucket, N_sample_per_bucket, slot_hidden_dim]
        
        # Compute logits via batched matrix multiplication: [..., max_tokens_per_bucket, slot_hidden_dim] @ [..., slot_hidden_dim, N_sample_per_bucket] -> [..., max_tokens_per_bucket, N_sample_per_bucket]
        # Reshape for batched matmul: flatten leading dims, then reshape back
        # MEMORY: These are views of the above, no additional memory
        q_slot_selection = q_slot_selection.view(B * H * num_bucket, max_tokens_per_bucket, slot_hidden_dim)  # [B*H*num_bucket, max_tokens_per_bucket, slot_hidden_dim]
        q_slot_gate = q_slot_gate.view(B * H * num_bucket, max_tokens_per_bucket, slot_hidden_dim)  # [B*H*num_bucket, max_tokens_per_bucket, slot_hidden_dim]
        k_slot_selection = k_slot_selection.view(B * H * num_bucket, N_sample_per_bucket, slot_hidden_dim)  # [B*H*num_bucket, N_sample_per_bucket, slot_hidden_dim]
        k_slot_gate = k_slot_gate.view(B * H * num_bucket, N_sample_per_bucket, slot_hidden_dim)  # [B*H*num_bucket, N_sample_per_bucket, slot_hidden_dim]
        
        # Compute dot products via matmul: [B*H*num_bucket, max_tokens_per_bucket, slot_hidden_dim] @ [B*H*num_bucket, slot_hidden_dim, N_sample_per_bucket] -> [B*H*num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
        # *** LARGE INTERMEDIATE TENSORS: These matmul operations create temporary tensors during computation ***
        # MEMORY: selection_logits_flat: [B*H*num_bucket, max_tokens_per_bucket, N_sample_per_bucket] = B * H * num_bucket * max_tokens_per_bucket * N_sample_per_bucket
        # MEMORY: gate_logits_flat: [B*H*num_bucket, max_tokens_per_bucket, N_sample_per_bucket] = B * H * num_bucket * max_tokens_per_bucket * N_sample_per_bucket
        selection_logits_flat = torch.matmul(q_slot_selection, k_slot_selection.transpose(-2, -1))  # [B*H*num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
        gate_logits_flat = torch.matmul(q_slot_gate, k_slot_gate.transpose(-2, -1))  # [B*H*num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
        
        # Reshape back to original shape
        # MEMORY: selection_logits_gen: [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket] = B * H * num_bucket * max_tokens_per_bucket * N_sample_per_bucket
        # MEMORY: gate_logits_gen: [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket] = B * H * num_bucket * max_tokens_per_bucket * N_sample_per_bucket
        selection_logits_gen = selection_logits_flat.view(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket)  # [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
        gate_logits_gen = gate_logits_flat.view(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket)  # [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
        
        slot_probs_gen_padded = F.gumbel_softmax(selection_logits_gen, tau=self.gumbel_tau_slot, hard=False, dim=-1)  # [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
        alpha_gen_padded = torch.sigmoid(gate_logits_gen)  # [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
        
        return slot_probs_gen_padded, alpha_gen_padded


    def _parallel_scan_linear_recurrence(
        self,
        a: torch.Tensor,  # [..., seq_len] - coefficients (1 - alpha)
        b: torch.Tensor,  # [..., seq_len, D] - terms (alpha * k)
        initial: torch.Tensor,  # [..., D] - initial state K_0
    ) -> torch.Tensor:
        """Parallel scan for linear recurrence: K_i = a_i * K_{i-1} + b_i
        
        Uses vectorized approach: K_i = (prod_{j=0}^{i} a_j) * K_0 + sum_{j=0}^{i} (prod_{k=j+1}^{i} a_k) * b_j
        
        Args:
            a: [..., seq_len] - coefficients (1 - alpha_i)
            b: [..., seq_len, D] - terms (alpha_i * k_i)
            initial: [..., D] - initial state K_0
            
        Returns:
            K: [..., seq_len, D] - final states K_i for all i
        """
        # Get shape info
        *leading_dims, seq_len = a.shape
        D = b.shape[-1]
        
        # Reshape to [batch, seq_len, ...] for processing
        batch_size = 1
        for dim in leading_dims:
            batch_size *= dim
        
        a_flat = a.reshape(batch_size, seq_len)  # [batch, seq_len]
        b_flat = b.reshape(batch_size, seq_len, D)  # [batch, seq_len, D]
        initial_flat = initial.reshape(batch_size, D)  # [batch, D]
        
        # Use accelerated_scan for parallel scan
        # accelerated_scan.scan solves: x_t = a_t * x_{t-1} + b_t (starting with x_0 = 0)
        # Our recurrence: K_i = a_i * K_{i-1} + b_i (same form, but starts with K_0 = initial)
        
        # Prepare inputs for accelerated_scan: expects (B, C, T) format
        # gates: [batch_size, D, seq_len] - coefficients a_i
        # tokens: [batch_size, D, seq_len] - terms b_i
        
        # Expand a to [batch_size, D, seq_len] (broadcast across D dimension)
        gates = a_flat.unsqueeze(1).expand(-1, D, -1).contiguous()  # [batch, D, seq_len]
        
        # Transpose b to [batch_size, D, seq_len]
        tokens = b_flat.transpose(1, 2).contiguous()  # [batch, D, seq_len]
        
        # Save a_flat for initial state correction (before any padding)
        a_flat_for_init = a_flat.clone()  # [batch, original_seq_len]
        
        # Delete intermediate tensors immediately after use
        del b_flat  # No longer needed after creating tokens
        
        # Pad to next power of 2 that is >= 32 and <= 65536 (warp.scan requirements)
        # Requirements: seqlen must be a power of 2, >= 32, <= 65536
        original_seq_len = seq_len
        min_seq_len = 32
        max_seq_len = 65536
        
        # Calculate the required padded length
        if seq_len < min_seq_len:
            # If less than 32, pad to 32 (which is 2^5, a power of 2)
            padded_len = min_seq_len
        else:
            # Check if it's already a power of 2
            is_power_of_2 = (seq_len & (seq_len - 1)) == 0
            if is_power_of_2:
                padded_len = seq_len
            else:
                # Pad to next power of 2
                padded_len = 1 << (seq_len - 1).bit_length()
        
        # Ensure padded_len doesn't exceed max_seq_len
        if padded_len > max_seq_len:
            raise ValueError(f"Sequence length {seq_len} requires padding to {padded_len}, which exceeds maximum {max_seq_len}")
        
        # Verify padded_len is a power of 2 and >= 32 (double check)
        if padded_len < min_seq_len or (padded_len & (padded_len - 1)) != 0:
            raise ValueError(f"Invalid padded length {padded_len}: must be power of 2 and >= {min_seq_len}")
        
        # Pad if needed
        if padded_len != seq_len:
            pad_len = padded_len - seq_len
            gates = F.pad(gates, (0, pad_len), value=1.0)  # Pad with 1.0 (identity: no decay)
            tokens = F.pad(tokens, (0, pad_len), value=0.0)  # Pad with 0.0 (no contribution)
            seq_len = padded_len
        
        # Call accelerated_scan: solves x_t = gates_t * x_{t-1} + tokens_t
        # Note: accelerated_scan initializes x_0 = 0, so we need to account for initial state
        K_scan = accelerated_scan_fn(gates, tokens)  # [batch, D, seq_len or padded_len]
        
        # Delete gates and tokens immediately after use
        del gates, tokens
        
        # Extract original length if padded
        if seq_len != original_seq_len:
            K_scan = K_scan[:, :, :original_seq_len]
            seq_len = original_seq_len
        
        # Handle initial state: accelerated_scan computes starting from x_0 = 0
        # We want: K_i = a_i * K_{i-1} + b_i where K_0 = initial
        # 
        # accelerated_scan gives: x_1 = a_0 * 0 + b_0 = b_0
        #                      x_2 = a_1 * b_0 + b_1
        #                      x_3 = a_2 * (a_1 * b_0 + b_1) + b_2 = a_2 * a_1 * b_0 + a_2 * b_1 + b_2
        #
        # We want: K_1 = a_0 * initial + b_0
        #         K_2 = a_1 * (a_0 * initial + b_0) + b_1 = a_1 * a_0 * initial + a_1 * b_0 + b_1
        #         K_3 = a_2 * (a_1 * a_0 * initial + a_1 * b_0 + b_1) + b_2 = a_2 * a_1 * a_0 * initial + ...
        #
        # Relationship: K_i = (prod_{j=0}^{i-1} a_j) * initial + x_i
        # where x_i is what accelerated_scan computed
        
        # OPTIMIZED: Simplified initial state correction using more efficient broadcast operations
        # Relationship: K_i = (prod_{j=0}^{i-1} a_j) * initial + x_i
        # where x_i is what accelerated_scan computed
        # 
        # For position i, we need prefix product prod_{j=0}^{i-1} a_j
        # We compute this as: prepend 1, then cumprod, then take [:seq_len]
        
        # Handle initial state: accelerated_scan computes starting from x_0 = 0
        # We need to add the initial state contribution: K_i = (prod_{j=0}^{i-1} a_j) * initial + x_i
        # where x_i is what accelerated_scan computed
        
        # For small sequences or when all tokens are invalid, we need to ensure initial state is preserved
        # Compute prefix products of a to get the coefficient for initial state at each position
        # a_cumprod[i] = prod_{j=0}^{i-1} a_j (prefix product up to position i-1)
        
        # Use saved a_flat_for_init (from before padding) for initial state correction
        # Compute prefix products: prepend 1, then cumprod
        # Pad at the beginning: [1.0, a_0, a_1, ...]
        a_padded = F.pad(a_flat_for_init, (1, 0), value=1.0)  # [batch, seq_len+1], prepend 1.0
        a_cumprod_full = torch.cumprod(a_padded, dim=1)  # [batch, seq_len+1]
        # Take [:seq_len] to get prefix products for positions [0, 1, ..., seq_len-1]
        # a_cumprod_prefix[i] = prod_{j=0}^{i-1} a_j, which is the coefficient for initial at position i
        a_cumprod_prefix = a_cumprod_full[:, :seq_len]  # [batch, seq_len] - prefix products
        
        # Compute initial contribution: (prod_{j=0}^{i-1} a_j) * initial for each position i
        # a_cumprod_prefix: [batch, seq_len], initial_flat: [batch, D]
        # Result: [batch, seq_len, D] -> transpose to [batch, D, seq_len]
        initial_contrib = (a_cumprod_prefix.unsqueeze(-1) * initial_flat.unsqueeze(1)).transpose(1, 2)  # [batch, D, seq_len]
        
        # Add initial state contribution to K_scan
        K_scan = K_scan + initial_contrib
        
        # Clean up intermediate tensors
        del a_flat_for_init, a_padded, a_cumprod_full, a_cumprod_prefix, initial_contrib, a_flat, initial_flat
        
        # Transpose back: [batch, D, seq_len] -> [batch, seq_len, D]
        K_flat = K_scan.transpose(1, 2)  # [batch, seq_len, D]
        del K_scan  # Delete immediately after transpose
        
        # Reshape back to original shape
        return K_flat.reshape(*leading_dims, seq_len, D)


    def _update_cache_batched(
        self,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
        k_gen_padded: torch.Tensor,
        v_gen_padded: torch.Tensor,
        slot_probs_padded: torch.Tensor,
        alpha_padded: torch.Tensor,
        cache_key_idx: torch.Tensor,
        bucket_id_padded: torch.Tensor,
        padded_gen_len: int,
        B: int,
        H: int,
        N_sample_per_bucket: int,
        D: int,
        num_bucket: int,
        use_fused_kernel: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update cache using parallel scan for linear recurrence: K_i = (1-alpha_i)*K_{i-1} + alpha_i*k_i
        
        Optimized version with reduced memory allocations and more efficient operations.
        Time: O(B*H*num_bucket*N_sample_per_bucket*max_tokens_per_bucket*D) for parallel scan
        Space: Reduced intermediate tensor allocations by using views and in-place operations
        
        Notation:
            - K_i: The cached/accumulated key state at position i (the state being updated)
            - k_i: The new key vector from the i-th generation token (the input, from k_gen_padded)
            - alpha_i: The mixing coefficient (gate) for the i-th token
            - K_0: Initial cache state (from cache_k at position specified by cache_key_idx)
        
        Args:
            use_fused_kernel: If True and available, use fused Triton kernel (default: True)
        
        Returns:
            cache_k: Updated cache_k [B, H, cache_len, D]
            cache_v: Updated cache_v [B, H, cache_len, D]
        """
        # cache_k的形状是[B, H, cache_len, D]，所以cache_len在索引2
        # cache_len_val = cache_k.size(1)
        cache_len_val = cache_k.size(2)
        max_tokens_per_bucket = padded_gen_len // num_bucket if num_bucket > 0 else padded_gen_len

        # Try to use fused kernel if available and requested
        if not self.training and use_fused_kernel and _has_fused_cache_update:
            return fused_cache_update(
                cache_k, cache_v,
                k_gen_padded, v_gen_padded,
                slot_probs_padded, alpha_padded,
                bucket_id_padded,
                padded_gen_len, B, H, N_sample_per_bucket, D,
                num_bucket, 
            )
        # Create valid mask once and reuse
        valid_mask = (bucket_id_padded >= 0).view(B, H, num_bucket, max_tokens_per_bucket)  # [B, H, num_bucket, max_tokens_per_bucket]
        
        # Reshape inputs using view (no copy if contiguous)
        k_gen_reshaped = k_gen_padded.view(B, H, num_bucket, max_tokens_per_bucket, D)
        v_gen_reshaped = v_gen_padded.view(B, H, num_bucket, max_tokens_per_bucket, D)
        slot_probs_reshaped = slot_probs_padded.view(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket)
        alpha_reshaped = alpha_padded.view(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket)
        
        # Reshape cache using view
        # cache_k = cache_k.view(B, H, num_bucket, N_sample_per_bucket, D)
        # cache_v = cache_v.view(B, H, num_bucket, N_sample_per_bucket, D)
        
        # Transpose for parallel scan (creates views, no copy)
        # [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket] -> [B, H, num_bucket, N_sample_per_bucket, max_tokens_per_bucket]
        alpha_for_scan = alpha_reshaped.transpose(3, 4)  # view
        slot_probs_for_scan = slot_probs_reshaped.transpose(3, 4)  # view
        
        # Compute alpha_slot_prod (needed for both a and b)
        alpha_slot_prod = alpha_for_scan * slot_probs_for_scan  # [B, H, num_bucket, N_sample_per_bucket, max_tokens_per_bucket]
        
        # Compute coefficients a = 1 - alpha_slot_prod
        # Use in-place subtraction where possible, but need separate copies for k and v after masking
        a_k = 1.0 - alpha_slot_prod  # [B, H, num_bucket, N_sample_per_bucket, max_tokens_per_bucket]
        
        # Apply masking to a_k (in-place where safe)
        valid_mask_expanded = valid_mask.unsqueeze(3)  # [B, H, num_bucket, 1, max_tokens_per_bucket] - view
        a_k = torch.where(valid_mask_expanded, a_k, torch.ones_like(a_k))
        
        # For a_v, we can reuse the computation but need a separate tensor after masking
        # Since masking might differ, we compute separately (but more efficiently)
        # a_v = a_k
        # a_v = 1.0 - alpha_slot_prod
        # a_v = torch.where(valid_mask_expanded, a_v, torch.ones_like(a_v))
        
        # Compute b_k and b_v using efficient broadcasting
        # k_gen: [B, H, num_bucket, max_tokens_per_bucket, D]
        # alpha_slot_prod: [B, H, num_bucket, N_sample_per_bucket, max_tokens_per_bucket]
        # Use unsqueeze to create view, then multiply
        k_gen_expanded = k_gen_reshaped.unsqueeze(3)  # [B, H, num_bucket, 1, max_tokens_per_bucket, D] - view
        v_gen_expanded = v_gen_reshaped.unsqueeze(3)  # view
        
        # Compute b with masking applied during computation (more efficient)
        alpha_slot_prod_expanded = alpha_slot_prod.unsqueeze(-1)  # [B, H, num_bucket, N_sample_per_bucket, max_tokens_per_bucket, 1]
        valid_mask_b = valid_mask_expanded.unsqueeze(-1)  # [B, H, num_bucket, 1, max_tokens_per_bucket, 1] - view
        
        # Compute b with masking in one step
        b_k = torch.where(valid_mask_b, alpha_slot_prod_expanded * k_gen_expanded, torch.zeros_like(k_gen_expanded))
        b_v = torch.where(valid_mask_b, alpha_slot_prod_expanded * v_gen_expanded, torch.zeros_like(v_gen_expanded))
        
        # Flatten for parallel scan: [B*H*num_bucket*N_sample_per_bucket, max_tokens_per_bucket, ...]
        batch_size_scan = B * H * num_bucket * N_sample_per_bucket
        # a_k_flat = a_k.reshape(batch_size_scan, max_tokens_per_bucket)
        # a_v_flat = a_v.reshape(batch_size_scan, max_tokens_per_bucket)
        a_k_flat = a_k.reshape(batch_size_scan, max_tokens_per_bucket)
        b_k_flat = b_k.reshape(batch_size_scan, max_tokens_per_bucket, D)
        b_v_flat = b_v.reshape(batch_size_scan, max_tokens_per_bucket, D)

        cache_k = cache_k.view(batch_size_scan, D)
        cache_v = cache_v.view(batch_size_scan, D)
        
        # Delete large intermediate tensors before parallel scan to free memory
        del a_k, b_k, b_v, alpha_slot_prod, alpha_slot_prod_expanded  #  a_v, 
        del k_gen_expanded, v_gen_expanded, valid_mask_expanded, valid_mask_b
        del alpha_for_scan, slot_probs_for_scan, k_gen_reshaped, v_gen_reshaped
        del slot_probs_reshaped, alpha_reshaped # , initial_k, initial_v
        
        # Apply parallel scan
        K_all_flat = self._parallel_scan_linear_recurrence(a_k_flat, b_k_flat, cache_k)
        V_all_flat = self._parallel_scan_linear_recurrence(a_k_flat, b_v_flat, cache_v)
        
        # Extract only final state (last position) and delete intermediate results immediately
        K_final_flat = K_all_flat[:, -1, :]  # [batch_size_scan, D]
        V_final_flat = V_all_flat[:, -1, :]
        del K_all_flat, V_all_flat, a_k_flat, b_k_flat, b_v_flat #, initial_k_flat, initial_v_flat
        
        # Reshape back using view
        K_final = K_final_flat.view(B, H, num_bucket, N_sample_per_bucket, D)
        V_final = V_final_flat.view(B, H, num_bucket, N_sample_per_bucket, D)
        del K_final_flat, V_final_flat
        
        # Update cache (convert dtype if needed, use view)
        cache_k = K_final.to(cache_k.dtype).view(B, H, cache_len_val, D)
        cache_v = V_final.to(cache_v.dtype).view(B, H, cache_len_val, D)

        del K_final, V_final
        
        return cache_k, cache_v

    def _update_cache_key_idx(
        self,
        unsort_info: dict,
        bucket_id_gen: torch.Tensor,
        max_tokens_per_bucket: int,
        slot_probs_padded: torch.Tensor,
        cache_key_idx: torch.Tensor,
        gen_len: int,
        B: int,
        H: int,
        num_bucket: int,
        N_sample_per_bucket: int,
        start_idx: int,
        chunk_len: int,
    ) -> torch.Tensor:
        """Update cache_key_idx to reflect which tokens were assigned to which slots.
        
        This function tracks the maximum token position assigned to each slot in the cache.
        It's used during training to maintain accurate cache indexing.
        
        Args:
            unsort_info: Dictionary containing:
                - sort_indices: [B, H, gen_len] - maps original -> sorted
                - unsort_indices: [B, H, gen_len] - maps sorted -> original
                - position_in_group: [B, H, gen_len] - position in sorted order
            bucket_id_gen: [B, H, gen_len] bucket assignments for each token
            max_tokens_per_bucket: Maximum number of tokens per bucket
            slot_probs_padded: [B, H, padded_gen_len, N_sample_per_bucket] slot probabilities
            cache_key_idx: [B*H, num_bucket, N_sample_per_bucket] cache key index (will be updated)
            gen_len: Generation length
            B: Batch size
            H: Number of heads
            num_bucket: Number of buckets
            N_sample_per_bucket: Number of samples per bucket
            
        Returns:
            cache_key_idx: [B*H, num_bucket, N_sample_per_bucket] updated cache key index
        """
        # Extract unsort information
        sort_indices = unsort_info['sort_indices']  # [B, H, gen_len] - maps original -> sorted
        unsort_indices = unsort_info['unsort_indices']  # [B, H, gen_len] - maps sorted -> original
        position_in_group_sorted = unsort_info['position_in_group']  # [B, H, gen_len] - in sorted order
        
        # Gather position_in_group back to original order
        position_in_group = torch.gather(
            position_in_group_sorted, 
            dim=2, 
            index=unsort_indices
        )  # [B, H, gen_len] - now in original order
        
        # STEP 2: Compute padded index for each token in slot_probs_padded
        # slot_probs_padded is organized as: [B, H, padded_gen_len, N_sample_per_bucket]
        # where padded_gen_len = num_bucket * max_tokens_per_bucket
        # Index calculation: bucket_id * max_tokens_per_bucket + position_in_bucket
        padded_idx = bucket_id_gen * max_tokens_per_bucket + position_in_group  # [B, H, gen_len]
        # STEP 3: Gather slot probabilities and find the assigned slot for each token
        # Expand padded_idx to match slot_probs_padded's last dimension
        padded_idx_expanded = padded_idx.unsqueeze(-1).expand(-1, -1, -1, N_sample_per_bucket)  # [B, H, gen_len, N_sample_per_bucket]
        # Gather slot probabilities: slot_probs_padded[b, h, padded_idx[b,h,t], :]
        slot_probs_gathered = torch.gather(slot_probs_padded, 2, padded_idx_expanded)  # [B, H, gen_len, N_sample_per_bucket]
        # Find the slot with maximum probability for each token
        slot_assignments = slot_probs_gathered.argmax(dim=-1)  # [B, H, gen_len] - slot ID for each token
        del padded_idx_expanded, slot_probs_gathered  # Free memory immediately
        
        # STEP 4: Prepare indices for scatter_reduce to update cache_key_idx
        # cache_key_idx shape: [B*H, num_bucket, N_sample_per_bucket]
        # We need to compute linear indices: bh_idx * (num_bucket * N_sample_per_bucket) + bucket_id * N_sample_per_bucket + slot_id
        cache_key_idx_flat = cache_key_idx.view(-1).clone()  # [B*H*num_bucket*N_sample_per_bucket]

        # Token positions in the original sequence (starting from cache_len)
        token_pos_base = start_idx + torch.arange(chunk_len, device=bucket_id_gen.device, dtype=cache_key_idx.dtype)  # [gen_len]
        token_positions_flat = token_pos_base.repeat(B * H)  # [B*H*gen_len] - repeat for each (batch, head) pair
        
        # Flatten all indices
        # Use reshape() instead of view() to handle non-contiguous tensors
        bucket_id_flat = bucket_id_gen.reshape(-1)  # [B*H*gen_len]
        slot_assignments_flat = slot_assignments.reshape(-1)  # [B*H*gen_len]
        bh_indices = torch.arange(B * H, device=bucket_id_gen.device, dtype=torch.long).repeat_interleave(gen_len)  # [B*H*gen_len]
        
        # Compute linear indices in cache_key_idx_flat
        # Layout: [batch_0_head_0, batch_0_head_1, ..., batch_0_head_{H-1}, batch_1_head_0, ...]
        # For each (b, h) pair: indices start at (b*H + h) * num_bucket * N_sample_per_bucket
        linear_indices = (
            bh_indices * num_bucket * N_sample_per_bucket +  # offset for (batch, head) pair
            bucket_id_flat * N_sample_per_bucket +           # offset for bucket
            slot_assignments_flat                            # offset for slot within bucket
        )  # [B*H*gen_len]
        
        del padded_idx, bh_indices, bucket_id_flat, slot_assignments_flat  # Free memory immediately
        
        # STEP 5: Update cache_key_idx using scatter_reduce
        # For each slot, store the maximum token position that was assigned to it
        # reduce='amax': if multiple tokens are assigned to the same slot, keep the maximum position
        # include_self=True: include existing values in the reduction
        cache_key_idx_flat.scatter_reduce_(
            0, 
            linear_indices, 
            token_positions_flat, 
            reduce='amax', 
            include_self=True
        )
        del linear_indices, token_positions_flat  # Free memory immediately
        
        # STEP 6: Reshape back to original shape
        cache_key_idx = cache_key_idx_flat.view(B * H, num_bucket, N_sample_per_bucket)
        del cache_key_idx_flat  # Free memory immediately
        
        return cache_key_idx

    def _compute_attention_batched(
        self,
        q_gen_attn: torch.Tensor,
        cache_k_for_attn: torch.Tensor,
        cache_v_for_attn: torch.Tensor,
        cache_bucket_keys_reshaped: torch.Tensor,
        cache_bucket_values_reshaped: torch.Tensor,
        cache_key_idx_reshaped: torch.Tensor,
        query_idx_gen_M_stored: torch.Tensor,
        query_idx_gen_M_plus_stored: Optional[torch.Tensor],
        g_topk_gen: torch.Tensor,
        beam_width_score: torch.Tensor,
        N_sample_per_bucket: int,
        HQ: int,
        gen_len: int,
    ) -> torch.Tensor:
        """Compute attention output for all generation tokens at once (batched).
        
        Returns:
            attn_output_gen: [B, gen_len, HQ, D] - parallel_arm returns [B, seq_len, HQ, D] format
        """
        B, H, cache_len, hdim = cache_bucket_keys_reshaped.shape
        # Reshape for batched attention: [B, gen_len, HQ, D] -> process all at once
        # The parallel_arm function should handle batched queries
        cache_bucket_keys_reshaped = (
            cache_bucket_keys_reshaped.view(B, H, cache_len, hdim)
            .transpose(1, 2)
        )
        cache_bucket_values_reshaped = (
            cache_bucket_values_reshaped.view(B, H, cache_len, hdim)
            .transpose(1, 2)
        )
        cache_key_idx_reshaped = cache_key_idx_reshaped.view(B, H, cache_len).transpose(1, 2)  # [B, num_bucket*N_sample, H]
        if self.training:
            # Expand query indices: [B, HQ, gen_len, beam_width] -> [B, gen_len, HQ, beam_width]
            query_idx_gen_M_expanded = query_idx_gen_M_stored.permute(0, 2, 1, 3)  # [B, gen_len, HQ, beam_width_M]
            attn_output_gen_M = parallel_arm(
                q_gen_attn, cache_k_for_attn, cache_v_for_attn,
                g_topk_gen, cache_bucket_keys_reshaped, cache_bucket_values_reshaped,
                q_indices=query_idx_gen_M_expanded,
                k_indices=cache_key_idx_reshaped,
                block_size=N_sample_per_bucket,
                window_size=self.local_window_size
            )  # [B, gen_len, HQ, D]
            
            query_idx_gen_M_plus_expanded = query_idx_gen_M_plus_stored.permute(0, 2, 1, 3)  # [B, gen_len, HQ, beam_width_M_plus]
            attn_output_gen_M_plus = parallel_arm(
                q_gen_attn, cache_k_for_attn, cache_v_for_attn,
                g_topk_gen, cache_bucket_keys_reshaped, cache_bucket_values_reshaped,
                q_indices=query_idx_gen_M_plus_expanded,
                k_indices=cache_key_idx_reshaped,
                block_size=N_sample_per_bucket,
                window_size=self.local_window_size
            )  # [B, gen_len, HQ, D]
            
            attn_output_gen = (1.0 - beam_width_score) * attn_output_gen_M + beam_width_score * attn_output_gen_M_plus
        else:
            query_idx_gen_M_expanded = query_idx_gen_M_stored.permute(0, 2, 1, 3).contiguous()  # [B, gen_len, HQ, beam_width_M]
            attn_output_gen = parallel_arm(
                q_gen_attn, cache_k_for_attn, cache_v_for_attn,
                g_topk_gen, cache_bucket_keys_reshaped, cache_bucket_values_reshaped,
                q_indices=query_idx_gen_M_expanded,
                k_indices=cache_key_idx_reshaped,
                block_size=N_sample_per_bucket,
                window_size=self.local_window_size
            )  # [B, gen_len, HQ, D]
        
        # Add sink token contribution
        if self.sink is not None:
            attn_output_gen = attn_output_gen + g_topk_gen[..., [2]] * self.sink.weight.repeat_interleave(self.num_key_value_groups, dim=0)[None, None, ...]
        # attn_output_gen = attn_output_gen.transpose(1, 2)  # [B, HQ, gen_len, D]
        
        return attn_output_gen


    def _compute_attention_for_cache(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        hidden_states: torch.Tensor,
        query_idx: torch.Tensor,
        key_idx: torch.Tensor,
        bucket_keys: torch.Tensor,
        bucket_values: torch.Tensor,
        num_bucket: int,
        N_sample_per_bucket: int,
        B: int,
        HQ: int,
        H: int,
        cache_len: int,
        hdim: int,
        shared_g_proj: Optional[nn.Module] = None,
        cache_query_idx_M_stored: Optional[torch.Tensor] = None,
        cache_query_idx_M_plus_stored: Optional[torch.Tensor] = None,
        beam_width_score: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Compute attention output for the cache portion [0:cache_len]."""
        if value.dtype == torch.float16 or value.dtype == torch.bfloat16:
            target_dtype = value.dtype
        else:
            target_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

        query_attn = query.transpose(1, 2).to(target_dtype)  # [B, cache_len, HQ, D]
        key_attn = key.transpose(1, 2).to(target_dtype)  # [B, cache_len, H, D]
        value_attn = value.transpose(1, 2).to(target_dtype)  # [B, cache_len, H, D]

        bucket_keys_reshaped = (
            bucket_keys.view(B, H, num_bucket * N_sample_per_bucket, hdim)
            .transpose(1, 2)
            .to(target_dtype)
        )
        bucket_values_reshaped = (
            bucket_values.view(B, H, num_bucket * N_sample_per_bucket, hdim)
            .transpose(1, 2)
            .to(target_dtype)
        )

        hidden_cache = hidden_states[:, :cache_len, :]  # [B, cache_len, hidden_size]
        # Use shared g_proj if provided, otherwise use layer-specific one
        if hasattr(self, 'use_shared_g_proj') and self.use_shared_g_proj:
            assert shared_g_proj is not None, "use_shared_g_proj is True but shared_g_proj is None"
            g_proj = shared_g_proj
        else:
            g_proj = self.g_proj
        # Output dimension depends on use_sink: 3 if sink token is used, 2 otherwise
        g_proj_output_dim = 3 if self.config.use_sink else 2
        g_topk = F.softmax(
            g_proj(hidden_cache).view(B, cache_len, HQ, g_proj_output_dim), dim=-1
        )  # [B, cache_len, HQ, g_proj_output_dim]

        key_idx_reshaped = key_idx.view(B, H, num_bucket * N_sample_per_bucket).transpose(1, 2)  # [B, num_bucket*N_sample, H]

        if self.training:
        # Expand query indices: [B, HQ, cache_len, beam_width] -> [B, cache_len, HQ, beam_width]
            query_idx_cache_M_expanded = cache_query_idx_M_stored.permute(0, 2, 1, 3)  # [B, cache_len, HQ, beam_width_M]
            attn_output_cache_M = parallel_arm(
                query_attn,
                key_attn,
                value_attn,
                g_topk,
                bucket_keys_reshaped,
                bucket_values_reshaped,
                q_indices=query_idx_cache_M_expanded,
                k_indices=key_idx_reshaped,
                block_size=N_sample_per_bucket,
                window_size=self.local_window_size,
            )  # [B, cache_len, HQ, D]
            query_idx_cache_M_plus_expanded = cache_query_idx_M_plus_stored.permute(0, 2, 1, 3)  # [B, cache_len, HQ, beam_width_M_plus]
            attn_output_cache_M_plus = parallel_arm(
                query_attn,
                key_attn,
                value_attn,
                g_topk,
                bucket_keys_reshaped,
                bucket_values_reshaped,
                q_indices=query_idx_cache_M_plus_expanded,
                k_indices=key_idx_reshaped,
                block_size=N_sample_per_bucket,
                window_size=self.local_window_size,
            )  # [B, cache_len, HQ, D]
            
            attn_output = (1.0 - beam_width_score) * attn_output_cache_M + beam_width_score * attn_output_cache_M_plus
        else:
            query_idx_reshaped = query_idx.view(B, HQ, cache_len, -1).transpose(1, 2)  # [B, cache_len, HQ, beam_width]
            attn_output = parallel_arm(
                query_attn,
                key_attn,
                value_attn,
                g_topk,
                bucket_keys_reshaped,
                bucket_values_reshaped,
                q_indices=query_idx_reshaped,
                k_indices=key_idx_reshaped,
                block_size=N_sample_per_bucket,
                window_size=self.local_window_size,
            )  # [B, cache_len, HQ, D]
        
        if self.sink is not None:
            attn_output = (
                attn_output
                + g_topk[..., [2]]
                * self.sink.weight.repeat_interleave(self.num_key_value_groups, dim=0)[None, None, ...]
            )

        # attn_output from HiR is already in correct format [B, cache_len, HQ, D], no need for transpose or permute
        # attn_output = attn_output.transpose(1, 2)  # Removed: not needed
        # attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(B, cache_len, self.config.hidden_size)  # Removed: not needed
        attn_output = attn_output.contiguous().view(B, cache_len, self.config.hidden_size)
        
        attn_output = self.o_proj(attn_output)  # [B, cache_len, hidden_size]

        return attn_output

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = False,
        return_updated_cache: Optional[bool] = False,
        shared_g_proj: Optional[nn.Module] = None,
        shared_router: Optional[nn.Module] = None,
        shared_key_gate: Optional[nn.Module] = None,
        shared_slot_query_proj: Optional[nn.Module] = None,
        shared_slot_key_proj: Optional[nn.Module] = None,
        shared_beam_width_lstm_input: Optional[nn.Module] = None,
        shared_beam_width_lstm_hidden: Optional[nn.Module] = None,
        shared_step_embedding: Optional[nn.Module] = None,
        shared_beam_width_proj: Optional[nn.Module] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
        """Forward pass - adapts ArmTopKAttentionWithMemory to HiLlamaAttention interface.
        
        Note: position_embeddings and position_ids are ignored since ArmTopKAttentionWithMemory
        handles RoPE internally via its own RotaryEmbedding.
        """
        # Call ArmTopKAttentionWithMemory
        # Note: ArmTopKAttentionWithMemory doesn't use position_embeddings - it has its own rotary
        
        """Forward pass with memory caching support.
        
        Supports two modes:
        1. Initialization mode (L > 1): Initialize cache and store in past_key_values
        2. Generation mode (L == 1): Load cache from past_key_values, update it, and store back
        """
        B, L, _ = hidden_states.shape
        # Set cache_len to the largest power of 2 <= L, with lower bound 64 and upper bound 2048

        HQ = self.num_attention_heads
        H = self.num_key_value_heads
        D = self.head_dim
        hdim = D
        # Generation mode: L == 1 and past_key_values exists
        is_generation_mode = L == 1 and past_key_values is not None
        # 
        # self.local_window_size = 256 # TODO
        if is_generation_mode:
            # Retrieve cache states from past_key_values for this specific layer
            arm_cache = _get_arm_cache(past_key_values, self.layer_idx)
            if not arm_cache:
                raise ValueError(
                    f"past_key_values does not contain arm cache states for layer {self.layer_idx}. "
                    "Initialize cache first with L > 1."
                )
            
            cache_k_window = arm_cache['cache_k_window']  # [B, window_size, H, D]
            cache_v_window = arm_cache['cache_v_window']  # [B, window_size, H, D]
            cache_bucket_keys = arm_cache['cache_bucket_keys']  # [B*H, num_bucket*N_sample, D]
            cache_bucket_values = arm_cache['cache_bucket_values']  # [B*H, num_bucket*N_sample, D]
            beam_width_M = arm_cache['beam_width_M']
            num_bucket = arm_cache['num_bucket']
            N_sample_per_bucket = arm_cache['N_sample_per_bucket']
            beam_width_M_plus = arm_cache['beam_width_M_plus']
            beam_width_score = arm_cache['beam_width_score']
            cache_key_idx = arm_cache['cache_key_idx']  # [B*H, num_bucket, N_sample_per_bucket]


            # # Rearrange to [B, num_heads, seq_len, head_dim] to match cache shape [B, H, cache_len, D]
            # q = rearrange(self.q_proj(hidden_states), 'b s (h d) -> b h s d', d=self.head_dim)
            # k = rearrange(self.k_proj(hidden_states), 'b s (h d) -> b h s d', d=self.head_dim)
            # v = rearrange(self.v_proj(hidden_states), 'b s (h d) -> b h s d', d=self.head_dim)

            # # if self.qk_norm:
            # #     q, k = self.q_norm(q), self.k_norm(k)

            # # Get sequence offset from cache
            # seqlen_offset = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
            # max_seqlen = q.shape[1] + seqlen_offset
            # if self.config.max_position_embeddings is not None:
            #     max_seqlen = max(max_seqlen, self.config.max_position_embeddings)
            # q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=kwargs.get('cu_seqlens'))

            # Compute QKV - use view instead of reshape to avoid copy
            input_shape = hidden_states.shape[:-1]
            hidden_shape = (*input_shape, -1, self.head_dim)

            q = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            k = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            v = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

            cos, sin = position_embeddings
            q, k = apply_rotary_pos_emb(q, k, cos, sin)

            # Extract single token - use views/slices, no copy
            q_t = q  # [B, HQ, 1, D] - already correct shape
            k_t = k[:, :, 0, :]  # [B, H, D] - direct indexing, no copy
            v_t = v[:, :, 0, :]  # [B, H, D] - direct indexing, no copy
            # Route query and key for the new token
            query_idx_t_M_stored, query_idx_t_M_plus_stored, bucket_id = self._route_query_and_key_for_token(
                q_t, k_t, B, HQ, H, beam_width_M, beam_width_M_plus,
                shared_router=shared_router
            )
            # Select slot in bucket to update
            slot_probs, alpha, selected_bucket_keys, selected_bucket_values = self._select_slot_in_bucket(
                k_t, bucket_id, cache_bucket_keys, cache_bucket_values, cache_key_idx, B, H, N_sample_per_bucket, D,
                shared_slot_query_proj=shared_slot_query_proj,
                shared_slot_key_proj=shared_slot_key_proj
            )

            # Update cache and sliding window
            # Note: This updates cache_bucket_keys and cache_bucket_values directly.
            # No need to re-gather since we're updating the bucket cache directly, not cache_k/cache_v.
            cache_k_window, cache_v_window, cache_bucket_keys, cache_bucket_values = self._update_cache_and_sliding_window(
                cache_k_window, cache_v_window, cache_bucket_keys, cache_bucket_values, k_t, v_t,
                slot_probs, alpha, selected_bucket_keys, selected_bucket_values, bucket_id, B, H, N_sample_per_bucket, D
            )

            # Prepare inputs for attention computation
            if cache_v_window.dtype == torch.float16 or cache_v_window.dtype == torch.bfloat16:
                target_dtype_gen = cache_v_window.dtype
            else:
                target_dtype_gen = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

            cache_k_for_attn = cache_k_window.to(target_dtype_gen)
            cache_v_for_attn = cache_v_window.to(target_dtype_gen)
            cache_bucket_keys_reshaped = (
                cache_bucket_keys.view(B, H, num_bucket * N_sample_per_bucket, hdim)
                .transpose(1, 2)
                .contiguous() #to(target_dtype_gen)
            )
            cache_bucket_values_reshaped = (
                cache_bucket_values.view(B, H, num_bucket * N_sample_per_bucket, hdim)
                .transpose(1, 2)
                .contiguous() #.to(target_dtype_gen)
            )

            cache_key_idx_reshaped = torch.zeros(
                (B, num_bucket * N_sample_per_bucket, H), dtype=torch.long, device=cache_k_window.device
            )

            hidden_t = hidden_states
            # Use shared g_proj if provided, otherwise use layer-specific one
            if hasattr(self, 'use_shared_g_proj') and self.use_shared_g_proj:
                assert shared_g_proj is not None, "use_shared_g_proj is True but shared_g_proj is None"
                g_proj = shared_g_proj
            else:
                g_proj = self.g_proj
            # Output dimension depends on use_sink: 3 if sink token is used, 2 otherwise
            g_proj_output_dim = 3 if self.config.use_sink else 2
            g_topk_t = F.softmax(g_proj(hidden_t).view(B, 1, HQ, g_proj_output_dim), dim=-1)
            q_t_attn = q_t.transpose(1, 2).to(target_dtype_gen)

            # Compute attention for this token
            attn_output_t = self._compute_attention_for_token(
                q_t_attn, cache_k_for_attn, cache_v_for_attn,
                cache_bucket_keys_reshaped, cache_bucket_values_reshaped, cache_key_idx_reshaped,
                query_idx_t_M_stored, query_idx_t_M_plus_stored, g_topk_t, beam_width_score,
                N_sample_per_bucket, HQ #, use_causal_swa=False,
            )

            # Project output
            attn_output = attn_output_t.contiguous().view(B, 1, self.config.hidden_size)  # .permute(0, 2, 1, 3)
            attn_output = self.o_proj(attn_output)

            # Store updated cache states back to past_key_values for this specific layer
            _set_arm_cache(
                past_key_values,
                self.layer_idx,
                cache_k_window,
                cache_v_window,
                cache_key_idx,
                cache_bucket_keys,
                cache_bucket_values,
                beam_width_M,
                beam_width_M_plus,
                num_bucket,
                N_sample_per_bucket,
                beam_width_score,
            )

            reg_loss = 0  # No reg loss in generation mode
            output = attn_output, None, past_key_values, reg_loss

        else:
            # Initialization mode: L > cache_len
            # Compute gen_len from actual input length
            if L < 16:
                self.cache_len = 8
                shared_router.num_levels = 1
            elif L < 32 and L >= 16:
                self.cache_len = 16
                shared_router.num_levels = 1
            elif L < 64 and L >= 32:
                self.cache_len = 32
                shared_router.num_levels = 2
            elif L < 128 and L >= 64:   
                self.cache_len = 64
                shared_router.num_levels = 2
            elif L < 256 and L >= 128:
                self.cache_len = 128
                shared_router.num_levels = 2
            elif L < 512 and L >= 256:
                self.cache_len = 256
                shared_router.num_levels = 2
            elif L < 1024 and L >= 512:
                self.cache_len = 512
                shared_router.num_levels = 2
            elif L < 2048 and L >= 1024:
                self.cache_len = 1024
                shared_router.num_levels = 3
            # elif L < 2048:
            #     shared_router.num_levels = 3
            #     # Find the largest power of 2 <= L
            #     # Check if L is a power of 2
            #     if (L & (L - 1)) == 0:
            #         # L is a power of 2, use it directly
            #         self.cache_len = L
            #     else:
            #         # L is not a power of 2, find the largest power of 2 < L
            #         # L.bit_length() gives the number of bits needed to represent L
            #         # Subtract 1 to get the power of 2 that is < L
            #         power = L.bit_length() - 1
            #         self.cache_len = 1 << power  # 2^power
            elif L >= 2048:
                shared_router.num_levels = 3
                self.cache_len = 2048
            # breakpoint()
            # self.cache_len = min(4, L)
            # self.local_window_size = min(4, L)
            # shared_router.num_levels = 1
            gen_len = max(0, L - self.cache_len) if L > self.cache_len else 0
            if L < self.cache_len and self.training:
                # Pad hidden_states and position_embeddings to cache_len
                pad_length = self.cache_len - L
                
                # Pad hidden_states with zeros at the beginning: [B, L, hidden_size] -> [B, cache_len, hidden_size]
                hidden_states = torch.cat([
                    torch.randn(B, pad_length, hidden_states.shape[-1], 
                               dtype=hidden_states.dtype, device=hidden_states.device),
                    hidden_states
                ], dim=1)
                
                # Pad position_embeddings (cos, sin) with 0-th position embedding
                cos, sin = position_embeddings
                
                # Handle different dimensions: cos/sin can be [seq_len, head_dim] or [B_or_1, seq_len, head_dim]
                if cos.dim() == 2:
                    # [seq_len, head_dim] -> pad to [cache_len, head_dim]
                    cos_0 = cos[0:1, :]  # [1, head_dim] - 0-th position embedding
                    sin_0 = sin[0:1, :]  # [1, head_dim] - 0-th position embedding
                    cos_padded = torch.cat([
                        cos_0.expand(pad_length, -1),  # [pad_length, head_dim]
                        cos
                    ], dim=0)
                    sin_padded = torch.cat([
                        sin_0.expand(pad_length, -1),  # [pad_length, head_dim]
                        sin
                    ], dim=0)
                elif cos.dim() == 3:
                    # [B_or_1, seq_len, head_dim] -> pad to [B_or_1, cache_len, head_dim]
                    cos_0 = cos[:, 0:1, :]  # [B_or_1, 1, head_dim] - 0-th position embedding
                    sin_0 = sin[:, 0:1, :]  # [B_or_1, 1, head_dim] - 0-th position embedding
                    cos_padded = torch.cat([
                        cos_0.expand(-1, pad_length, -1),  # [B_or_1, pad_length, head_dim]
                        cos
                    ], dim=1)
                    sin_padded = torch.cat([
                        sin_0.expand(-1, pad_length, -1),  # [B_or_1, pad_length, head_dim]
                        sin
                    ], dim=1)
                else:
                    raise ValueError(f"Unexpected position_embeddings dimension: {cos.dim()}")
                
                position_embeddings = (cos_padded, sin_padded)
                
                # Pad attention_mask if provided (pad with 0s at the beginning, meaning masked positions)
                attention_mask = attention_mask
                if attention_mask is not None:
                    # attention_mask shape is typically [B, L] or [B, 1, L, L]
                    if attention_mask.dim() == 2:
                        # [B, L] -> [B, cache_len]
                        attention_mask_padded = torch.cat([
                            torch.zeros(B, pad_length, dtype=attention_mask.dtype, device=attention_mask.device),
                            attention_mask
                        ], dim=1)
                    elif attention_mask.dim() == 4:
                        # [B, 1, L, L] -> [B, 1, cache_len, cache_len]
                        # Pad both sequence dimensions
                        mask_pad_1 = torch.zeros(B, 1, pad_length, L, dtype=attention_mask.dtype, device=attention_mask.device)
                        mask_pad_2 = torch.zeros(B, 1, self.cache_len, pad_length, dtype=attention_mask.dtype, device=attention_mask.device)
                        attention_mask_padded = torch.cat([
                            mask_pad_1,
                            attention_mask
                        ], dim=2)
                        attention_mask_padded = torch.cat([
                            mask_pad_2,
                            attention_mask_padded
                        ], dim=3)
                    else:
                        # For other shapes, try to pad along the last dimension
                        pad_shape = list(attention_mask.shape)
                        pad_shape[-1] = pad_length
                        mask_pad = torch.zeros(pad_shape, dtype=attention_mask.dtype, device=attention_mask.device)
                        attention_mask_padded = torch.cat([mask_pad, attention_mask], dim=-1)
                
            # ============================================================================
            
            # Compute QKV (+ RoPE)
            
            # Pad sequences to the nearest power of 2
            # next_pow2 = 1 << (q_len - 1).bit_length()
            # total_num_buckets = self.router.num_buckets_per_level**self.router.num_levels
            # while next_pow2 < total_num_buckets:
            #     next_pow2 = 1 << (next_pow2).bit_length()

            # if next_pow2 != q_len:
            #     pad_len = next_pow2 - q_len
            #     noise = torch.randn(
            #         batch_size, pad_len, d, dtype=hidden_states.dtype, device=hidden_states.device
            #     )
            #     hidden_states = torch.cat([hidden_states, noise], dim=1)

            # # Rearrange to [B, num_heads, seq_len, head_dim] to match cache shape [B, H, cache_len, D]
            # q = rearrange(self.q_proj(hidden_states), 'b s (h d) -> b h s d', d=self.head_dim)
            # k = rearrange(self.k_proj(hidden_states), 'b s (h d) -> b h s d', d=self.head_dim)
            # v = rearrange(self.v_proj(hidden_states), 'b s (h d) -> b h s d', d=self.head_dim)

            # # if self.qk_norm:
            # #     q, k = self.q_norm(q), self.k_norm(k)

            # cu_seqlens = kwargs.get('cu_seqlens')
            # # After rearrange, q has shape [B, num_heads, seq_len, head_dim], so seq_len is at index 2
            # seqlen_offset, max_seqlen = 0, q.shape[2]
            # if self.config.max_position_embeddings is not None:
            #     max_seqlen = max(max_seqlen, self.config.max_position_embeddings)
            # q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)

            # Compute QKV - use view instead of reshape to avoid copy
            # In training mode: Take last self.gen_len as gen_part, rest padded with Gaussian noise to cache_len as cache_part
            original_L = L
            original_cache_len = None
            if self.training:
                # Gen portion: last self.gen_len tokens
                gen_portion = hidden_states[:, -self.gen_len:, :]  # [B, self.gen_len, hidden_size]
                
                # Cache portion: rest of tokens (first L - self.gen_len), pad with Gaussian noise to cache_len
                cache_portion = hidden_states[:, :L - self.gen_len, :]  # [B, L - self.gen_len, hidden_size]
                original_cache_len = cache_portion.shape[1]  # Store for later removal
                cache_pad_len = self.cache_len - cache_portion.shape[1]
                if cache_pad_len > 0:
                    # Pad with Gaussian noise
                    cache_padding = torch.randn(
                        B, cache_pad_len, hidden_states.shape[-1],
                        dtype=hidden_states.dtype,
                        device=hidden_states.device
                    )
                    cache_portion = torch.cat([cache_portion, cache_padding], dim=1)  # [B, self.cache_len, hidden_size]
                
                # Combine: cache portion (self.cache_len) + gen portion (self.gen_len)
                hidden_states = torch.cat([cache_portion, gen_portion], dim=1)  # [B, self.cache_len + self.gen_len, hidden_size]
                L = self.cache_len + self.gen_len
                
                # Pad position embeddings to match new length
                # IMPORTANT: apply_rotary_pos_emb expects cos/sin in the shape [B, seq_len, head_dim]
                # (it then unsqueezes at dim=1 to make them broadcastable to q/k: [B, heads, seq_len, head_dim]).
                # In practice, upstream position_embeddings can be:
                # - [seq_len, head_dim]
                # - [1, seq_len, head_dim] or [B, seq_len, head_dim]
                # - sometimes other layouts; we normalize to [B_or_1, seq_len, head_dim] here.
                cos, sin = position_embeddings
                new_seq_len = self.cache_len + self.gen_len

                # Normalize to 3D: [B_or_1, seq_len, head_dim]
                if cos.dim() == 2:
                    cos = cos.unsqueeze(0)
                    sin = sin.unsqueeze(0)
                elif cos.dim() == 3 and cos.shape[0] not in (1, B):
                    # Likely [seq_len, num_heads, head_dim] -> take first "head" and add batch dim
                    cos = cos[:, 0, :].unsqueeze(0)
                    sin = sin[:, 0, :].unsqueeze(0)
                # Now pad/truncate along seq_len dimension (dim=1)
                current_seq_len = cos.shape[1]
                if current_seq_len < new_seq_len:
                    pad_len = new_seq_len - current_seq_len
                    cos_last = cos[:, -1:, :]  # [B_or_1, 1, D]
                    sin_last = sin[:, -1:, :]  # [B_or_1, 1, D]
                    cos_padding = cos_last.repeat(1, pad_len, 1)
                    sin_padding = sin_last.repeat(1, pad_len, 1)
                    cos = torch.cat([cos, cos_padding], dim=1)
                    sin = torch.cat([sin, sin_padding], dim=1)
                elif current_seq_len > new_seq_len:
                    cos = cos[:, :new_seq_len, :]
                    sin = sin[:, :new_seq_len, :]

                position_embeddings = (cos, sin)

            input_shape = hidden_states.shape[:-1]
            hidden_shape = (*input_shape, -1, self.head_dim)

            # MEMORY: q: [B, HQ, L, D] = B * HQ * L * D elements
            # MEMORY: k: [B, H, L, D] = B * H * L * D elements  
            # MEMORY: v: [B, H, L, D] = B * H * L * D elements
            # These are typically LARGE if L (sequence length) is large
            q = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            k = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            v = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

            cos, sin = position_embeddings
            # Normalize to [B_or_1, seq_len, head_dim] so apply_rotary_pos_emb broadcasts correctly.
            if cos.dim() == 2:
                cos = cos.unsqueeze(0)
                sin = sin.unsqueeze(0)
            elif cos.dim() == 3 and cos.shape[0] not in (1, B):
                cos = cos[:, 0, :].unsqueeze(0)
                sin = sin[:, 0, :].unsqueeze(0)

            q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)

            # Init fixed cache with the first cache_len tokens - use slice, contiguous only if needed
            # MEMORY: cache_k: [B, H, cache_len, D] = B * H * cache_len * D
            # MEMORY: cache_v: [B, H, cache_len, D] = B * H * cache_len * D
            # MEMORY: cache_q: [B, HQ, cache_len, D] = B * HQ * cache_len * D
            cache_k = k[:, :, :self.cache_len, :].contiguous()  # [B, H, cache_len, D]
            cache_v = v[:, :, :self.cache_len, :].contiguous()  # [B, H, cache_len, D]
            cache_q = q[:, :, :self.cache_len, :].contiguous()  # [B, HQ, cache_len, D]
            # Initialize cache arm and get bucket organization
            # MEMORY: cache_bucket_keys: [B*H, num_bucket*N_sample_per_bucket, D] = B * H * num_bucket * N_sample_per_bucket * D
            # MEMORY: cache_bucket_values: [B*H, num_bucket*N_sample_per_bucket, D] = B * H * num_bucket * N_sample_per_bucket * D
            # These can be VERY LARGE if num_bucket * N_sample_per_bucket is large
            (
                cache_bucket_keys,
                cache_bucket_values,
                cache_key_idx,
                cache_query_idx_M_stored,
                cache_query_idx_M_plus_stored,
                reg_loss,
                beam_width_score,
                beam_width_M,
                beam_width_M_plus,
                num_bucket,
                N_sample_per_bucket,
            ) = self._initialize_cache_and_arm(
                cache_k, cache_v, cache_q, hidden_states, B, H, HQ, hdim,
                shared_router=shared_router,
                shared_step_embedding=shared_step_embedding,
                shared_beam_width_proj=shared_beam_width_proj,
                shared_beam_width_lstm_input=shared_beam_width_lstm_input,
                shared_beam_width_lstm_hidden=shared_beam_width_lstm_hidden
            )
            look_back_window_size = self.local_window_size if self.local_window_size < self.cache_len else self.cache_len
            # Initialize sliding window for attention
            cache_k_window = cache_k.transpose(1, 2)[:, -look_back_window_size :, :, :]  # [B, window_size, H, D]
            cache_v_window = cache_v.transpose(1, 2)[:, -look_back_window_size :, :, :]  # [B, window_size, H, D]

            # cache_key_idx_reshaped = torch.zeros(
            #     (B, num_bucket * N_sample_per_bucket, H), dtype=torch.long, device=cache_k.device
            # )

            # Compute attention for cache portion
            # cache_q = q[:, :, : self.cache_len, :]
            # cache_k = k[:, :, : self.cache_len, :]
            # cache_v = v[:, :, : self.cache_len, :]
            
            # Reshape cache_query_idx_M_stored back to [B*HQ, cache_len, beam_width_M] format for backward compatibility
            cache_query_idx_for_attn = cache_query_idx_M_stored.view(B * HQ, self.cache_len, beam_width_M) if cache_query_idx_M_stored is not None else None
            cache_attn_output = self._compute_attention_for_cache(
                cache_q, cache_k, cache_v, hidden_states, # torch.Size([4, 32, 2048, 128]) torch.Size([4, 8, 2048, 128])
                cache_query_idx_for_attn, cache_key_idx, cache_bucket_keys, cache_bucket_values,  # cache_query_idx.shape torch.Size([128, 2048, 64]) cache_key_idx.shape torch.Size([32, 64, 32]); 
                                                                                         # cache_bucket_keys.shape torch.Size([32, 2048, 128])
                num_bucket, N_sample_per_bucket, B, HQ, H, self.cache_len, hdim,
                shared_g_proj=shared_g_proj,
                cache_query_idx_M_stored=cache_query_idx_M_stored,
                cache_query_idx_M_plus_stored=cache_query_idx_M_plus_stored,
                beam_width_score=beam_width_score
            )

            # Process generation tokens
            # Condition branch: use batched processing for prefilling if there are tokens beyond cache_len
            # Compute gen_len from actual input length (tokens beyond cache_len)
            gen_len = max(0, L - self.cache_len) if L > self.cache_len else 0
            is_prefilling = gen_len > 0
            attn_output_gen = None        
            if is_prefilling:  ### Prefilling the cache
                # Prepare all inputs needed for generation attention computation
                (
                    query_idx_gen_M_stored,
                    query_idx_gen_M_plus_stored,
                    bucket_id_gen,
                    target_dtype_gen,
                    g_topk_gen,
                    q_gen_attn,
                    query_idx_gen_M_stored_padded,
                    query_idx_gen_M_plus_stored_padded,
                    cache_k_window,
                    cache_v_window,
                    cache_k_for_attn,
                    cache_v_for_attn,
                    k_gen,
                    v_gen,
                ) = self._prepare_generation_attention_inputs(
                    q, k, v, hidden_states, cache_k_window, cache_v_window,
                    B, HQ, H, gen_len, beam_width_M, beam_width_M_plus,
                    shared_router=shared_router, shared_g_proj=shared_g_proj
                )

                # Rebalance bucket assignments if most tokens (> 1/4) are assigned to one bucket
                if self.training:
                    bucket_id_gen = self._rebalance_bucket_assignments(
                        bucket_id_gen, num_bucket, threshold_ratio=1/4
                    )
                
                # cache_k_window and cache_v_window are [B, window_size + gen_len, H, D] (already updated in _prepare_generation_attention_inputs)
                k_window = cache_k_window  # [B, window_size + gen_len, H, D]
                v_window = cache_v_window  # [B, window_size + gen_len, H, D]
                k_for_attn = k_window.to(target_dtype_gen)  # [B, window_size + gen_len, H, D]
                v_for_attn = v_window.to(target_dtype_gen)  # [B, window_size + gen_len, H, D]

                q_for_attn = q_gen_attn  # [B, gen_len, HQ, D]
                g_topk_for_attn = g_topk_gen  # [B, gen_len, HQ, g_proj_output_dim]
                
                # Use query_idx for gen_len only
                query_idx_M_for_attn = query_idx_gen_M_stored  # [B, HQ, gen_len, beam_width_M] or None
                query_idx_M_plus_for_attn = query_idx_gen_M_plus_stored  # [B, HQ, gen_len, beam_width_M_plus] or None
                
                # Process in chunks of 4096 to reduce memory usage
                chunk_size = 4096 # self.cache_len * 2  # 
                num_chunks = (gen_len + chunk_size - 1) // chunk_size  # Number of chunks for gen_len
                attn_output_gen_chunks = []

                for chunk_idx in range(num_chunks):
                    # chunk_start and chunk_end are in gen index space (starting from 0)
                    chunk_start_attn = chunk_idx * chunk_size
                    chunk_end_attn = min(chunk_start_attn + chunk_size, gen_len)
                    chunk_len = chunk_end_attn - chunk_start_attn

                    # Slice tensors for this chunk (using gen indices)
                    bucket_id_gen_chunk = bucket_id_gen[:, :, chunk_start_attn:chunk_end_attn]  # [B, H, chunk_len]

                    q_for_attn_chunk = q_for_attn[:, chunk_start_attn:chunk_end_attn, :, :]  # [B, chunk_len, HQ, D]
                    # k_for_attn and v_for_attn contain window_size, so add self.local_window_size to access gen portion
                    k_for_attn_chunk = k_for_attn[:, chunk_start_attn:look_back_window_size + chunk_end_attn, :, :].transpose(1, 2)  # [B, chunk_len+window_size, H, D] -> [B, H, chunk_len+window_size, D]
                    v_for_attn_chunk = v_for_attn[:, chunk_start_attn:look_back_window_size + chunk_end_attn, :, :].transpose(1, 2)  # [B, chunk_len+window_size, H, D] -> [B, H, chunk_len+window_size, D]

                    # Slice from the concatenated tensors (k_for_attn and v_for_attn contain window_size)
                    k_gen_chunk = k_for_attn[:, look_back_window_size + chunk_start_attn : look_back_window_size + chunk_end_attn, :, :].transpose(1, 2)  # [B, chunk_len, H, D] -> [B, H, chunk_len, D]
                    v_gen_chunk = v_for_attn[:, look_back_window_size + chunk_start_attn : look_back_window_size + chunk_end_attn, :, :].transpose(1, 2)  # [B, chunk_len, H, D]
                    
                    g_topk_chunk = g_topk_for_attn[:, chunk_start_attn : chunk_end_attn, :, :]  # [B, chunk_len, HQ, g_proj_output_dim]
                    
                    # Slice query_idx from gen tensors (using gen indices)
                    if query_idx_M_for_attn is not None:
                        query_idx_M_chunk = query_idx_M_for_attn[:, :, chunk_start_attn : chunk_end_attn, :]  # [B, HQ, chunk_len, beam_width_M]
                    else:
                        query_idx_M_chunk = None
                    
                    if query_idx_M_plus_for_attn is not None:
                        query_idx_M_plus_chunk = query_idx_M_plus_for_attn[:, :, chunk_start_attn : chunk_end_attn, :]  # [B, HQ, chunk_len, beam_width_M_plus]
                    else:
                        query_idx_M_plus_chunk = None
                        
                    # 3) Sort and pad bucket_id_gen, k_gen, and v_gen before slot selection 
                    bucket_id_padded, k_gen_padded, num_bucket, max_tokens_per_bucket, unsort_info, v_gen_padded = \
                        self._sort_and_pad_bucket_id_and_k(
                            bucket_id_gen_chunk, k_gen_chunk, cache_bucket_keys, 
                            # bucket_id_gen_chunk.shape: torch.Size([B, H, chunk_len])
                            # k_gen_chunk.shape: torch.Size([B, H, chunk_len, D])
                            B, H, chunk_len, N_sample_per_bucket, D, v_gen=v_gen_chunk
                        )
                    # 4) Select slots for all tokens (batched, using padded inputs)
                    slot_probs_padded, alpha_padded = \
                        self._select_slots_batched(
                            k_gen_padded, bucket_id_padded, cache_bucket_keys, cache_key_idx,
                            B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket, D,
                            shared_slot_query_proj=shared_slot_query_proj,
                            shared_slot_key_proj=shared_slot_key_proj
                        )
                    
                    # 5) Optimized: Use view() for contiguous tensors, reshape() for non-contiguous ones
                    padded_gen_len = num_bucket * max_tokens_per_bucket
                    k_gen_padded = k_gen_padded.view(B, H, padded_gen_len, D)
                    v_gen_padded = v_gen_padded.view(B, H, padded_gen_len, D)
                    slot_probs_padded = slot_probs_padded.view(B, H, padded_gen_len, N_sample_per_bucket)
                    alpha_padded = alpha_padded.view(B, H, padded_gen_len, N_sample_per_bucket)
                    bucket_id_padded = bucket_id_padded.view(B, H, padded_gen_len)
                    
                    # cache_k [B, H, cache_len, D]
                    # 6) Update cache in batched form using tensor operations
                    cache_bucket_keys, cache_bucket_values = self._update_cache_batched(
                        cache_bucket_keys.view(B, H, num_bucket* N_sample_per_bucket, D), cache_bucket_values.view(B, H, num_bucket* N_sample_per_bucket, D),
                        k_gen_padded, v_gen_padded,
                        slot_probs_padded, alpha_padded, cache_key_idx, bucket_id_padded,
                        padded_gen_len, B, H, N_sample_per_bucket, D, num_bucket
                    )
                    # cache_v [B, H, cache_len + chunk_len, D] [B*H, num_bucket, N_sample_per_bucket]
                    # 6.5) Update cache_key_idx to reflect which tokens were assigned to which slots (training only)
                    cache_key_idx = self._update_cache_key_idx(
                        unsort_info, bucket_id_gen_chunk, max_tokens_per_bucket, slot_probs_padded,
                        cache_key_idx, chunk_len, B, H, num_bucket, N_sample_per_bucket, self.cache_len+chunk_idx * chunk_size, chunk_len
                    )

                    # Memory optimization: Delete large intermediate tensors after cache update
                    del k_gen_padded, v_gen_padded, slot_probs_padded, alpha_padded, bucket_id_padded

                    # k_for_attn_chunk = torch.cat((k_for_attn[:, max(0, chunk_start-self.local_window_size): chunk_start, :, :],k_for_attn_chunk), dim=1)  # [B, window_size + chunk_end, H, D]
                    # v_for_attn_chunk = torch.cat((v_for_attn[:, max(0, chunk_start-self.local_window_size): chunk_start, :, :],v_for_attn_chunk), dim=1)  # [B, window_size + chunk_end, H, D]
                    # 7) Compute attention for this chunk (batched)
                    # Prepare cache_key_idx_reshaped for attention computation
                    attn_output_chunk_full = self._compute_attention_batched(
                        q_for_attn_chunk,  # [B, chunk_len, HQ, D]
                        k_for_attn_chunk.transpose(1, 2),  # [B, chunk_len, H, D]
                        v_for_attn_chunk.transpose(1, 2),  # [B, chunk_len, H, D]
                        cache_bucket_keys, cache_bucket_values,  # # [B, H, num_bucket*N_sample, D]
                        cache_key_idx,  # [B, num_bucket*N_sample, H] - placeholder
                        query_idx_M_chunk,  # [B, HQ, chunk_len, beam_width_M]
                        query_idx_M_plus_chunk,  # [B, HQ, chunk_len, beam_width_M_plus] or None
                        g_topk_chunk,  # [B, chunk_len, HQ, g_proj_output_dim]
                        beam_width_score, N_sample_per_bucket, HQ, chunk_len
                    )  # [B, chunk_len, HQ, D]
                    
                    # Append the chunk output (no need to slice since we already skipped window_size)
                    attn_output_gen_chunks.append(attn_output_chunk_full)
                cache_k_window = k_for_attn_chunk.transpose(1, 2)
                cache_v_window = v_for_attn_chunk.transpose(1, 2)
                # Concatenate all chunks (no need to slice window_size since we already skipped it)
                attn_output_gen = torch.cat(attn_output_gen_chunks, dim=1)  # [B, gen_len, HQ, D]
                
                # Project output
                # attn_output_gen is [B, gen_len, HQ, D], can directly view to [B, gen_len, hidden_size]
                # No transpose or permute needed
                attn_output_gen = attn_output_gen.contiguous().view(B, gen_len, self.config.hidden_size)
                attn_output_gen = self.o_proj(attn_output_gen)

            # Concatenate cache attention output with generation output
            # In training mode: remove padding from cache output if it was added
            if self.training and original_cache_len is not None:
                # Remove the padding portion from cache_attn_output
                # cache_attn_output currently has shape [B, self.cache_len, hidden_size] but includes padding
                # Keep only the original cache portion (without padding)
                if original_cache_len > 0:
                    cache_attn_output = cache_attn_output[:, :original_cache_len, :]
                else:
                    # If all was padding, create empty tensor
                    cache_attn_output = torch.zeros(B, 0, self.config.hidden_size, device=cache_attn_output.device, dtype=cache_attn_output.dtype)
            
            if attn_output_gen is not None:
                full_output = torch.cat([cache_attn_output, attn_output_gen], dim=1)
            else:
                full_output = cache_attn_output[:, -L:, :]
            # # Pad to original length L if needed
            # if full_output.size(1) < L:
            #     padding = hidden_states.new_zeros(B, L - full_output.size(1), self.config.hidden_size)
            #     full_output = torch.cat([full_output, padding], dim=1)

            # Store cache states in past_key_values whenever past_key_values exists
            # Note: We store arm cache attributes even if use_cache=False because
            # they are essential for subsequent generation steps. The use_cache flag
            # only controls whether past_key_values is returned in the output.
            # During prefilling (L > 1), use_cache may be False, but we still need
                # to store arm cache attributes for generation steps to work.
            if past_key_values is not None and not self.training:
                # Store arm cache for this specific layer (prevents overwriting between layers)
                if L > self.local_window_size:
                    look_back_window_size = self.local_window_size
                else:
                    look_back_window_size = L
                _set_arm_cache(
                    past_key_values,
                    self.layer_idx,
                    cache_k_window[:, -look_back_window_size:, :, :],
                    cache_v_window[:, -look_back_window_size:, :, :],
                    cache_key_idx,
                    cache_bucket_keys.view(B*H, self.cache_len, D),
                    cache_bucket_values.view(B*H, self.cache_len, D),
                    beam_width_M,
                    beam_width_M_plus,
                    num_bucket,
                    N_sample_per_bucket,
                    beam_width_score,
                )


            if return_updated_cache:
                output = (full_output, None, None, reg_loss), (cache_k, cache_v)
            else:
                output = (full_output, None, past_key_values if use_cache else None, reg_loss)
        
        # ArmTopKAttentionWithMemory returns:
        #   (full_out, None, past_key_values if use_cache else None, reg_loss)
        # We need to return: (attn_output, attn_weights, reg_loss) for HiLlamaAttention interface
        if isinstance(output, tuple) and len(output) >= 4:
            attn_output = output[0]  # First element is attention output (full_out)
            attn_weights = None  # ArmTopKAttentionWithMemory doesn't return attention weights (output[1] is None)
            # output[2] is past_key_values (already updated in-place)
            reg_loss = output[3]  # Last element is reg_loss
        elif isinstance(output, tuple) and len(output) >= 2:
            attn_output = output[0]
            reg_loss = output[-1] if len(output) > 1 else None
            attn_weights = None
        else:
            attn_output = output if not isinstance(output, tuple) else output[0]
            attn_weights = None
            reg_loss = None
        return attn_output, attn_weights, reg_loss


class HiLlamaDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: HiLlamaConfig, layer_idx: int, shared_params=None, shared_g_router=None):
        super().__init__()
        self.hidden_size = config.hidden_size

        # Pass shared_g_router to both HiLlamaAttention and HiLlamaAttentionWithMemory
        if getattr(config, "use_arm_memory", False):
            self.self_attn = HiLlamaAttentionWithMemory(
                config=config, 
                layer_idx=layer_idx, 
                shared_params=shared_params,
                shared_g_router=shared_g_router
            )
        else:
            self.self_attn = HiLlamaAttention(
                config=config, 
                layer_idx=layer_idx,
                shared_params=shared_g_router  # For HiLlamaAttention, shared_params refers to shared_g_router
            )

        self.mlp = HiLlamaMLP(config)
        self.input_layernorm = HiLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = HiLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        shared_g_proj: Optional[nn.Module] = None,
        shared_router: Optional[nn.Module] = None,
        shared_key_gate: Optional[nn.Module] = None,
        shared_slot_query_proj: Optional[nn.Module] = None,
        shared_slot_key_proj: Optional[nn.Module] = None,
        shared_beam_width_lstm_input: Optional[nn.Module] = None,
        shared_beam_width_lstm_hidden: Optional[nn.Module] = None,
        shared_step_embedding: Optional[nn.Module] = None,
        shared_beam_width_proj: Optional[nn.Module] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        # Self Attention
        hidden_states, attn_weights, reg_loss = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            shared_g_proj=shared_g_proj,
            shared_router=shared_router,
            shared_key_gate=shared_key_gate,
            shared_slot_query_proj=shared_slot_query_proj,
            shared_slot_key_proj=shared_slot_key_proj,
            shared_beam_width_lstm_input=shared_beam_width_lstm_input,
            shared_beam_width_lstm_hidden=shared_beam_width_lstm_hidden,
            shared_step_embedding=shared_step_embedding,
            shared_beam_width_proj=shared_beam_width_proj,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states, reg_loss


@auto_docstring
class HiLlamaPreTrainedModel(PreTrainedModel):
    config: HiLlamaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["HiLlamaDecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn = True
    _supports_sdpa = True
    _supports_flex_attn = True

    _can_compile_fullgraph = True
    _supports_attention_backend = True
    _can_record_outputs = {
        "hidden_states": HiLlamaDecoderLayer,
        "attentions": HiLlamaAttention,
    }
    # TODO
    # def _init_weights(self, module):
    #     """
    #     Initialize weights, but skip route parameters in HierarchicalRouter.
    #     Route parameters are already initialized with custom Xavier/Glorot initialization.
    #     """
    #     # Skip HierarchicalRouter modules - they have custom initialization
    #     if isinstance(module, HierarchicalRouter):
    #         # Mark route parameters as already initialized to prevent reinitialization
    #         for param in module.route:
    #             param._is_hf_initialized = True
    #         return
        
    #     # For all other modules, use parent's initialization
    #     super()._init_weights(module)


@auto_docstring
class HiLlamaModel(HiLlamaPreTrainedModel):
    def __init__(self, config: HiLlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)

        # Initialize shared parameters for arm memory components
        # Control via config.share_arm_params (default: False)
        # When True: all layers share the same parameters (reduces model size, gradients accumulate from all layers)
        # When False: each layer has its own parameters (more parameters, independent gradients per layer)
        use_arm_memory = getattr(config, "use_arm_memory", False)
        share_arm_params = getattr(config, "share_arm_params", False)
        
        head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        
        # Shared g_proj and router (always shared if share_arm_params=True, regardless of use_arm_memory)
        # These are used in HiLlamaAttention which is always present
        if share_arm_params:
            # g_proj output dimension: 3 if use_sink, 2 otherwise (no sink token needed)
            g_proj_output_dim = config.num_attention_heads * (3 if config.use_sink else 2)
            self.shared_g_proj = nn.Linear(config.hidden_size, g_proj_output_dim, bias=False)
            self.shared_router = HierarchicalRouter(
                head_dim,
                head_dim,
                beam_width=config.beam_width,
                num_levels=config.num_levels,
                num_buckets_per_level=getattr(config, 'num_buckets_per_level', 4),
            )
        else:
            self.shared_g_proj = None
            self.shared_router = None
        
        # Shared parameters for arm memory components (only used in HiLlamaAttentionWithMemory)
        if use_arm_memory and share_arm_params:
            slot_hidden_dim = head_dim // 2
            hidden_size_half = config.hidden_size // 2
            
            # Get num_leaf_buckets for step_embedding size
            # We need to access router config to get num_buckets_per_level and num_levels
            num_buckets_per_level = getattr(config, "num_buckets_per_level", 4)
            num_levels = getattr(config, "num_levels", 3)
            num_leaf_buckets = int(num_buckets_per_level ** num_levels)
            embedding_size = num_leaf_buckets + 100  #  TODO
            
            # Shared parameters for all layers
            # NOTE: When shared, gradients from all layers will automatically accumulate 
            # into these shared parameters during backpropagation (PyTorch's default behavior)
            self.shared_key_gate = nn.Linear(2 * head_dim, head_dim, bias=True)
            self.shared_slot_query_proj = nn.Linear(head_dim, slot_hidden_dim, bias=True)
            self.shared_slot_key_proj = nn.Linear(head_dim, slot_hidden_dim, bias=True)
            # Support multi-layer LSTM, default to 1 layer
            beam_width_lstm_num_layers = getattr(config, 'beam_width_lstm_num_layers', 1)
            if beam_width_lstm_num_layers == 1:
                self.shared_beam_width_lstm_input = nn.Linear(config.hidden_size, 4 * hidden_size_half, bias=True)
                self.shared_beam_width_lstm_hidden = nn.Linear(hidden_size_half, 4 * hidden_size_half, bias=True)
            else:
                self.shared_beam_width_lstm_input = nn.ModuleList([
                    nn.Linear(config.hidden_size if i == 0 else hidden_size_half, 4 * hidden_size_half, bias=True)
                    for i in range(beam_width_lstm_num_layers)
                ])
                self.shared_beam_width_lstm_hidden = nn.ModuleList([
                    nn.Linear(hidden_size_half, 4 * hidden_size_half, bias=True)
                    for _ in range(beam_width_lstm_num_layers)
                ])
            self.shared_step_embedding = nn.Embedding(embedding_size, config.hidden_size)
            self.shared_beam_width_proj = nn.Sequential(
                nn.Linear(hidden_size_half, 1, bias=True), nn.Sigmoid()
            )
        else:
            self.shared_key_gate = None
            self.shared_slot_query_proj = None
            self.shared_slot_key_proj = None
            self.shared_beam_width_lstm_input = None
            self.shared_beam_width_lstm_hidden = None
            self.shared_step_embedding = None
            self.shared_beam_width_proj = None
        
        self.layers = nn.ModuleList(
            [HiLlamaDecoderLayer(
                config, 
                layer_idx, 
                shared_params=self if (use_arm_memory and share_arm_params) else None,
                shared_g_router=self if share_arm_params else None  # Share g_proj and router even without arm_memory
            ) 
             for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = HiLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = HiLlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPast:
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        # Default use_cache to config value if not provided
        if use_cache is None:
            use_cache = self.config.use_cache

        if inputs_embeds is None:
            inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position: torch.Tensor = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        all_reg_losses = 0

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = create_causal_mask(
            config=self.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        hidden_states = inputs_embeds
        position_embeddings = self.rotary_emb(hidden_states, position_ids)
        # Prepare shared parameters for passing to layers
        use_arm_memory = getattr(self.config, "use_arm_memory", False)
        share_arm_params = getattr(self.config, "share_arm_params", False)
        
        # Shared g_proj and router (always shared if share_arm_params=True)
        shared_g_proj = self.shared_g_proj if share_arm_params else None
        shared_router = self.shared_router if share_arm_params else None
        
        # Shared parameters for arm memory (only if both use_arm_memory and share_arm_params)
        shared_key_gate = self.shared_key_gate if (use_arm_memory and share_arm_params) else None
        shared_slot_query_proj = self.shared_slot_query_proj if (use_arm_memory and share_arm_params) else None
        shared_slot_key_proj = self.shared_slot_key_proj if (use_arm_memory and share_arm_params) else None
        shared_beam_width_lstm_input = self.shared_beam_width_lstm_input if (use_arm_memory and share_arm_params) else None
        shared_beam_width_lstm_hidden = self.shared_beam_width_lstm_hidden if (use_arm_memory and share_arm_params) else None
        shared_step_embedding = self.shared_step_embedding if (use_arm_memory and share_arm_params) else None
        shared_beam_width_proj = self.shared_beam_width_proj if (use_arm_memory and share_arm_params) else None
        
        for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
            hidden_states, reg_loss = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                shared_g_proj=shared_g_proj,
                shared_router=shared_router,
                shared_key_gate=shared_key_gate,
                shared_slot_query_proj=shared_slot_query_proj,
                shared_slot_key_proj=shared_slot_key_proj,
                shared_beam_width_lstm_input=shared_beam_width_lstm_input,
                shared_beam_width_lstm_hidden=shared_beam_width_lstm_hidden,
                shared_step_embedding=shared_step_embedding,
                shared_beam_width_proj=shared_beam_width_proj,
                **kwargs,
            )
            
            all_reg_losses += reg_loss / len(self.layers)
        hidden_states = self.norm(hidden_states)
        return BaseModelArmOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            reg_losses=all_reg_losses
        )


@auto_docstring
class HiLlamaForCausalLM(HiLlamaPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        self.model = HiLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs],
    ) -> CausalLMOutputWithPast:
        r"""
        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        outputs: BaseModelArmOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        if self.training and self.config.reg_loss_alpha > 0:
            reg_loss = outputs[-1]
            # Check if reg_loss is finite before adding to loss
            if reg_loss is not None and torch.isfinite(reg_loss):
                loss = loss + self.config.reg_loss_alpha * (reg_loss - reg_loss.detach())
            elif reg_loss is not None and not torch.isfinite(reg_loss):
                # Log warning if reg_loss is NaN/Inf
                import warnings
                warnings.warn(f"reg_loss is {reg_loss.item()} (NaN/Inf), skipping addition to loss")
        return CausalLMArmOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            reg_losses=outputs.reg_losses
        )


class HiLlamaForSequenceClassification(GenericForSequenceClassification, HiLlamaPreTrainedModel): ...


class HiLlamaForQuestionAnswering(GenericForQuestionAnswering, HiLlamaPreTrainedModel):
    base_model_prefix = "transformer"  # For BC, where `transformer` was used instead of `model`


class HiLlamaForTokenClassification(GenericForTokenClassification, HiLlamaPreTrainedModel): ...


__all__ = [
    "HiLlamaForCausalLM",
    "HiLlamaModel",
    "HiLlamaPreTrainedModel",
    "HiLlamaForSequenceClassification",
    "HiLlamaForQuestionAnswering",
    "HiLlamaForTokenClassification",
]