import torch
from torch import nn
import torch.nn.functional as F
from transformers.utils import logging
from einops import rearrange
from typing import Callable, Optional, Tuple, Union
from transformers.cache_utils import Cache
import math

logger = logging.get_logger(__name__)

import triton
import triton.language as tl
import triton.language.core as core
from triton.language.standard import _log2, sum, zeros_like

import sys
import os
# Get the directory of the current file (routing_attn.py)
current_dir = os.path.dirname(__file__)

# Go up one level: routing-tree-attn
root_dir = os.path.abspath(os.path.join(current_dir, '..'))

# Build the path to the ops folder
ops_dir = os.path.join(root_dir, 'nsa_lib', 'nsa_lib', 'ops')
breakpoint()
# Add to sys.path
sys.path.append(ops_dir)
from hir_kernel import parallel_hir

@triton.jit
def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
    n_outer: core.constexpr = x.numel >> n_dims
    shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
    y = core.reshape(x, shape)
    # slice left/right with 'stride' 2**(n_dims - i - 1)
    mask = core.arange(0, 2)[None, :, None]
    left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
    right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)
    left = core.reshape(left, x.shape)
    right = core.reshape(right, x.shape)

    # idx
    y_idx = core.reshape(ids, shape)
    left_idx = core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape)
    right_idx = core.broadcast_to(sum(y_idx * mask, 1)[:, None, :], shape)
    left_idx = core.reshape(left_idx, x.shape)
    right_idx = core.reshape(right_idx, x.shape)

    # actual compare-and-swap
    idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth,
                                signed=True)
    ileft = left.to(idtype, bitcast=True)
    iright = right.to(idtype, bitcast=True)
    ix = x.to(idtype, bitcast=True)

    cond = (left > right) ^ flip

    ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))

    new_ids = ids ^ core.where(cond, left_idx ^ right_idx, zeros_like(ids))

    return ret.to(x.dtype, bitcast=True), new_ids


@triton.jit
def _bitonic_merge(x, ids, stage: core.constexpr, order: core.constexpr,
                   n_dims: core.constexpr):
    '''
    order_type 0 == ascending
    order_type 1 == descending
    order_type 2 == alternating
    '''
    n_outer: core.constexpr = x.numel >> n_dims
    core.static_assert(stage <= n_dims)
    # flip denotes whether to re-arrange sub-sequences of elements in ascending or
    # descending order.
    # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
    # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
    # a stride of 2) at this stage
    if order == 2:
        shape: core.constexpr = [
            n_outer * 2**(n_dims - 1 - stage), 2, 2**stage
        ]
        flip = core.reshape(
            core.broadcast_to(core.arange(0, 2)[None, :, None], shape),
            x.shape)
    else:
        flip = order
    # perform `stage` rounds of `compare-and-swap`
    for i in core.static_range(stage):
        x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
    return x, ids


@triton.jit
def argsort(x,
            ids,
            dim: core.constexpr = None,
            descending: core.constexpr = core.constexpr(1)):  # core.CONSTEXPR_0 -> from small to large
    # handle default dimension or check that it is the most minor dim
    _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
    core.static_assert(_dim == len(x.shape) - 1,
                       "only minor dimension is currently supported")
    # iteratively run bitonic merge-sort steps
    n_dims: core.constexpr = _log2(x.shape[_dim])

    for i in core.static_range(1, n_dims + 1):
        x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending,
                                n_dims)
    return x, ids
# -----------------------------------------------------------------------------
# 1) Triton kernel: blocked, batched hierarchical beam search across all levels
# -----------------------------------------------------------------------------
@triton.jit
def hierarchical_beam_search_blocked(
    # pointers
    q_ptr,              # float32* [B_S, D]
    route_ptr,          # float32* [sum_P, D, C]
    offsets_ptr,        # int32*   [L+1]
    counts_ptr,         # int32*   [L]
    scores_ptr,         # float32* [B_S, beam*C]
    idxs_ptr,           # int32*   [B_S, beam*C]
    scores_beam_ptr,         # float32* [B_S, beam]
    idxs_beam_ptr,           # float32* [B_S, beam]
    # scalars
    B_S, D, C, L, K:tl.constexpr,
    stride_qD:tl.constexpr,
    stride_route_L:tl.constexpr, stride_route_D: tl.constexpr, stride_route_C: tl.constexpr,
    beam: tl.constexpr,
    # compile-time tiles
    BLOCK_TOKENS: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_C: tl.constexpr
):

    # token block
    block_id  = tl.program_id(0)
    offs_t    = tl.arange(0, BLOCK_TOKENS)
    token_ids = block_id * BLOCK_TOKENS + offs_t      # [T]
    valid_t   = token_ids < B_S                       # [T]

    # 1) load query tile [T, BLOCK_D]
    offs_d = tl.arange(0, BLOCK_D)
    mask_q = (offs_d[None, :] < D) & (valid_t[:, None])
    ptrs_q = q_ptr + token_ids[:, None]*stride_qD + offs_d[None, :]
    q_tile = tl.load(ptrs_q, mask=mask_q, other=0.0)  # [T, BLOCK_D]

    # 2) init beam state [T, beam]
    offs_beam = tl.arange(0,  beam)
    offs_bucket = tl.arange(0, BLOCK_C)
    offs_cand = tl.arange(0, beam * stride_route_D)
    beam_probs = tl.zeros([BLOCK_TOKENS, beam], dtype=tl.float32)
    beam_probs = beam_probs + 1
    beam_parents = tl.zeros([BLOCK_TOKENS, beam], dtype=tl.int32)

    # set first beam slot prob=1.0
    zero_b = tl.arange(0, beam) == 0                   # [beam]
    init_mask = valid_t[:, None] & zero_b[None, :]
    beam_probs = tl.where(init_mask, beam_probs, (tl.zeros((beam,), dtype=tl.float32)-1)[None, :])

    # scratch for bitonic
    MAX_CAND = beam * C 

    # scores_ptr = scores_ptr + block_id * BLOCK_TOKENS* beam * stride_route_D      # float32* [B_S, beam*C]
    # idxs_ptr = idxs_ptr + block_id *BLOCK_TOKENS* beam * stride_route_D         # int32*   [B_S, beam*C]
    # scores_beam_ptr = scores_beam_ptr + block_id *BLOCK_TOKENS* beam     # float32* [B_S, beam]
    # idxs_beam_ptr = idxs_beam_ptr + block_id *BLOCK_TOKENS* beam          # float32* [B_S, beam]

    # score index dont have to store at all
    # 3) loop levels
    for lvl in range(L):
        P_l    = tl.load(counts_ptr  + lvl)
        offset = tl.load(offsets_ptr + lvl)

        # clear scratch
        # for i in range(MAX_CAND):

        # compute scores
        # for b in range(beam):
        prev_p   = beam_probs    #[:, b]               # [T]
        parent_b = beam_parents  #[:, b]               # [T]
        baseW    = route_ptr + (offset + parent_b)*stride_route_L  # [T]
        # for c in range(C):
        offs_d = tl.arange(0, BLOCK_D)
        mask_w = (offs_d[None, None, :, None] < D) & valid_t[:, None, None, None]
        ptrs_w = baseW[:, :, None, None] \
                    + offs_bucket[None, None, None, :]+(offs_d[None, None, :, None]*stride_route_D) 
                    # + c*stride_route_C
        w_tile = tl.load(ptrs_w, mask=mask_w, other=0.0)      # [T, beam, D, C] # [T, BLOCK_D]
        dot    = tl.sum(q_tile[:, None, :, None] * w_tile, axis=2)              # [T, beam, C]
        sc     = tl.exp(tl.cast(dot, tl.float32))

        sc     = sc / (tl.sum(sc, axis=2, keep_dims=True)+1e-6) * prev_p [:, :, None] # [T, beam*C] combined_probs
        sc     = tl.reshape(sc, (BLOCK_TOKENS, beam * stride_route_D) )  # view    


            # ——— Bitonic‐sort top‐K (vectorized) ———
        # scores_ptr, idxs_ptr: global [B_S, N]
        N = beam * stride_route_D
        ids =  tl.zeros([BLOCK_TOKENS, beam * stride_route_D], dtype=tl.int64) + offs_cand[None, :] # tl.arange
        ids = tl.broadcast_to(offs_cand[None, :], (BLOCK_TOKENS, beam * stride_route_D))
        new_sc, new_ids = argsort(sc, ids)

        ptr_s = scores_ptr + token_ids[:, None]*(MAX_CAND)+ offs_cand[None, :]  # beam_parents
        tl.store(ptr_s, new_sc, mask=valid_t[:, None])
        ptr_i = idxs_ptr + token_ids[:, None]*(MAX_CAND) + offs_cand[None, :]  # beam_probs
        tl.store(ptr_i, new_ids, mask=valid_t[:, None])  # tl.broadcast(tl.int32(i), BLOCK_TOKENS)

        # extract top-K
        # new_probs   = tl.zeros([BLOCK_TOKENS, beam], dtype=tl.float32)
        # new_parents = tl.zeros([BLOCK_TOKENS, beam], dtype=tl.int32)
        # for r in range(K):
        # store and load from a pointer!

        offs_topk = tl.arange(0, K)
        ptr_s = scores_ptr + token_ids[:, None]*MAX_CAND + offs_topk
        ptr_i = idxs_ptr   + token_ids[:, None]*MAX_CAND + offs_topk
        top_sc    = tl.load(ptr_s, mask=valid_t[:, None], other=0.0)
        top_ix    = tl.load(ptr_i, mask=valid_t[:, None], other=0)
        p     = top_ix // C
        c     = top_ix % C
        # idxr  = offs_t[:, None]*beam + r
        
        ptr_s_beam = scores_beam_ptr + token_ids[:, None]*beam + offs_topk
        tl.store(ptr_s_beam,   top_sc, mask=valid_t[:, None])

        p = tl.load(idxs_beam_ptr + token_ids[:, None]*beam + p, mask=valid_t[:, None])
        ptr_i_beam = idxs_beam_ptr   + token_ids[:, None]*beam + offs_topk
        id = p*C+c
        tl.store(ptr_i_beam, id, mask=valid_t[:, None])

        beam_probs   = top_sc
        beam_parents = id

    # write outputs
    # for t in range(BLOCK_TOKENS):
    #     tid = token_ids[t]
    #     if not valid_t[t]: continue
    #     base_o = tid * K
    #     for r in range(K):
    #         pc = beam_parents[t, r]
    #         p  = pc // C
    #         c  = pc % C
    #         tl.store(out_p_ptr + base_o + r, p)
    #         tl.store(out_c_ptr + base_o + r, c)


# -----------------------------------------------------------------------------
# 2) Python wrapper + test
# -----------------------------------------------------------------------------
def hierarchical_search_triton(
    q, route_flat, level_offsets, parent_counts,
    beam_width, num_levels
):

    B, S, D = q.shape
    beam = beam_width
    C    = route_flat.shape[-1]
    BLOCK_D = D
    BLOCK_C = C
    L, K = num_levels, beam_width
    B_S  = B * S

    # flatten + contig
    q_flat  = q.contiguous().view(B_S, D)
    off_ptr = level_offsets.to(torch.int32).contiguous()
    cnt_ptr = parent_counts.to(torch.int32).contiguous()

    # global scratch
    scores = torch.zeros((B_S, beam*C), dtype=torch.float32, device=q.device)
    idxs   = torch.zeros((B_S, beam*C),dtype=torch.int32, device=q.device)#  arange(beam*C, dtype=torch.int32, device=q.device)[None,:].expand(B_S, -1).contiguous()
    
    scores_beam = torch.zeros((B_S, beam), dtype=torch.float32, device=q.device)
    idxs_beam   = torch.zeros((B_S, beam),dtype=torch.int32, device=q.device)#  arange(beam*C, dtype=torch.int32, device=q.device)[None,:].expand(B_S, -1).contiguous()
    

    # out_p  = torch.empty((B_S, K),   dtype=torch.int32,   device=q.device)
    # out_c  = torch.empty((B_S, K),   dtype=torch.int32,   device=q.device)

    # strides
    s_qD  = D
    s_rL  = D*C; s_rD = C; s_rC = 1

    # launch
    # num_blocks = (B_S + BLOCK_TOKENS - 1)//BLOCK_TOKENS
    # grid = (num_blocks,)
    BLOCK_TOKENS= 64 # 64 # 128: 2.334024667739868 512: 3.244183301925659
    num_blocks = (B_S + BLOCK_TOKENS - 1)//BLOCK_TOKENS
    grid = (num_blocks,)
    # grid = lambda META: ((B_S + META['BLOCK_TOKENS'] - 1)//META['BLOCK_TOKENS'],)  # ( (T + META['BS'] -1)//META['BS'], BH )
    hierarchical_beam_search_blocked[grid](
        q_flat,
        route_flat,
        off_ptr,
        cnt_ptr,
        scores,
        idxs,
        scores_beam,
        idxs_beam,
        B_S, D, C, L, K,
        s_qD, s_rL, s_rD, s_rC,
        beam,
        BLOCK_TOKENS,
        BLOCK_D=BLOCK_D,
        BLOCK_C=BLOCK_C
    )
    return idxs_beam  # out_p.view(B, S, K), out_c.view(B, S, K)



class HierarchicalRouter(nn.Module):
    # aim to learn generate hierachical trees with more layers #
    def __init__(
        self, input_dim, hidden_dim, num_levels=3, num_buckets_per_level=4, beam_width=4
    ):
        super(HierarchicalRouter, self).__init__()
        self.num_levels = num_levels
        self.num_buckets_per_level = num_buckets_per_level # *self.num_buckets_per_level
        self.route = nn.ParameterList([nn.Parameter(torch.rand(num_buckets_per_level**l,input_dim, num_buckets_per_level)) for l in range(self.num_levels)])
        self.Ps = [num_buckets_per_level**l for l in range(num_levels)]
        self.beam_width = beam_width
        sum = 0
        self.offsets = []
        for p in self.Ps:
            sum+=p
            self.offsets.append(sum)
        self.counts = None
        self.epoch=0

    def get_key_assignment(self, k, level, tau):
        batch_size, seq_len, inp_dim = k.size()
        b_s = batch_size * seq_len
        k = k.view(b_s, 1, inp_dim)
        route_prob_all = torch.ones((b_s,1,1), device=k.device)
        k_sorted = k.view(batch_size, 1, seq_len, inp_dim)
        original_ind = torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
        for l in range(level):   # self.num_levels
            logits = k_sorted @ self.route[l]
            route_prob_l_hard = F.gumbel_softmax(logits, dim=-1, tau=tau, hard=True) 
            _, route_max_ind = torch.max(route_prob_l_hard, dim=-1,keepdim=True) 

            num_buckets = self.num_buckets_per_level**(l+1)  
            
            ind_sorted, shift_ind= torch.sort(route_max_ind.squeeze(-1), dim=-1, stable=True)
            original_ind = torch.gather(original_ind, -1,shift_ind)
            k_sorted = torch.gather(k.view(batch_size, seq_len,-1), 1, original_ind.view(batch_size,seq_len,1).expand(-1, -1, k.shape[-1]))

            ind_sorted = ind_sorted.view(batch_size, num_buckets, seq_len // num_buckets)
            original_ind = original_ind.view(batch_size, num_buckets, seq_len // num_buckets)

            k_sorted = k_sorted.view(batch_size, num_buckets, seq_len // num_buckets, -1) #.detach()

        return original_ind
    
    def test(self, q, k, tau=0.01):
        L = k.shape[1] 
        if L <= 128:
            level=1; beam_width=2
        elif L == 512 or L == 256:
            level=2; beam_width=4
        elif L == 1024:
            level=3; beam_width=8
        elif L >= 2048:
            level=3; beam_width=8
        else:
            breakpoint()

        key_assignment = self.get_key_assignment(k, level, tau)
        route_flat = torch.cat([param for param in self.route], dim=0)
        if self.counts == None:
            self.offsets = torch.tensor([0] + self.offsets, dtype=torch.int32,device=q.device)
            self.counts  = torch.tensor(self.Ps, dtype=torch.int32,device=q.device)
        query_assignment_tri = hierarchical_search_triton(q, route_flat, self.offsets, self.counts, beam_width, level)  # self.beam_widthself.num_levels
        query_assignment = query_assignment_tri.view(q.shape[0],q.shape[1],-1)
        return query_assignment, key_assignment  
    
    def forward(self, q, k, tau=0.01,eps=1e-12):
        self.epoch+=1
        x = torch.cat([q, k], dim=0)
        batch_size, seq_len, inp_dim = x.size()
        b_s = batch_size * seq_len
        x = x.view(b_s, 1, inp_dim)
        loss = 0
        x_sorted = x.view(batch_size, 1, seq_len, inp_dim)
        original_ind = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
        for l in range(self.num_levels):
            logits = x_sorted @ self.route[l]
            route_prob_l = torch.softmax(logits, dim=-1) # p(li | l_i-1, ..., l_1, x)
            route_prob_l_hard = F.gumbel_softmax(logits, dim=-1, tau=tau, hard=True) 
            _, route_max_ind = torch.max(route_prob_l_hard, dim=-1,keepdim=True)

            route_prob_avg = torch.mean(route_prob_l_hard, -2)
            term1 = torch.mean(route_prob_avg* torch.log(route_prob_avg+eps))
            term2 = -torch.mean(route_prob_l* torch.log(route_prob_l+eps))
            mi_loss = term1 + term2 
            loss += mi_loss

            num_buckets = self.num_buckets_per_level**(l+1)  
            # Routing probabilities at each level
            ind_sorted, shift_ind= torch.sort(route_max_ind.squeeze(-1), dim=-1, stable=True)
            original_ind = torch.gather(original_ind, -1,shift_ind)
            # No need to calculate prob: prob_sorted = torch.gather(prob, -1, original_ind)
            x_sorted = torch.gather(x.view(batch_size, seq_len,-1), 1, original_ind.view(batch_size,seq_len,1).expand(-1, -1, x.shape[-1]))
            # chunkwise
            # ind_sorted: sorted assignment results; original_ind: sorted results corresponding to original index 
            ind_sorted = ind_sorted.view(batch_size, num_buckets, seq_len // num_buckets)
            original_ind = original_ind.view(batch_size, num_buckets, seq_len // num_buckets)
            # q k all been sorted, so 
            x_sorted = x_sorted.view(batch_size, num_buckets, seq_len // num_buckets, -1)  # .detach()
            
        return loss

class RoutingTopKAttention(nn.Module):
    def __init__(self, config, routing_config, GPTNeoXAttention):
        super().__init__()
        self.config = config
        self.num_attention_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        if self.hidden_size % self.num_attention_heads != 0:
            raise ValueError(
                "The hidden size is not divisble by the number of attention heads! Make sure to update them"
            )
        self.head_size = self.hidden_size // self.num_attention_heads
        self.rotary_ndims = int(self.head_size * config.rotary_pct)
        self.rope_theta = config.rotary_emb_base
        self._init_bias(config.max_position_embeddings)

        self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
        self.rotary_emb = GPTNeoXAttention.rotary_emb

        self.norm_factor = self.head_size**-0.5
        self.attention_dropout = nn.Dropout(config.attention_dropout)
        self.is_causal = True
        self.layer_idx = GPTNeoXAttention.layer_idx

        self.G = 16  # default GQA group size = 16 
        self.num_kv_heads = self.num_attention_heads // self.G
        self.query_key_value = nn.Linear(config.hidden_size, self.head_size * self.num_attention_heads+self.num_kv_heads * self.head_size*2, bias=config.attention_bias)

        self.g_proj = nn.Linear(self.hidden_size, self.num_attention_heads*2, bias=False)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
        self.sink = nn.Parameter(
            torch.rand(self.num_kv_heads, self.head_size)
        )

        self.dense.load_state_dict(GPTNeoXAttention.dense.state_dict())
        # hierarchical router for approximate-nearest-neighbors search
        self.local_window_size = routing_config.local_window_size
        self.router = HierarchicalRouter(
            self.head_size,
            self.head_size,
            beam_width=routing_config.beam_width,
            num_levels=routing_config.num_levels,
        )

    def _init_bias(self, max_positions, device=None):
        self.register_buffer(
            "bias",
            torch.tril(
                torch.ones((max_positions, max_positions), dtype=torch.bool)
            ).view(1, 1, max_positions, max_positions),
            persistent=False,
        )
        if device is not None:
            self.bias = self.bias.to(device)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        attention_mask: torch.FloatTensor,
        position_ids: torch.LongTensor,
        head_mask: Optional[torch.FloatTensor] = None,
        layer_past: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        padding_mask: Optional[torch.Tensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[
            Tuple[torch.Tensor, torch.Tensor]
        ] = None,  # will become mandatory in v4.46
    ):

        # Apply attention-specific projections and rope
        padding = True
        if padding:
            b, L, d = hidden_states.shape
            # ------- compute next power of 2 ≥ L -------
            next_pow2 = 1 << (L - 1).bit_length()  # e.g. 513 → 1024, 512 → 512

            total_num_buckets = (
                self.router.num_buckets_per_level**self.router.num_levels
            )

            while next_pow2 < total_num_buckets:

                next_pow2 = 1 << (next_pow2).bit_length()

            if next_pow2 != L:  # already a power of 2 → nothing to do
                pad_len = next_pow2 - L
                # ------- generate Gaussian noise on same device / dtype -------
                noise = torch.randn(
                    b,
                    pad_len,
                    d,
                    dtype=hidden_states.dtype,
                    device=hidden_states.device,
                )
                # ------- concatenate along seq_len dimension (axis 1) -------
                hidden_states = torch.cat([hidden_states, noise], dim=1)
                position_ids = torch.arange(0, next_pow2).reshape(1, -1).to(hidden_states.device)

        query, key, value, present = self._attn_projections_and_rope(
            hidden_states=hidden_states,
            position_ids=position_ids,
            layer_past=layer_past,
            use_cache=use_cache,
            position_embeddings=position_embeddings,
        )

        # Compute attention q k v: [24, 12, 2048, 64]
        # attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 
        B, HQ, seq_len, hdim = query.shape
        b_h = B * HQ
        H = self.num_kv_heads
        query = query.flatten(0,1)
        key = key.flatten(0,1)
        value = value.flatten(0,1)
        reg_loss = 0

        if self.training:
            reg_loss = self.router(query, key)
            # loss.backward(retain_graph=True)
        with torch.no_grad():
            query_idx, key_idx = self.router.test(query.view(B, H, HQ//H, seq_len, hdim).mean(dim=2).flatten(0, 1), key)  # self.router.test(query, key)

        beam_width = query_idx.size(-1)
        _, num_bucket, N_sample_per_bucket = key_idx.shape

        bucket_keys = torch.gather(key, dim=1, index=key_idx.flatten(1,2).unsqueeze(-1).expand(-1, -1, hdim)).view(b_h//self.G, num_bucket* N_sample_per_bucket, hdim)#.contiguous() k.view(b_h, num_bucket, N_sample_per_bucket, hdim)  # 
        bucket_values = torch.gather(value, dim=1, index=key_idx.flatten(1,2).unsqueeze(-1).expand(-1, -1, hdim)).view(b_h//self.G, num_bucket* N_sample_per_bucket, hdim)#.contiguous() v.view(b_h, num_bucket, N_sample_per_bucket, hdim) #  

        query = query.view(B, HQ, seq_len, hdim).transpose(1, 2).to(value.dtype)
        key = key.view(B, H, seq_len, hdim).transpose(1, 2).to(value.dtype)
        value = value.view(B, H, seq_len, hdim).transpose(1, 2)

        g_topk =  F.softmax(self.g_proj(hidden_states).view(B, seq_len, HQ, 2),dim=-1) 
        
        bucket_keys = bucket_keys.view(B, H, seq_len, hdim).transpose(1, 2).to(value.dtype)
        bucket_values = bucket_values.view(B, H, seq_len, hdim).transpose(1, 2)
                        # bucket_values [2, 4096, 4, 64]
        query_idx = query_idx.view(B, H, seq_len, -1).transpose(1, 2)  # HQ -> H, GQA
        key_idx = key_idx.view(B, H, key_idx.shape[1]* key_idx.shape[2]).transpose(1, 2)

        attn_output = parallel_hir(query, key, value, g_topk, bucket_keys, bucket_values, q_indices=query_idx, k_indices=key_idx, block_size=N_sample_per_bucket, window_size=self.local_window_size)  # block_size=64
        attn_output = attn_output + g_topk[..., [2]] * self.sink.repeat_interleave(self.G, dim=0)[None, None, ...] 
        attn_output = attn_output.transpose(1, 2)  #view(b, HQ, seq_len, hdim)  # Reshape outputs [24, 12, 2048, 64]
        attn_weights = None
        attn_output = self._merge_heads(
            attn_output, self.num_attention_heads, self.head_size
        )
        attn_output = self.dense(attn_output)

        outputs = (attn_output[:, :L, :], reg_loss)

        return outputs

    @classmethod
    def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Splits hidden dim into attn_head_size and num_attention_heads
        """
        # tensor: [bs, seq_len, hidden_size]
        new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(new_shape)
        # -> [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3)
        return tensor

    @classmethod
    def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden dim
        """
        # tensor [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3).contiguous()
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(
            tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
        )
        # -> [bs, seq_len, hidden_size]
        return tensor

    def _attn_projections_and_rope(
        self,
        hidden_states: torch.FloatTensor,
        position_ids: torch.LongTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[
            Tuple[torch.Tensor, torch.Tensor]
        ] = None,  # will become mandatory in v4.46
    ):
        # Compute QKV
        # Attention heads [batch, seq_len, hidden_size]
        #   --> [batch, seq_len, (np * 3 * head_size)]
        qkv = self.query_key_value(hidden_states)

        # [batch, seq_len, (num_heads * 3 * head_size)]
        #   --> [batch, seq_len, num_heads, 3 * head_size]
        new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads+self.num_kv_heads*2, self.head_size)
        qkv = qkv.view(*new_qkv_shape)

        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
        query = qkv[..., : self.num_attention_heads, :].permute(0, 2, 1, 3)
        key = qkv[..., self.num_attention_heads : self.num_attention_heads+self.num_kv_heads, :].permute(0, 2, 1, 3)
        value = qkv[..., self.num_attention_heads+self.num_kv_heads:, :].permute(0, 2, 1, 3)

        # Compute rotary embeddings on rotary_ndims
        query_rot = query[..., : self.rotary_ndims]
        query_pass = query[..., self.rotary_ndims :]
        key_rot = key[..., : self.rotary_ndims]
        key_pass = key[..., self.rotary_ndims :]

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value, position_ids)
        else:
            cos, sin = position_embeddings
 
        query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
        query = torch.cat((query, query_pass), dim=-1)
        key = torch.cat((key, key_pass), dim=-1)

        # Cache QKV values
        if layer_past is not None:
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "partial_rotation_size": self.rotary_ndims,
                "cache_position": cache_position,
            }
            key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)

        return query, key, value, layer_past

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
        # compute causal mask from causal mask buffer
        batch_size, num_attention_heads, query_length, attn_head_size = query.size()
        key_length = key.size(-2)

        # dynamically increase the causal mask with the key length, if needed.
        if key_length > self.bias.shape[-1]:
            self._init_bias(key_length, device=key.device)
        causal_mask = self.bias[
            :, :, key_length - query_length : key_length, :key_length
        ]

        query = query.view(
            batch_size * num_attention_heads, query_length, attn_head_size
        )
        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
        attn_scores = torch.zeros(
            batch_size * num_attention_heads,
            query_length,
            key_length,
            dtype=query.dtype,
            device=key.device,
        )
        attn_scores = torch.baddbmm(
            attn_scores,
            query,
            key.transpose(1, 2),
            beta=1.0,
            alpha=self.norm_factor,
        )
        attn_scores = attn_scores.view(
            batch_size, num_attention_heads, query_length, key_length
        )

        mask_value = torch.finfo(attn_scores.dtype).min
        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(
            attn_scores.device
        )
        attn_scores = torch.where(causal_mask, attn_scores, mask_value)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key.shape[-2]]
            attn_scores = attn_scores + causal_mask

        attn_weights = nn.functional.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.to(value.dtype)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_weights = self.attention_dropout(attn_weights)

        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
