import triton
import triton.language as tl
from typing import Callable, Optional, Tuple, Dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

############################################ triton codes ######################################################

@triton.autotune(configs=[
    triton.Config(kwargs={'block_size': 32, 'block_size_topk': 16}, num_warps=32, num_stages=3),
  ],
  key=['seq_len'] # the two above configs will be evaluated anytime
                 # the value of x_size changes
)

@triton.jit
def sparse_topk_attention_kernel(
    q_ptr, k_ptr, v_ptr, gamma_sq_prt, topk_indices_ptr, mask_ptr, output_ptr, z_ptr, delta_ptr, diff_ptr, delta_avg_ptr, diff_avg_ptr,
    b_h:tl.constexpr , num_head:tl.constexpr, seq_len:tl.constexpr , topk:tl.constexpr,
    qk_dim: tl.constexpr, v_dim: tl.constexpr,
    stride_q_bs, stride_q_sl, stride_q_hdim,
    stride_k_bs, stride_k_sl, stride_k_hdim,
    stride_v_bs, stride_v_sl, stride_v_hdim,
    stride_idx_bs, stride_idx_sl, stride_idx_k,
    stride_mask_bs, stride_mask_sl, stride_mask_k,
    stride_out_bs, stride_out_sl, stride_out_hdim,
    stride_z_bs, stride_z_sl,
    stride_delta_bs, stride_delta_sl, stride_delta_k,
    stride_diff_bs, stride_diff_sl, stride_diff_k, stride_diff_hdim,
    stride_delta_avg_bs, stride_delta_avg_sl,
    stride_diff_avg_bs, stride_diff_avg_sl, stride_diff_avg_k, 
    qkdim_next_power_of_2: tl.constexpr, topk_next_power_of_2:tl.constexpr, causal,
    # Meta-parameters
    block_size: tl.constexpr, block_size_topk: tl.constexpr,
): # 
    # Calculate position of the block in the batch and sequence dimension
    b_h_idx = tl.program_id(0)
    block_idx = tl.program_id(1)
    # Compute the starting position for the current block
    block_start = block_idx * block_size
    head_idx = b_h_idx // (b_h//num_head)
    if block_idx == (seq_len // block_size):
        block_mask = tl.arange(0, block_size) < (seq_len % block_size)
    else:
        block_mask = tl.arange(0, block_size) < block_size
    # Offsets for accessing the query vector `q` of shape (block_size, hdim)
    qk_mask = tl.arange(0, qkdim_next_power_of_2) < qk_dim 
    q_offset = b_h_idx * stride_q_bs + block_start * stride_q_sl
    q = tl.load(q_ptr + q_offset + tl.arange(0, block_size)[:, None] * stride_q_sl \
                + tl.arange(0, qkdim_next_power_of_2)[None, :], mask=qk_mask[None, :]*block_mask[:, None])
    k_offset = b_h_idx * stride_k_bs + block_start * stride_k_sl
    k = tl.load(k_ptr + k_offset + tl.arange(0, block_size)[:, None] * stride_k_sl \
                + tl.arange(0, qkdim_next_power_of_2)[None, :], mask=qk_mask[None, :]*block_mask[:, None])
    v_offset = b_h_idx * stride_v_bs + block_start * stride_v_sl
    v = tl.load(v_ptr + v_offset + tl.arange(0, block_size)[:, None] * stride_v_sl \
                + tl.arange(0, v_dim)[None, :], mask=block_mask[:, None])

    v_block_avg = tl.cumsum(v, axis=0) / (tl.arange(0, block_size)[:, None]+1.)
    k_block_avg = tl.cumsum(k, axis=0) / (tl.arange(0, block_size)[:, None]+1.)
    # Offsets for accessing the top-k indices in the `k` and `v` tensors
    topk_offset = b_h_idx * stride_idx_bs + block_start * stride_idx_sl
    # topk_next_power_of_2 = topk # tl.constexpr(triton.next_power_of_2(topk))
    topk_mask = tl.arange(0, topk_next_power_of_2) < topk 
    # Gather the `k` and `v` vectors using the top-k indices
    indices = tl.load(topk_indices_ptr + topk_offset + tl.arange(0, block_size)[:, None] * stride_idx_sl \
                      + tl.arange(0, topk_next_power_of_2), mask=topk_mask[None,:]*block_mask[:, None])
    # current_block_indices = block_start*block_size + tl.arange(0, block_size)
    # causal_mask = indices > current_block_indices[:, None]
    mask_offset = b_h_idx * stride_mask_bs + block_start * stride_mask_sl
    causal_mask = tl.load(mask_ptr + mask_offset + tl.arange(0, block_size)[:, None] * stride_mask_sl \
                      + tl.arange(0, topk_next_power_of_2), mask=topk_mask[None,:]*block_mask[:, None])

    # Gather the `k` and `v` vectors using the top-k indices
    k_offsets = b_h_idx * stride_k_bs + indices * stride_k_sl
    v_offsets = b_h_idx * stride_v_bs + indices * stride_v_sl

    k_gathered = tl.load(k_ptr + k_offsets[:, :, None] + tl.arange(0, qkdim_next_power_of_2)[None, None, :], mask=qk_mask[None, None, :]*topk_mask[None,:,None]*block_mask[:, None, None])
    v_gathered = tl.load(v_ptr + v_offsets[:, :, None] + tl.arange(0, v_dim)[None, None, :], mask=block_mask[:, None, None]*topk_mask[None,:,None])

    # Compute Euclidean distances and scores
    diff = q[:, None, :] - k_gathered  # (block_size, topk, qk_dim) topk better set to 2^x
    dist_sq = tl.sum(diff*diff, axis=2)
    # smoothing / avg tokens into kv-cache: v_block_avg k_block_avg
    diff_avg = q - k_block_avg
    dist_sq_avg = tl.sum(diff_avg*diff_avg, axis=1)

    delta_ij = dist_sq + tl.load(gamma_sq_prt+head_idx) + 100000.0 * causal_mask
    delta_avg = dist_sq_avg + tl.load(gamma_sq_prt+head_idx)

    s_ij = 1.0 / delta_ij # Inverse Euclidean distance
    s_avg = 1.0 / delta_avg

    z_i = tl.sum(s_ij, axis=1)
    z_i = z_i + s_avg

    a_ij = s_ij / z_i[:, None]
    a_avg = s_avg / z_i
    # Compute the output, weighted sum of `v` vectors
    o_i = tl.sum(a_ij[:, :, None] * v_gathered , axis=1) + a_avg[:, None] * v_block_avg  # * causal_mask[:, :, None]

    # Store the result in output tensor
    out_offset = b_h_idx * stride_out_bs + block_start * stride_out_sl
    tl.store(output_ptr + out_offset + tl.arange(0, block_size)[:, None] * stride_out_sl + tl.arange(0, v_dim)[None, :], o_i, mask=block_mask[:, None])

    # Store intermediate values for backward pass
    z_offset = b_h_idx * stride_z_bs + block_start * stride_z_sl
    tl.store(z_ptr + z_offset + tl.arange(0, block_size), z_i, mask=block_mask)
    delta_offset = b_h_idx * stride_delta_bs + block_start * stride_delta_sl
    tl.store(delta_ptr + delta_offset + tl.arange(0, block_size)[:, None] * stride_delta_sl + tl.arange(0, topk_next_power_of_2), delta_ij,mask=block_mask[:, None]*topk_mask[None,:])
    diff_offset = b_h_idx * stride_diff_bs + block_start * stride_diff_sl 
    tl.store(diff_ptr + diff_offset + tl.arange(0, block_size)[:, None, None] *stride_diff_sl  + tl.arange(0, topk_next_power_of_2)[None,:, None] * stride_diff_k+ tl.arange(0, qkdim_next_power_of_2)[None, None, :],\
         diff,mask=block_mask[:, None, None]*qk_mask[None, None, :]*topk_mask[None,:,None])
    # stride_diff_bs, stride_diff_sl, stride_diff_k, stride_diff_hdim
    
    # Compute and store average values for backward pass
    diff_avg_offset = b_h_idx * stride_diff_avg_bs + block_start * stride_diff_avg_sl
    tl.store(diff_avg_ptr + diff_avg_offset + tl.arange(0, block_size)[:, None] * stride_diff_avg_sl + tl.arange(0, qkdim_next_power_of_2)[None, :], diff_avg, mask=block_mask[:, None]*qk_mask[None, :])
    delta_avg_offset = b_h_idx * stride_delta_avg_bs + block_start * stride_delta_avg_sl
    tl.store(delta_avg_ptr + delta_avg_offset + tl.arange(0, block_size), delta_avg,mask=block_mask)


@triton.autotune(configs=[
    triton.Config(kwargs={'block_size': 32, 'block_size_topk': 16}, num_warps=32, num_stages=3),
  ],
  key=['seq_len'] # the two above configs will be evaluated anytime
                 # the value of x_size changes
)
@triton.jit
def sparse_topk_attention_backward_kernel(
    v_ptr, gamma_sq_ptr, topk_indices_ptr, grad_output_ptr, output_prt, z_ptr, delta_ptr, diff_ptr, delta_avg_ptr, diff_avg_ptr,
    grad_q_ptr, grad_k_ptr, grad_v_ptr, grad_gamma_sq_ptr,
    qk_dim: tl.constexpr, v_dim: tl.constexpr,
    b_h: tl.constexpr, num_head:tl.constexpr, seq_len: tl.constexpr, topk: tl.constexpr, 
    stride_q_bs, stride_q_sl, stride_q_hdim,
    stride_k_bs, stride_k_sl, stride_k_hdim,
    stride_v_bs, stride_v_sl, stride_v_hdim,
    stride_idx_bs, stride_idx_sl, stride_idx_k,
    stride_out_bs, stride_out_sl, stride_out_hdim,
    stride_z_bs, stride_z_sl,
    stride_delta_bs, stride_delta_sl, stride_delta_k,
    stride_diff_bs, stride_diff_sl, stride_diff_k, stride_diff_hdim,
    stride_delta_avg_bs, stride_delta_avg_sl,
    stride_diff_avg_bs, stride_diff_avg_sl, stride_diff_avg_k, 
    qkdim_next_power_of_2: tl.constexpr, topk_next_power_of_2:tl.constexpr,
    # Meta-parameters
    block_size: tl.constexpr, block_size_topk: tl.constexpr
):  # 
    # Calculate position of the block in the batch and sequence dimension
    b_h_idx = tl.program_id(0)
    block_idx = tl.program_id(1)  # seq_idx  

    # Compute the starting position for the current block
    # block_size = META['block_size']
    block_start = block_idx * block_size
    head_idx = b_h_idx // (b_h//num_head)
    if block_idx == (seq_len // block_size):
        block_mask = tl.arange(0, block_size) < (seq_len % block_size)
    else:
        block_mask = tl.arange(0, block_size) < block_size
    # Offsets for accessing the query vector `q` of shape (block_size, hdim)
    # qkdim_next_power_of_2 = tl.constexpr(16)  # triton requires qk_dim >= 16; triton.next_power_of_2(qk_dim)
    qk_mask = tl.arange(0, qkdim_next_power_of_2) < qk_dim 
    q_offset = b_h_idx * stride_q_bs + block_start * stride_q_sl
    """ not need for get q
    q = tl.load(q_ptr + q_offset + tl.arange(0, block_size)[:, None] * stride_q_sl + tl.arange(0, qkdim_next_power_of_2)[None, :], mask=qk_mask[None, :])
    """
    # breakpoint()
    # Load average values for backward pass
    delta_avg_offset = b_h_idx * stride_delta_avg_bs + block_start * stride_delta_avg_sl
    delta_avg = tl.load(delta_avg_ptr + delta_avg_offset + tl.arange(0, block_size) * stride_delta_avg_sl,mask=block_mask)
    diff_avg_offset = b_h_idx * stride_diff_avg_bs + block_start * stride_diff_avg_sl
    diff_avg = tl.load(diff_avg_ptr + diff_avg_offset + tl.arange(0, block_size)[:, None] * stride_diff_avg_sl + tl.arange(0,qkdim_next_power_of_2)[None, :],mask=block_mask[:, None]*qk_mask[None, :])

    # Load stored intermediate values (we load query index within current block)
    z_offset = b_h_idx * stride_z_bs + block_start * stride_z_sl
    z_i = tl.load(z_ptr + z_offset + tl.arange(0, block_size), mask=block_mask)
    causal_avg_mask = tl.arange(0, block_size)[:, None] >= tl.arange(0, block_size)[None, :]
    
    # Load gradient of the output
    grad_out_offset = b_h_idx * stride_out_bs + block_start * stride_out_sl
    grad_output = tl.load(grad_output_ptr + grad_out_offset + tl.arange(0, block_size)[:, None] * stride_out_sl + tl.arange(0, v_dim)[None, :], mask=block_mask[:, None])
    o_i = tl.load(output_prt + grad_out_offset + tl.arange(0, block_size)[:, None] * stride_out_sl + tl.arange(0, v_dim)[None, :], mask=block_mask[:, None])
    grad_o_i = grad_output  # Gradient of the loss with respect to output

    # Compute gradients with respect to `q`, `k`, `v`, and `gamma_sq`
    # grad of v avg
    s_avg = 1.0 / (delta_avg + 1e-6)
    a_avg = s_avg / (z_i + 1e-6)
    # grad_o_i (seq_len, v_dim)
    dl_div_dv_avg = tl.dot(tl.trans(a_avg[:, None] * grad_o_i, 1, 0), causal_avg_mask * (1 / (1.0 + tl.arange(0, block_size)[:, None]))) #  tl.sum(grad_o_i*a_avg[:, None], axis=0)[None, :] * (1 / (1.0 + tl.arange(0, block_size)[:, None]))
    dl_div_dv_avg = tl.trans(dl_div_dv_avg, 1, 0)
    v_offset = b_h_idx * stride_v_bs + block_start * stride_v_sl
    tl.atomic_add(grad_v_ptr + v_offset + tl.arange(0, block_size)[:, None] * stride_v_sl \
                + tl.arange(0, v_dim)[None, :], dl_div_dv_avg, mask=block_mask[:, None])
    v_offset = b_h_idx * stride_v_bs + block_start * stride_v_sl
    v = tl.load(v_ptr + v_offset + tl.arange(0, block_size)[:, None] * stride_v_sl \
                + tl.arange(0, v_dim)[None, :], mask=block_mask[:, None])
    v_block_avg = tl.cumsum(v, axis=0) / (tl.arange(0, block_size)[:, None]+1.)

    # grad of q avg
    v_avg_minus_o_div_z = (v_block_avg - o_i) / (z_i[:, None]+ 1e-6)  # ( block_size, topk, v_dim )  (block_size, topk, qk_dim)
    del_avg_sq = (delta_avg*delta_avg)
    q_minus_k_avg_div_delta_2 = diff_avg / (del_avg_sq[:, None]+ 1e-6) # ( block_size, topk, qk_dim ) 
    interm_results = v_avg_minus_o_div_z[:, :, None] * q_minus_k_avg_div_delta_2[:, None, :]   # (block_size, v_dim, qk_dim)
    dl_div_dq_avg = -2 * tl.sum(grad_output[:, :, None]* interm_results, axis=1)  # (block_size, v_dim) @ (block_size, v_dim, qk_dim)
    tl.atomic_add(grad_q_ptr + q_offset + tl.arange(0, block_size)[:, None] * stride_q_sl + tl.arange(0, qkdim_next_power_of_2)[None, :], dl_div_dq_avg, mask=block_mask[:, None]*qk_mask[None, :]) #grad_q) 
    
    # grad of k avg
    # (block_size, v_dim) x ( block_size, topk, v_dim )  x ( block_size, topk, qk_dim ) 
    # (block_size, 1, v_dim) x ( block_size, topk, v_dim ) -> (block_size, topk)
    dl_div_dk_avg = tl.sum(grad_o_i * v_avg_minus_o_div_z, axis=1)
    # (block_size, topk) x ( block_size, topk, qk_dim ) -> (topk, qk_dim)
    dl_div_dk_avg = 2 * tl.dot( tl.trans(dl_div_dk_avg[ :, None] * q_minus_k_avg_div_delta_2, 1, 0), causal_avg_mask * (1 / (1.0 + tl.arange(0, block_size)[:, None]))) 
    dl_div_dk_avg = tl.trans(dl_div_dk_avg, 1, 0)
    
    k_offset = b_h_idx * stride_k_bs + block_start * stride_k_sl
    tl.atomic_add(grad_k_ptr + k_offset + tl.arange(0, block_size)[:, None] * stride_k_sl \
                + tl.arange(0, qkdim_next_power_of_2)[None, :], dl_div_dk_avg, mask=qk_mask[None, :]*block_mask[:, None])


    dl_div_epsi_avg =tl.sum(grad_o_i* tl.sum(v_avg_minus_o_div_z * (1.0/(del_avg_sq[ :, None]+1e-6)), axis=1)[:, None]) # first sum over v_dim, then sum over i
    tl.atomic_add(grad_gamma_sq_ptr + head_idx, dl_div_epsi_avg)
    # Offsets for accessing the top-k indices in the `k` and `v` tensors
    topk_offset = b_h_idx * stride_idx_bs + block_start * stride_idx_sl
    # topk_next_power_of_2 = tl.constexpr(triton.next_power_of_2(topk))
    topk_mask = tl.arange(0, topk_next_power_of_2) < topk 
    topk_chunk_offset = 0
    while topk_chunk_offset < topk:
        indices = tl.load(topk_indices_ptr + topk_offset + tl.arange(0, block_size)[:, None] * stride_idx_sl + topk_chunk_offset + tl.arange(0 ,block_size_topk), mask=block_mask[:, None])  # topk_mask[None,:]*

        k_offsets = b_h_idx * stride_k_bs + indices * stride_k_sl
        v_offsets = b_h_idx * stride_v_bs + indices * stride_v_sl
        v_gathered = tl.load(v_ptr + v_offsets[:, :, None] + tl.arange(0, v_dim)[None, None, :], mask=block_mask[:, None, None])

        # grad of v
        delta_offset = b_h_idx * stride_delta_bs + block_start * stride_delta_sl
        delta_ij = tl.load(delta_ptr + delta_offset + tl.arange(0, block_size)[:, None] * stride_delta_sl + topk_chunk_offset + tl.arange(0, block_size_topk), mask=block_mask[:, None]) #*topk_mask[None,:])
        diff_offset = b_h_idx * stride_diff_bs + block_start * stride_diff_sl
        # TODO: test here     qk_mask qkdim_next_power_of_2                                                                 tl.arange(0, topk_next_power_of_2)[None,:, None] * stride_diff_k
        diff = tl.load(diff_ptr + diff_offset + tl.arange(0, block_size)[:, None, None] *stride_diff_sl  + (topk_chunk_offset + tl.arange(0 ,block_size_topk)[None,:, None]) * stride_diff_k+ tl.arange(0, qkdim_next_power_of_2)[None, None, :],\
                    mask=block_mask[:, None, None]*qk_mask[None, None, :])  # *topk_mask[None,:,None])

        s_ij = 1.0 / (delta_ij + 1e-6) # Inverse Euclidean distance
        a_ij = s_ij / (z_i[:, None] + 1e-6)  # ( block_size, topk)
        dl_div_dv = a_ij[:, :, None] * grad_o_i[:, None, :] # ( block_size, topk) (block_size, v_dim) -> (topk, v_dim)
        # tl.atomic_add(grad_v_ptr + (v_offsets[:, :, None] + tl.arange(0, v_dim)[None, None, :]), dl_div_dv, mask=block_mask[:, None, None]*topk_mask[None,:,None])  # .flatten()
        # grad of q
        v_minus_o_div_z = (v_gathered - o_i[:, None, :]) / (z_i[:, None, None]+ 1e-6)  # ( block_size, topk, v_dim )  (block_size, topk, qk_dim)
        del_ij_sq = (delta_ij*delta_ij)
        q_minus_k_div_delta_2 = diff / (del_ij_sq[:, :, None]+1e-6)  # ( block_size, topk, qk_dim ) 
        interm_results = tl.dot(tl.trans(v_minus_o_div_z,0,2,1), q_minus_k_div_delta_2)   # (block_size, v_dim, qk_dim)
        dl_div_dq = -2 * tl.sum(grad_output[:, :, None]* interm_results, axis=1)  # (block_size, v_dim) @ (block_size, v_dim, qk_dim)

        # grad of k  
        # (block_size, v_dim) x ( block_size, topk, v_dim )  x ( block_size, topk, qk_dim ) 
        # (block_size, 1, v_dim) x ( block_size, topk, v_dim ) -> (block_size, topk)
        dl_div_dk = tl.sum(grad_o_i[:, None, :] * v_minus_o_div_z, axis=2)
        # (block_size, topk) x ( block_size, topk, qk_dim ) -> (topk, qk_dim)
        dl_div_dk = 2 *  dl_div_dk[:, :, None] * q_minus_k_div_delta_2

        # grad of gamma_sq / epsilon
        # (block_size, v_dim)  ( block_size, topk, v_dim )  (block,  topk)
        # dl_div_epsi = tl.sum( grad_o_i[:, :, None] * tl.dot(tl.trans(v_minus_o_div_z, 0, 2, 1), del_ij_sq[:, :, None]) )  # ->( block_size,  v_dim, 1) 
        # ( block_size, topk)  {sum over j (block_size, v_dim) }
        dl_div_epsi =tl.sum(grad_o_i* tl.sum(v_minus_o_div_z * (1.0/(del_ij_sq[:, :, None]+1e-6)), axis=1)) # first sum over v_dim, then sum over i

        # Store the gradients
        # store grad of q, k, v, gamma
        tl.atomic_add(grad_q_ptr + q_offset + tl.arange(0, block_size)[:, None] * stride_q_sl + tl.arange(0, qkdim_next_power_of_2)[None, :], dl_div_dq, mask=block_mask[:, None]*qk_mask[None, :]) #grad_q) 
        tl.atomic_add(grad_k_ptr + k_offsets[:, :, None] + tl.arange(0, qkdim_next_power_of_2)[None, None, :], dl_div_dk,mask=(block_mask[:, None, None]*qk_mask[None, None, :])) # *topk_mask[None,:,None] #  .flatten) #grad_delta_ij
        tl.atomic_add(grad_v_ptr + (v_offsets[:, :, None] + tl.arange(0, v_dim)[None, None, :]), dl_div_dv, mask=block_mask[:, None, None])#*topk_mask[None,:,None])  # .flatten()
        tl.atomic_add(grad_gamma_sq_ptr + head_idx, dl_div_epsi)
        topk_chunk_offset += block_size_topk



class _sparse_topk_attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, topk_indices, gamma_sq, num_head, block_size, mask, causal=True):
            # Get dimensions
        b_h, seq_len, qk_dim = q.shape
        v_dim = v.shape[-1]
        batch_size = b_h // num_head
        topk = topk_indices.shape[-1]

        # Prepare output tensor
        output = torch.empty((b_h, seq_len, v_dim), device=q.device, dtype=q.dtype)
        z = torch.empty((b_h, seq_len), device=q.device, dtype=q.dtype)
        delta = torch.empty((b_h, seq_len, topk), device=q.device, dtype=q.dtype)
        diff = torch.empty((b_h, seq_len, topk, qk_dim), device=q.device, dtype=q.dtype)  # (4, 8, 2, 3)
        
        delta_avg = torch.empty((b_h, seq_len), device=q.device, dtype=q.dtype)
        diff_avg = torch.empty((b_h, seq_len, qk_dim), device=q.device, dtype=q.dtype)

        # Define grid size
        grid = lambda META: (b_h, triton.cdiv(seq_len, META['block_size']))  # TODO: seq_len might not be divisable by block_size
        # grid = (b_h, triton.cdiv(seq_len, block_size))
        # Launch Triton kernel
        sparse_topk_attention_kernel[grid](
            q, k, v, gamma_sq, topk_indices, mask, output, z, delta, diff,delta_avg, diff_avg,
            b_h=b_h, num_head=num_head, seq_len=seq_len, topk=topk,
            qk_dim=qk_dim, v_dim=v_dim,
            stride_q_bs=q.stride(0), stride_q_sl=q.stride(1), stride_q_hdim=q.stride(2),
            stride_k_bs=k.stride(0), stride_k_sl=k.stride(1), stride_k_hdim=k.stride(2),
            stride_v_bs=v.stride(0), stride_v_sl=v.stride(1), stride_v_hdim=v.stride(2),
            stride_idx_bs=topk_indices.stride(0), stride_idx_sl=topk_indices.stride(1), stride_idx_k=topk_indices.stride(2),
            stride_mask_bs=mask.stride(0), stride_mask_sl=mask.stride(1), stride_mask_k=mask.stride(2),
            stride_out_bs=output.stride(0), stride_out_sl=output.stride(1), stride_out_hdim=output.stride(2),
            stride_z_bs=z.stride(0), stride_z_sl=z.stride(1),
            stride_delta_bs=delta.stride(0), stride_delta_sl=delta.stride(1), stride_delta_k=delta.stride(2),
            stride_diff_bs=diff.stride(0), stride_diff_sl=diff.stride(1), stride_diff_k=diff.stride(2), stride_diff_hdim=diff.stride(3),
            stride_delta_avg_bs=delta_avg.stride(0), stride_delta_avg_sl=delta_avg.stride(1),
            stride_diff_avg_bs=diff_avg.stride(0), stride_diff_avg_sl=diff_avg.stride(1), stride_diff_avg_k=diff_avg.stride(2),
            qkdim_next_power_of_2=4,topk_next_power_of_2=triton.next_power_of_2(topk), causal=causal,
            # block_size=block_size
        )  # block_size=block_size, 
        ctx.save_for_backward(q, k, v, topk_indices, output, z, delta, diff, delta_avg, diff_avg, gamma_sq)
        ctx.grid = grid
        ctx.num_head = num_head
        ctx.block_size = block_size
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        q, k, v, topk_indices, output, z, delta, diff, delta_avg, diff_avg, gamma_sq = ctx.saved_tensors
        block_size = ctx.block_size
        num_head = ctx.num_head
        # Get dimensions
        b_h, seq_len, qk_dim = q.shape
        v_dim = v.shape[-1]
        topk = topk_indices.shape[-1]

        # Prepare gradient tensors
        grad_q = torch.zeros_like(q)
        grad_k = torch.zeros_like(k)
        grad_v = torch.zeros_like(v)
        grad_gamma_sq = torch.zeros_like(gamma_sq)
        # grad_output = grad_output.clone()
        # Define grid size
        grid = ctx.grid # lambda META: (b_h, triton.cdiv(seq_len, META['block_size']))

        # grid = (b_h, triton.cdiv(seq_len, block_size))  # grid = (batch_size, (seq_len + block_size - 1) // block_size)
        # print("backward kernel starts")
        # Launch Triton kernel for backward pass
        sparse_topk_attention_backward_kernel[grid](
            v, gamma_sq, topk_indices,  grad_output, output, z, delta, diff, delta_avg, diff_avg,
            grad_q, grad_k, grad_v, grad_gamma_sq,
            qk_dim=qk_dim, v_dim=v_dim,
            b_h=b_h, num_head=num_head, seq_len=seq_len, topk=topk,
            stride_q_bs=q.stride(0), stride_q_sl=q.stride(1), stride_q_hdim=q.stride(2),
            stride_k_bs=k.stride(0), stride_k_sl=k.stride(1), stride_k_hdim=k.stride(2),
            stride_v_bs=v.stride(0), stride_v_sl=v.stride(1), stride_v_hdim=v.stride(2),
            stride_idx_bs=topk_indices.stride(0), stride_idx_sl=topk_indices.stride(1), stride_idx_k=topk_indices.stride(2),
            stride_out_bs=output.stride(0), stride_out_sl=output.stride(1), stride_out_hdim=output.stride(2),
            stride_z_bs=z.stride(0), stride_z_sl=z.stride(1),
            stride_delta_bs=delta.stride(0), stride_delta_sl=delta.stride(1), stride_delta_k=delta.stride(2),
            stride_diff_bs=diff.stride(0), stride_diff_sl=diff.stride(1), stride_diff_k=diff.stride(2), stride_diff_hdim=diff.stride(3),
            stride_delta_avg_bs=delta_avg.stride(0), stride_delta_avg_sl=delta_avg.stride(1),
            stride_diff_avg_bs=diff_avg.stride(0), stride_diff_avg_sl=diff_avg.stride(1), stride_diff_avg_k=diff_avg.stride(2),
            qkdim_next_power_of_2=16,topk_next_power_of_2=triton.next_power_of_2(topk),
            # block_size=block_size, 
        )  # block_size=block_size, 
        return grad_q, grad_k, grad_v, None, grad_gamma_sq, None, None, None, None

############################################ triton codes ######################################################

# Define the z_order function
def z_order_optimized(x):
    # Normalize data to [0, 1]
    min_vals = x.min(dim=1, keepdim=True)[0]
    max_vals = x.max(dim=1, keepdim=True)[0]
    normalized_data = (x - min_vals) / (max_vals - min_vals)
 
    # Discretize to a fixed grid size
    grid_size = 4096
    discretized_data = (normalized_data * (grid_size - 1)).to(torch.int64) #32) #64) torch.iinfo(torch.int32).max

    # Interleave bits in a vectorized way across all dimensions
    shifts = torch.arange(x.size(-1), device=x.device, dtype=torch.int64)
    z = interleave_bits(discretized_data) << shifts.view(1, 1, -1)
    
    # Sum the interleaved bits across dimensions
    return   z[..., 0] | z[..., 1] | z[..., 2]# z.sum(dim=-1) #
 
# Define the interleave_bits function  /lustre07/scratch/absking/research_codes/llm/next-gen-attn/outputs/2024-09-18/14-32-42-220102/checkpoints/va
def interleave_bits(x):
    x = (x | (x << 16)) & 0x0000FFFF0000FFFF
    x = (x | (x << 8)) & 0x00FF00FF00FF00FF
    x = (x | (x << 4)) & 0x0F0F0F0F0F0F0F0F
    x = (x | (x << 2)) & 0x3333333333333333
    x = (x | (x << 1)) & 0x5555555555555555
    return x

class MLP(nn.Module):
    def __init__(
        self,
        d_model: int,
        out_dim: int,
        hidden_mult: int=1,
        activation: callable=F.elu,  # elu 18 epochs converge in final version before using gelu, but found elu is the best, converge in 20 epochs!  
        return_residual: bool=False,  # gelu converge in 32 epochs; relu best converge in 20 epochs!
        **kwargs
    ):
        super().__init__()
        in_features, out_features = d_model, out_dim
        hidden_features = d_model * hidden_mult
        self.return_residual = return_residual
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.activation = activation
        self.fc2 = nn.Linear(hidden_features, out_features)

    def forward(self, x):
        y = x
        if x.shape[-1] > 32:
            y = self.fc1(y)
            y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)

def mask_value(hash_vals_causal: torch.Tensor, mask2: torch.Tensor,hash_vals_sorted: torch.Tensor,seq_len:int,chunk_size:int,mask: torch.Tensor,hash_vals_causal_original_index: torch.Tensor,sorted_indices: torch.Tensor  ):
    hash_vals_causal[mask2] = hash_vals_sorted.unsqueeze(1).repeat(1, seq_len // chunk_size, 1)[mask] # expand(-1, seq_len // chunk_size, -1)
    hash_vals_causal_original_index[mask2] = sorted_indices.unsqueeze(1).repeat(1, seq_len // chunk_size, 1)[mask] #expand(-1, seq_len // chunk_size, -1)
    return hash_vals_causal,hash_vals_causal_original_index



@torch.jit.script
def gather_indices(input: torch.Tensor, mask: torch.Tensor):
    return torch.gather(input,-1, mask).flatten()

def chunk_causal_sort(hash_vals: torch.Tensor, chunk_size: int):
    # chunk_size = self.chunk_size
    batch_size, seq_len = hash_vals.shape[:2]
    # Sort the hash_vals tensor along the second dimension
    hash_vals_sorted, sorted_indices = torch.sort(hash_vals, dim=1)

    # Initialize the result tensors with large values (padding)
    hash_vals_causal = torch.full((batch_size, seq_len // chunk_size, seq_len),  torch.iinfo(hash_vals.dtype).max, dtype=hash_vals.dtype,device=hash_vals.device)  
    hash_vals_causal_original_index = torch.full((batch_size, seq_len // chunk_size, seq_len), -seq_len, dtype=sorted_indices.dtype,device=hash_vals.device)

    # Generate a mask for causal sorting
    indices = torch.arange(seq_len, device=hash_vals.device)
    chunk_positions = (torch.arange(1, seq_len // chunk_size + 1, device=hash_vals.device) * chunk_size)

    # mask keep the chunked index: index sorted index [2, 3, 1, 0] -> [1, 0], [2, 3, 1, 0]
    mask = sorted_indices.unsqueeze(1).expand(-1, seq_len // chunk_size, -1) < chunk_positions.unsqueeze(0).unsqueeze(-1)
    # mask keep the index values in final result: [0, 1, 2, 3] -> [0,1], [0, 1, 2, 3]
    mask2 = indices.unsqueeze(0).unsqueeze(0).expand(batch_size,seq_len // chunk_size, -1) < chunk_positions.unsqueeze(0).unsqueeze(-1)

    mask_indices = mask.nonzero(as_tuple=True)
    batch_indices, seq_indices, n_indices = mask_indices
    n_indices = n_indices.view(batch_size,-1)
    n_indices = n_indices +  (seq_len * torch.arange(batch_size, device=n_indices.device)).unsqueeze(1)

    hash_vals_causal_original_index[mask2] = torch.take(sorted_indices, n_indices.flatten()) #gather_indices(sorted_indices, n_indices) #torch.gather(sorted_indices,-1, n_indices).flatten() # sorted_indices[batch_indices, n_indices]# 
    hash_vals_causal[mask2]=torch.take(hash_vals_sorted, n_indices.flatten()) #  gather_indices(hash_vals, n_indices) # torch.gather(hash_vals, -1, n_indices).flatten() #hash_vals[batch_indices, n_indices]#   hash_vals_sorted.flatten().index_select(0, n_indices.flatten()) # hash_vals_sorted.unsqueeze(1).expand(-1, seq_len // chunk_size, -1)[mask] #hash_vals[batch_indices, n_indices]# torch.gather(hash_vals.unsqueeze(1).expand(-1, seq_len // chunk_size, -1).reshape(-1),0,mask.reshape(-1).nonzero().flatten()) #hash_vals[batch_indices, n_indices]
    return hash_vals_causal, hash_vals_causal_original_index, chunk_positions.unsqueeze(0).unsqueeze(-1)


# Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py
class PositionalEncoder(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoder(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=16384, pe_init=None):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        if pe_init is not None:
            self.pe = nn.Parameter(torch.empty(max_len, 1, d_model))
            nn.init.normal_(self.pe, 0, pe_init)
            # self.pe = pe.unsqueeze(1)
        else:
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0.0, max_len).unsqueeze(1)
            div_term = torch.exp(
                -math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model
            )
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            self.register_buffer("pe", pe)

        self.attn_mask = None

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
            lens: actual lengths of sequences
        Shape:
            x: [l_sequence, n_batch, d_model]
            Returns: [l_sequence, n_batch, d_model]
            attn_mask: [l_sequence, l_sequence]
            padding_mask:
        """
        x = x + self.pe[: x.size(-2)]
        return self.dropout(x)


class OneDAttention(nn.Module):
    def __init__(self, config, onedattn_config):
        super(OneDAttention, self).__init__()
        self.config = config
        self.n_heads = config.num_attention_heads
        self.d_model = config.hidden_size
        if self.d_model % self.n_heads != 0:
            raise ValueError(
                "The hidden size is not divisble by the number of attention heads! Make sure to update them"
            )
        self.head_dim = self.d_model // self.n_heads
        self.rotary_ndims = 2 # int(self.head_size * config.rotary_pct)
        self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
        self._init_rope()

        self.qk_dim = 3 
        self.Wv = nn.Linear(
            self.d_model, self.d_model, bias=config.attention_bias
        )
        self.Wqk = MLP(self.d_model,2*self.qk_dim*self.n_heads)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
        self.attention_dropout = nn.Dropout(config.attention_dropout)
        self.is_causal = True

        self.causal = self.is_causal
        self.eps = onedattn_config.eps
        self.k = onedattn_config.k
        if self.causal:
            self.num_chunk = onedattn_config.num_chunks  # aan(retrieval) = 16  # 16 for AAN(retrieval); 4 for listops, image, imdb(text), chatgpt={64, 128, 512} me={1024,2048} associative-recall=8
        else:
            self.num_chunk = 1 # 32? for pathfinder # 1
        self.num_heads = self.n_heads
        self.gamma_sq = nn.Parameter(torch.rand(1))
        self.max_len = 1024
        self.position_enc = PositionalEncoder(config.hidden_size,max_len=self.max_len)

    def _init_bias(self, max_positions, device=None):
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            ),
            persistent=False,
        )
        if device is not None:
            self.bias = self.bias.to(device)

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = GPTNeoXRotaryEmbedding(
                self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
            )
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding(
                    self.rotary_ndims,
                    self.config.max_position_embeddings,
                    base=self.config.rotary_emb_base,
                    scaling_factor=scaling_factor,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding(
                    self.rotary_ndims,
                    self.config.max_position_embeddings,
                    base=self.config.rotary_emb_base,
                    scaling_factor=scaling_factor,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
            
    def forward(
            self,
            hidden_states: torch.FloatTensor,
            attention_mask: torch.FloatTensor,
            position_ids: torch.LongTensor,
            head_mask: Optional[torch.FloatTensor] = None,
            layer_past: Optional[Tuple[torch.Tensor]] = None,
            use_cache: Optional[bool] = False,
            output_attentions: Optional[bool] = False
    ):
        batch_size, original_seq_len, feat_dim = hidden_states.shape 
        # hidden_states = torch.cat([hidden_states, torch.zeros([batch_size, self.max_len-original_seq_len, feat_dim],\
        #                 dtype=hidden_states.dtype, device=hidden_states.device)] , dim=1)
        # Apply attention-specific projections and rope
        q, k, v, present = self._attn_projections_and_rope(
            hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
        )
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # Compute attention
        attn_output, attn_weights = self._attn(q, k, v, attention_mask, head_mask)
        attn_output = self.dense(rearrange(attn_output, "... h d -> ... (h d)"))  # [:, :original_seq_len]

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs
    
    def _attn_projections_and_rope(
        self,
        hidden_states: torch.FloatTensor,
        position_ids: torch.LongTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        use_cache: Optional[bool] = False,
    ):
        has_layer_past = layer_past is not None

        # Compute QKV
        # Attention heads [batch, seq_len, hidden_size]
        #   --> [batch, seq_len, (3 * 2 * head_size)] & [batch, seq_len, (np * 1 * head_size)]
        hidden_states = self.position_enc(hidden_states)
        qk, v = self.Wqk(hidden_states), self.Wv(hidden_states)

        # [batch, seq_len, (num_heads * 3 * head_size)]
        #   --> [batch, seq_len, num_heads, 3 * head_size]
        v = rearrange(
            v, "... (h d) -> ... h d", d=self.head_dim
        )
        qk = rearrange(
            qk, "... (h d) -> ... h d", h=self.num_heads
        )
        q=qk[...,:self.qk_dim]; k=qk[...,self.qk_dim:]

        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
        query = q.permute(0, 2, 1, 3)
        key = k.permute(0, 2, 1, 3)
        value = v.permute(0, 2, 1, 3)

        # Cache QKV values
        if has_layer_past:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)
        present = (key, value) if use_cache else None

        return query, key, value, present
    


    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
        # compute causal mask from causal mask buffer
        batch_size, seq_len, num_head, feat_dim = v.shape
        num_attn = 2*self.k; self.chunk_size = math.ceil(seq_len / self.num_chunk)
        original_seq_len = seq_len; seq_len = self.num_chunk * self.chunk_size
        qk = torch.cat([q, k], dim=-1)
        qk = torch.cat([qk, torch.zeros([batch_size, seq_len-original_seq_len, num_head, qk.shape[-1]], device=v.device)], dim=1)
        v = torch.cat([v, torch.zeros([batch_size, seq_len-original_seq_len, num_head, feat_dim], device=v.device)] , dim=1)
        qk = rearrange(
            qk, "b s h d -> (b h) s d", h=num_head
        )
        qk_dim=qk.shape[-1]//2
        q=qk[:,:,:qk_dim]; k=qk[:,:,qk_dim:]  # TODO: need to padding to the seq_len which can be mod by chunk_size = 0
            
        with torch.no_grad():

            qk_z_order = qk.detach().reshape(batch_size*num_head, -1, qk_dim)
            
            ############ multiple z-order mapping, each feature can be important
            num_z_order = 1

            qk_z_order = z_order_optimized(qk_z_order).view(batch_size*num_head*num_z_order, seq_len,-1)
            ##############################
            q_z_order = qk_z_order[:,:,0]; k_z_order = qk_z_order[:,:,1]


            hash_vals_causal, hash_vals_causal_original_index, chunk_positions = chunk_causal_sort(k_z_order, self.chunk_size)
            q_z_order = q_z_order.view(batch_size*num_head*num_z_order,  seq_len//self.chunk_size, self.chunk_size)
            # Find the queried indices for each element  batch*num_head, num_chunk, chunk_size
            indices_to_keys_val = torch.searchsorted(hash_vals_causal, q_z_order.contiguous())
            # to index the keys using a self.k size window/find knn
            indices_to_keys_range_idx = indices_to_keys_val.view(batch_size*num_head*num_z_order, seq_len, 1) + torch.arange(-self.k, self.k,device=v.device).reshape(1, 1, -1)  # pos to insert query
            chunk_range = chunk_positions.expand(-1,-1, self.chunk_size).reshape(1, seq_len, 1).detach()
            
            out_range_mask = (indices_to_keys_range_idx <0) + (indices_to_keys_range_idx>chunk_range-1)
            # Clamp the start and end indices to ensure they are within bounds
            indices_to_keys_range_idx = torch.clamp(indices_to_keys_range_idx, min=0)  # attend to position i and its history
            indices_to_keys_range_idx = torch.clamp(indices_to_keys_range_idx, max=seq_len-1)
            #                                                                                 
            attended_indices = torch.gather(hash_vals_causal_original_index, -1, indices_to_keys_range_idx.view(batch_size*num_head*num_z_order, self.num_chunk,self.chunk_size*num_attn)).view(batch_size*num_head*num_z_order, self.num_chunk*self.chunk_size, num_attn)
            causal_range = torch.arange(seq_len, device=v.device).reshape(1, -1, 1) # 
            if self.causal:
                mask = (attended_indices<0) +  (attended_indices>causal_range)
            else:
                mask = (attended_indices<0)
            attended_indices[mask] = 0
            mask = mask + out_range_mask
        # collect the 2k nearest neighbor using a window to index
        v = rearrange(
            v, "b s h d -> (b h) s d"
        )

        v = v.unsqueeze(1).expand(-1, num_z_order, -1, -1).reshape(batch_size*num_head*num_z_order, seq_len, feat_dim)
        k = k.unsqueeze(1).expand(-1, num_z_order, -1, -1).reshape(batch_size*num_head*num_z_order, seq_len, qk_dim) / ( (8.*qk_dim)**0.5 )

        epsilon = self.eps * torch.sigmoid(self.gamma_sq) #(q)) # learnable converge in 22 epochs! # 0.008 converge in 31 epochs # 0.01
        epsilon = epsilon.repeat(num_head)
        q = q.unsqueeze(1).expand(-1, num_z_order, -1, -1).reshape(batch_size*num_head*num_z_order, seq_len, qk_dim)/ ( (8.*qk_dim)**0.5 )
        output =  _sparse_topk_attention.apply(q, k, v, attended_indices, epsilon, num_head, 32, mask, True)
        output = rearrange(
            output, "(b h) l ... -> b l h ...", b=batch_size, h=self.num_heads, l=seq_len
        )

        # delete all the index and mask tensor to free gpu memory
        del qk_z_order; del q_z_order; del k_z_order
        del hash_vals_causal; del hash_vals_causal_original_index; del chunk_positions
        del indices_to_keys_val; del indices_to_keys_range_idx; del chunk_range; del out_range_mask
        del attended_indices; del mask; # del mask_to_add
        # torch.cuda.empty_cache()


        return output[:, :original_seq_len], None 

        






class GPTNeoXRotaryEmbedding(nn.Module):
    # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len],
            self.sin_cached[:seq_len],
        )


# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
# TODO @gante bring compatibility back
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
    """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)


class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
    """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
    # TODO @gante no longer copied from
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len

        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

