import torch
from sparse import fwd_sparse
import pca_topk as G

def loki_attention(query, key, value, attn_mask, Out, top_r, top_k):

    bsz, q_len, num_head, head_dim = query.shape
    kv_len = key.shape[0] // bsz
    # bsz, num_head, kv_len, head_dim = key.shape
    # pca_query = torch.matmul(query.reshape(bsz, num_kv_head, -1, head_dim), pca_component.unsqueeze(0).expand(bsz, -1, -1, -1))
    # pca_key = torch.matmul(key, pca_component.unsqueeze(0).expand(bsz, -1, -1, -1))
    B = key.view(bsz, -1, num_head, head_dim).transpose(1, 2).reshape(bsz * num_head, kv_len, head_dim).transpose(-1,-2)
    # print(key.shape)
    attn_weights = G.topr_bmv_optimized(A=query.view(bsz * num_head, 1, head_dim), B=B, 
                                                        r=top_r)
    attn_weights = attn_weights.view(bsz, num_head, kv_len)
    _, label_index = torch.topk(attn_weights, top_k, dim=-1)
    # pca_key = pca_key.transpose(1, 2).reshape(bsz * kv_len, -1, head_dim)
    # value = value.transpose(1, 2).reshape(bsz * kv_len, -1, head_dim)
    # print(pca_query.shape)
    # print(pca_key.shape)
    # print(value.shape)
    fwd_sparse(query.squeeze(1), 
               key, 
               value, 
               Out, 
               label_index, 
               attn_mask)

