import math
from multiprocessing import Pool, cpu_count
from functools import partial
import faiss
import faiss.contrib.torch_utils
import torch
import torch.nn.functional as F
import einops
from torch import nn
from .utils import logging
import psutil

logger = logging.get_logger(__name__)

def get_cpu_memory():
    total_memory_usage = 0

    # Iterate over all running processes
    for process in psutil.process_iter(['pid', 'name', 'memory_info']):
        try:
            # Get the memory usage of the process
            memory_info = process.info['memory_info']
            total_memory_usage += memory_info.rss
        except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
            # Handle processes that may have terminated or are inaccessible
            continue

    # Convert total memory usage to GB
    total_memory_usage_gb = total_memory_usage / (1024 ** 3)
    return total_memory_usage_gb


def get_topk_via_faiss(topk_k, query_states, key_databases, kv_heads, kv_groups):
    """Retrieve top-k values and indices using FAISS.

    Args:
        topk_k (int): Number of top-k values to retrieve.
        query_states (torch.Tensor): Query states tensor of shape (BH, N_q, D).
        key_databases (list): List of FAISS search indexes.
        kv_heads (int): Number of key-value heads.
        kv_groups (int): Number of key-value groups.

    Returns:
        torch.Tensor: Top-k normalized score values tensor of shape (BH, N_q, topk_k).
        torch.Tensor: Top-k indices tensor of shape (BH, N_q, topk_k).
    """
    BH, N_q, D = query_states.shape
    faiss_values_tensor = torch.zeros((BH, N_q, topk_k))
    faiss_indices_tensor = torch.zeros((BH, N_q, topk_k), dtype=torch.int32)
    for i, search_index in enumerate(key_databases):
        faiss_values, faiss_indices = search_index.search(
            query_states[i, :, :].contiguous().to(torch.float32).cpu(), k=topk_k
        )
        faiss_values_tensor[i, :, :] = faiss_values
        faiss_indices_tensor[i, :, :] = faiss_indices
    
    # Scale the dot products
    faiss_values_tensor = faiss_values_tensor / math.sqrt(D)
    return faiss_values_tensor, faiss_indices_tensor



def search_db(args):
    """Search a FAISS database

    Args:
        args (tuple): Tuple containing key_db, Q, k, and D.

    Returns:
        tuple: Tuple containing values and indices from the search.
    """
    key_db, Q, k, D = args
    v, i = key_db.search(Q.contiguous().to(torch.float32), k=k)
    v = v / math.sqrt(D)
    return v, i


def get_topk_via_faiss_parallel(topk_k, query_states, key_databases):
    """Retrieve top-k values and indices using FAISS in parallel.

    Args:
        topk_k (int): Number of top-k values to retrieve.
        query_states (torch.Tensor): Query states tensor of shape (BH, N_q, D).
        key_databases (list): List of FAISS search indexes.

    Returns:
        torch.Tensor: Top-k normalized score values tensor of shape (BH, N_q, k).
        torch.Tensor: Top-k indices tensor of shape (BH, N_q, k).
    """
    BH, N_q, D = query_states.shape
    faiss_values_list = []
    faiss_indices_list = []

    faiss_values_tensor = torch.zeros((BH, N_q, topk_k))
    faiss_indices_tensor = torch.zeros((BH, N_q, topk_k), dtype=torch.int32)

    args = [(key_databases[i], query_states[i, :, :].detach().cpu(), topk_k, D) for i in range(BH)]
    with Pool(cpu_count()) as pool:
        results = pool.map(search_db, args)

    for i, (faiss_v, faiss_indices) in enumerate(results):
        faiss_values_tensor[i, :, :] = torch.tensor(faiss_v)
        faiss_indices_tensor[i, :, :] = torch.tensor(faiss_indices)
    return faiss_values_tensor, faiss_indices_tensor

def get_topk_via_knn_cpu(topk_k, query_states, key_states):
    """Retrieve top-k values and indices using K-Nearest Neighbors (KNN).

    Args:
        topk_k (int): Number of top-k values to retrieve.
        query_states (torch.Tensor): Query states tensor of shape (BH, N_q, D).
        key_states (torch.Tensor): Key states tensor of shape (BH, N, D).
        disable_scratch_memory (bool): Disables faiss scratch memory

    Returns:
        torch.Tensor: Top-k normalized score values tensor of shape (BH, N_q, k).
        torch.Tensor: Top-k indices tensor of shape (BH, N_q, k).
    """
    assert len(query_states.shape) == 3, f"Query states should have shape (BH, N_q, D), got {query_states.shape}"
    assert len(key_states.shape) == 3,   f"Key states should have shape (BH, N, D), got {key_states.shape}"
    BH, N_q, D = query_states.shape
    faiss_values_list = []
    faiss_indices_list = []
    faiss_values_tensor = torch.zeros((BH, N_q, topk_k))
    faiss_indices_tensor = torch.zeros((BH, N_q, topk_k), dtype=torch.int64)
    for i in range(BH):
        faiss_values, faiss_indices = faiss.knn(
            query_states[i, :, :].contiguous().to(torch.float32).cpu(),
            key_states[i, :, :].contiguous().to(torch.float32),
            k=topk_k,
            metric=faiss.METRIC_INNER_PRODUCT,
        )
        faiss_values_tensor[i, :, :] = torch.tensor(faiss_values)
        faiss_indices_tensor[i, :, :] = torch.tensor(faiss_indices, dtype=torch.int64)
    faiss_values_tensor = faiss_values_tensor / math.sqrt(D)
    return faiss_values_tensor, faiss_indices_tensor


def get_topk_via_knn(topk_k, query_states, key_states, disable_scratch_memory=False):
    """Retrieve top-k values and indices using K-Nearest Neighbors (KNN).

    Args:
        topk_k (int): Number of top-k values to retrieve.
        query_states (torch.Tensor): Query states tensor of shape (BH, N_q, D).
        key_states (torch.Tensor): Key states tensor of shape (BH, N, D).
        disable_scratch_memory (bool): Disables faiss scratch memory

    Returns:
        torch.Tensor: Top-k normalized score values tensor of shape (BH, N_q, k).
        torch.Tensor: Top-k indices tensor of shape (BH, N_q, k).
    """
    assert len(query_states.shape) == 3, "Query states should have shape (BH, N_q, D)"
    assert len(key_states.shape) == 3,   "Key states should have shape (BH, N, D)"
    BH, N_q, D = query_states.shape
    faiss_values_list = []
    faiss_indices_list = []
    FAISS_RES = faiss.StandardGpuResources()
    if disable_scratch_memory:
        logger.warning_once("Calling Faiss GPU KNN with 0b scratch memory. May be significantly slow.")
        FAISS_RES.setTempMemory(0)
    faiss_values_tensor = torch.zeros((BH, N_q, topk_k), device=query_states.device)
    faiss_indices_tensor = torch.zeros(
        (BH, N_q, topk_k), dtype=torch.int64, device=query_states.device
    )
    # TODO parallelize?
    for i in range(BH):
        faiss_values, faiss_indices = faiss.knn_gpu(
            FAISS_RES,
            query_states[i, :, :].contiguous().to(torch.float32),
            key_states[i, :, :].contiguous().to(torch.float32),
            k=topk_k,
            metric=faiss.METRIC_INNER_PRODUCT,
        )
        faiss_values_tensor[i, :, :] = faiss_values
        faiss_indices_tensor[i, :, :] = faiss_indices
    faiss_values_tensor = faiss_values_tensor / math.sqrt(D)
    return faiss_values_tensor, faiss_indices_tensor


def create_sparse_matrix(topk_values, topk_indices, N_k, mask=True, mask_offset=None):
    """Create a PyTorch sparse matrix from top-k values and indices.

    Args:
        topk_values (torch.Tensor): Tensor of top-k normalized score values of shape (BH, N_q, k).
        topk_indices (torch.Tensor): Tensor of top-k indices of shape (BH, N_q, k).
        N_k (int): Number of keys (defines the width of the dense version of the matrix).

    Returns:
        torch.sparse.Tensor: Sparse matrix in COO format with causal attention applied.
    """
    BH, N_q, k = topk_indices.shape
    N_list = einops.repeat(torch.arange(N_q), "i -> bh i k", bh=BH, k=k).to(
        topk_indices.device
    )
    i_tens = torch.cat((N_list, topk_indices), dim=0)
    i_tens = einops.rearrange(i_tens, "(two bh) i k -> bh two (i k)", bh=BH, two=2)

    bh_list = einops.repeat(torch.arange(BH), "bh -> bh 1 Nk", Nk=N_q * k).to(
        i_tens.device
    )
    i_tens = nn.functional.pad(i_tens, pad=(0, 0, 1, 0, 0, 0))
    i_tens.scatter_(dim=0, index=bh_list, src=bh_list)
    i_tens = einops.rearrange(i_tens, "bh j Nk -> j (bh Nk)")

    v = einops.rearrange(topk_values, "bh N k -> (bh N k)")

    if mask:
        if mask_offset is None: mask_offset = 0
        # This hard-codes in the attention mask
        mask = (i_tens[1] + mask_offset) >= i_tens[2]
        i_tens = i_tens[:, mask]
        v = v[mask]

    return torch.sparse_coo_tensor(
        i_tens, v, (BH, N_q, N_k)
    )  # BH, N_q, N_k in sparse format


def gather_and_get_output(topk_k, topk_values, topk_indices, value_states, B, H, N, D):
    """Gather and compute the output from top-k values and indices.

    Args:
        topk_k (int): Number of top-k values.
        topk_values (torch.Tensor): Top-k values tensor.
        topk_indices (torch.Tensor): Top-k indices tensor.
        value_states (torch.Tensor): Value states tensor.
        B (int): Batch size.
        H (int): Number of heads.
        N (int): Number of queries.
        D (int): Dimension of the embeddings.

    Returns:
        torch.Tensor: Output tensor of shape (B, H, N, D).
    """
    BH = B * H
    row_indices = einops.repeat(
        torch.arange(N, device=value_states.device), "N -> BH N k", BH=BH, k=topk_k
    )
    mask = topk_indices > row_indices
    topk_values[mask] = float("-inf")
    topk_values = F.softmax(topk_values, dim=-1)
    topk_values = torch.nan_to_num(topk_values)
    topk_indices = einops.repeat(topk_indices, "BH N k -> k BH N D", D=D)
    value_states = einops.repeat(value_states, "BH N D -> k BH N D", k=topk_k)
    value_states_gathered = torch.gather(value_states, -2, topk_indices)
    topk_values = einops.repeat(topk_values, "BH N k -> k BH N D", D=D)
    weighted_values_gathered = value_states_gathered * topk_values
    xhat = einops.reduce(weighted_values_gathered, "k BH N D -> BH N D", "sum")
    xhat = einops.rearrange(xhat, "(B H) N D -> B H N D", B=B, H=H)
    xhat = xhat.to(value_states.dtype)
    return xhat


def loop_and_get_output(topk_k, topk_values, topk_indices, value_states, B, H, N, D):
    """Like gather, but just uses a loop"""
    raise NotImplementedError


# ALWAYS returns a tensor of size (B, H, N_q, D)
def topk_output_with_database(
    topk_k,
    query_states,
    key_states,
    value_states,
    key_database,
    value_cache,
    B,
    H,
    head_dim,
    mode=None,
):
    """Compute the output of the attention block with key in a FAISS database.

    Args:
        topk_k (int): Number of topk values
        query_states (torch.Tensor): Query states tensor.
        key_states (torch.Tensor): Key states tensor.
        value_states (torch.Tensor): Value states tensor.
        key_database (list): List of FAISS search indexes.
        value_cache (torch.Tensor): Cached value states.
        B (int): Batch size.
        H (int): Number of heads.
        head_dim (int): Dimension of the heads.
        mode (str, optional): Mode of operation. Defaults to None.

    Returns:
        torch.Tensor: Output tensor of shape (B, H, N_q, D).
    """
    if len(query_states.shape) == 4:
        _, _, N_q, D = query_states.shape
        BH = B * H
        query_states = einops.rearrange(query_states, "B H N D -> (B H) N D")
        value_states = einops.rearrange(value_states, "B H N D -> (B H) N D")
        key_states = einops.rearrange(key_states, "B H N D -> (B H) N D")
    else:
        BH, N_q, D = query_states.shape

    N_k = key_states.shape[-2]
    # Get topk attention matrix
    topk_values, topk_indices = get_topk_via_knn(
        topk_k, query_states.detach(), key_states.detach()
    )
    # return gather_and_get_output(topk_k, topk_values, topk_indices, value, B, H, N_q, D)
    score_sparse = create_sparse_matrix(
        topk_values, topk_indices, N_k
    )  # BH, N_q, N_k in sparse format
    attn_sparse = (
        torch.sparse.softmax(score_sparse, dtype=torch.float32, dim=-1)
        .to(query_states.dtype)
        .to(value_states.device)
    )
    # Potential for nans since we mask after selecting topk. These scores are set to 0.
    if torch.any(torch.isnan(attn_sparse)).item():
        attn_sparse = torch.nan_to_num(attn_sparse)
    # torch.bmm does not support sparse cuda and float16, so they must be casted.
    # see https://github.com/pytorch/pytorch/issues/80574
    xhat = torch.bmm(attn_sparse.to(torch.float32), value_states.to(torch.float32)).to(
        query_states.dtype
    )
    xhat = einops.rearrange(xhat, "(b h) n d -> b h n d", b=B, h=H)
    return xhat


# def topk_construct_output(topk_k, query_states, key_database, value_states, B, H, kv_heads, kv_groups):
#     """Gets output of attention block in construct phase of model

#     Args:
#         topk_k (int): Number of top-k entries of score matrix
#         query_states (torch.Tensor): Query states tensor.
#         key_states (torch.Tensor): Key states tensor.
#         value_states (torch.Tensor): Value states tensor.
#         B (int): Batch size.
#         H (int): Number of heads.

#     Returns:
#         torch.Tensor: Output tensor of shape (B, H, N_q, D).
#     """
#     if len(query_states.shape) == 4:
#         _, _, N_q, D = query_states.shape
#         BH = B * H
#         query_states = einops.rearrange(query_states, "B H N D -> (B H) N D")
#         value_states = einops.rearrange(value_states, "B H N D -> (B H) N D")
#     else:
#         BH, N_q, D = query_states.shape

#     N_k = key_database[0].ntotal
#     assert (
#         topk_k <= N_k
#     ), f"Number of topk is larger than N. Expected topk_k <= {N_k}, got topk_k = {topk_k}"
#     # Get topk attention matrix
#     topk_values, topk_indices = get_topk_via_faiss(topk_k, query_states, key_database, kv_heads, kv_groups)
#     # return gather_and_get_output(topk_values.to(value_states.device), topk_indices.to(value_states.device), value_states, B, H, N_q, D, k)
#     score_sparse = create_sparse_matrix(
#         topk_values, topk_indices, N_k
#     )  # BH, N_q, N_k in sparse format
#     attn_sparse = (
#         torch.sparse.softmax(score_sparse, dtype=torch.float32, dim=-1)
#         .to(query_states.dtype)
#         .to(value_states.device)
#     )
#     # Potential for nans since we mask after selecting topk. These scores are set to 0.
#     if torch.any(torch.isnan(attn_sparse)).item():
#         attn_sparse = torch.nan_to_num(attn_sparse)
#     # no bmm_sparse_cuda kernel for bfloat16, so must cast to float32
#     # see https://github.com/pytorch/pytorch/issues/80574
#     xhat = (
#         torch.bmm(attn_sparse.to(torch.float32), value_states.to(torch.float32))
#         .to(query_states.dtype)
#         .cuda()
#     )
#     xhat = einops.rearrange(xhat, "(b h) n d -> b h n d", b=B, h=H)
#     return xhat


def topk_attn(
    topk_k, 
    query_states, 
    suffix_key_states, 
    suffix_value_states, 
    prefix_key_db, 
    prefix_value_states, 
    B, 
    H, 
    kv_heads, 
    kv_groups, 
    num_prev_seen_tokens=0,
    construct_mode=False
):
    """Computes output of attention block in query phase of model

    In the query phase, the databases from the forward pass in the construct phase have
    already been created. Here we densely compute the attention of the suffix prompt
    with itself, and use the databases to compute the attention of the suffix with
    respect to the prefix.

    Args:
        topk_k (int): number of top-k entries to take
        query_states (torch.Tensor): Query states tensor.
        key_states (torch.Tensor): Key states tensor.
        value_states (torch.Tensor): Value states tensor.
        key_database (list): List of FAISS search indexes.
        value_cache (torch.Tensor): Cached value states.
        B (int): Batch size.
        H (int): Number of heads.

    Returns:
        torch.Tensor: Output tensor of shape (B, H, N_q, D).
    """    
    if len(query_states.shape) == 4:
        _, _, N_q, D = query_states.shape
        BH = B * H
        query_states = einops.rearrange(query_states, "B H N D -> (B H) N D")
        prefix_value_states = einops.rearrange(prefix_value_states, "B H N D -> (B H) N D")
    else:
        BH, N_q, D = query_states.shape

    N_k_prefix = prefix_key_db[0].ntotal
    if topk_k > N_k_prefix:
        topk_k = N_k_prefix

    topk_values, topk_indices = get_topk_via_faiss(topk_k, query_states, prefix_key_db, kv_heads, kv_groups)
    
    # In construct mode we only simply compute attention from the provided key and value states that we are actively
    # caching into a faiss database (the prefix).
    # In query/generation mode we have to combine the prefix keys/values with suffix keys/values and that
    # changes the masking and softmax computation.
    if construct_mode:
        score_sparse = create_sparse_matrix(
            topk_values, topk_indices, N_k_prefix, mask_offset=num_prev_seen_tokens
        )  # BH, N_q, N_k in sparse format
        attn_sparse = (
            torch.sparse.softmax(score_sparse, dtype=torch.float32, dim=-1)
            .to(query_states.dtype)
            .to(prefix_value_states.device)
        )
    else:
        try:
            suffix_value_states = einops.rearrange(suffix_value_states, "B H N D -> (B H) N D")
            suffix_key_states = einops.rearrange(suffix_key_states, "B H N D -> (B H) N D")
        except (einops.EinopsError, RuntimeError) as e:
            msg = f"Suffix key/value states must not be empty in query/generate mode. Got suffix_key_states: {suffix_key_states.shape if not suffix_key_states is None else type(suffix_key_states)}."
            raise ValueError(msg) from e
        
        topk_values_exp = torch.exp(topk_values)
        topk_values_exp_sums = einops.reduce(
            topk_values_exp, "BH N_k_suffix topk_k -> BH N_k_suffix", "sum"
        )

        score_dense = query_states @ suffix_key_states.mT / math.sqrt(D)

        # Don't mask if there is only one token
        if N_q > 1:
            score_dense = score_dense + torch.triu(
                torch.full_like(score_dense, float("-inf")), 1
            )
        score_dense_exp = torch.exp(score_dense)
        score_dense_exp_sums = einops.reduce(
            score_dense_exp, "BH N_k_suffix N_k_suffix_2 -> BH N_k_suffix", "sum"
        )
        softmax_denominators = topk_values_exp_sums.cuda() + score_dense_exp_sums

        attn_sparse = create_sparse_matrix(
            topk_values_exp
            / einops.repeat(
                softmax_denominators.cpu(), "BH N_k_prefix -> BH N_k_prefix topk_k", topk_k=topk_k
            ),
            topk_indices,
            N_k_prefix,
            mask=False,
        )
        attn_dense = score_dense_exp / einops.repeat(
            softmax_denominators,
            "BH N_k_suffix -> BH N_k_suffix N",
            N=score_dense_exp.shape[-1],
        ).to(torch.bfloat16)
    
    # Potential for nans since we mask after selecting topk. These scores are set to 0.
    if torch.any(torch.isnan(attn_sparse)).item():
        attn_sparse = torch.nan_to_num(attn_sparse)
    # no bmm_sparse_cuda kernel for bfloat16, so must cast to float32
    # see https://github.com/pytorch/pytorch/issues/80574
    xhat = (
        torch.bmm(attn_sparse.to(torch.float32), prefix_value_states.to(torch.float32))
        .to(query_states.dtype)
        .cuda()
    )
    if not construct_mode:
        xhat = xhat + attn_dense @ suffix_value_states
    xhat = einops.rearrange(xhat, "(b h) n d -> b h n d", b=B, h=H)
    return xhat


def naive_topk_attn(config, attn_weights, query_states, key_states, value_states):
    ### Actual attn stuff: ####
    k = config.topk
    if k == 0:
        pass  # Don't do top-k attn, but still do logging stuff
    else:
        if k > attn_weights.size(-1):
            k = attn_weights.size(-1)
        _, indices = torch.topk(attn_weights, k, dim=-1)
        topk_mask = torch.full_like(attn_weights, float("-inf"))
        topk_mask.scatter_(dim=-1, index=indices, value=0)
        attn_weights = attn_weights + topk_mask
    ############################

    ### Logging stuff: #########
    # Save top k indices
    if config.save_indices:
        if "indices" not in config.scratchpad:
            config.scratchpad["indices"] = indices
        else:
            config.scratchpad["indices"] = torch.cat(
                [config.scratchpad["indices"], indices], dim=-1
            )
    # Save a list of all context lengths
    if config.save_contexts:
        n_tokens = key_states.size(-2)
        config.scratchpad["context_lengths"].append(n_tokens)

    # Save the last attention map that gets produced during the run
    if config.save_attn_last:
        if hasattr(config, "n_layer"):
            num_layers = config.n_layer
            logger.info("The number of layers was in the 'n_layer' attribute")
        elif hasattr(config, "num_hidden_layers"):
            num_layers = config.num_hidden_layers
            logger.info("The number of layers was in the 'num_hidden_layers' attribute")
        else:
            raise ValueError(
                "Number of hidden layers couldn't be determined from the config. Try debugging to see where this info might be."
            )
        if config.scratchpad["counter"] == 0:
            config.scratchpad["attn_first"] = attn_weights[0, 0, :, :]
        if config.scratchpad["counter"] == num_layers:
            config.scratchpad["attn_last"] = attn_weights[0, 0, :, :]

    # Save an aggregate of all the attention maps that get produced during the run
    if config.save_attn_agg:
        # Sort the values so that when we average them it makes sense.
        new_attn, _ = torch.sort(attn_weights, dim=-1, descending=True)
        assert (
            new_attn.dim() == 4
        ), f"attn mx should have dimensions like (b, h, n, n) but this has new_attn.shape"
        new_attn = new_attn.mean(dim=(0, 1))
        if "attn_agg" not in config.scratchpad or config.scratchpad["attn_agg"] is None:
            config.scratchpad["attn_agg"] = new_attn
        else:
            # Add the values of the new attention map to the existing aggregate. Pad the smaller of the two tensors with zeros so the addition matches.
            old_attn = config.scratchpad["attn_agg"]
            assert (
                new_attn.shape[-1] == new_attn.shape[-2]
            ), f"Attn mx is not square: {new_attn.shape}"

            # Determine the larger of the two tensors and pad the smaller one
            # Pad the top and right side of the matrix with zeros (see ordering in pad() call)
            max_dim = max(old_attn.shape[-1], new_attn.shape[-1])
            if old_attn.shape[-1] < max_dim:
                pad_amount = max_dim - old_attn.shape[-1]
                old_attn = nn.functional.pad(
                    old_attn, pad=(0, pad_amount, pad_amount, 0)
                )
            elif new_attn.shape[-1] < max_dim:
                pad_amount = max_dim - new_attn.shape[-1]
                new_attn = nn.functional.pad(
                    new_attn, pad=(0, pad_amount, pad_amount, 0)
                )

            config.scratchpad["attn_agg"] = old_attn + new_attn

    if config.save_hist:
        # Parameters for the histogram
        num_bins = 100
        min_value = 0
        max_value = 1
        if "attn_agg" not in config.scratchpad or config.scratchpad["attn_agg"] is None:
            config.scratchpad["hist"] = torch.histc(
                attn_weights.flatten(), bins=num_bins, min=min_value, max=max_value
            )
        else:
            config.scratchpad["hist"] += torch.histc(
                attn_weights.flatten(), bins=num_bins, min=min_value, max=max_value
            )

    config.scratchpad["counter"] += 1
    #########################

    ### Debugging stuff: ####
    # if nk > config.scratchpad["maxk"]:
    #     config.scratchpad["maxk"] = nk
    # if nk < config.scratchpad["mink"]:
    #     config.scratchpad["mink"] = nk

    # config.scratchpad["counter"] += 1
    # if config.scratchpad["counter"] % 22 == 0:
    #     print(f"tokens: {nk}")
    # if config.scratchpad["counter"] % 2200 == 0:
    #     breakpoint()
    ##########################
    return attn_weights

def block_topk_attn(
    topk_k, 
    query_states, 
    sparse_keys,
    keys,
    sparse_values,
    values,
    B, 
    H, 
    num_prev_seen_tokens=0,
    use_topk=True
):
    """Computes output of attention block in query phase of model

    In the query phase, the databases from the forward pass in the construct phase have
    already been created. Here we densely compute the attention of the suffix prompt
    with itself, and use the databases to compute the attention of the suffix with
    respect to the prefix.

    Args:
        topk_k (int): number of top-k entries to take
        query_states (torch.Tensor): Query states tensor.
        sparse_key_db (faiss index): Index containing key vectors in sparse part of attention matrix
        keys (torch.Tensor): Key states tensor of dense keys.
        sparse_values (torch.Tensor): Value states tensor.
        values (torch.Tensor): Value states tensor.
        B (int): Batch size.
        H (int): Number of heads.
        num_prev_seen_tokens

    Returns:
        torch.Tensor: Output tensor of shape (B, H, N_q, D).
    """    
    if len(query_states.shape) == 4:
        _, _, N_q, D = query_states.shape
        BH = B * H
        query_states = einops.rearrange(query_states, "B H N D -> (B H) N D")
        keys = einops.rearrange(keys, "B H N D -> (B H) N D")
        values = einops.rearrange(values, "B H N D -> (B H) N D")
    else:
        BH, N_q, D = query_states.shape

    N_q = query_states.shape[-2]
    N_k_dense = keys.shape[-2]
    dtype, device = query_states.dtype, query_states.device
    min_dtype = torch.finfo(dtype).min
    attn_mask = torch.full(
            (N_q, N_k_dense),
            fill_value=min_dtype,
            dtype=dtype,
            device=device
    )
    if N_q > 1:
        attn_mask = torch.triu(attn_mask, diagonal=num_prev_seen_tokens + 1)
    attn_mask = attn_mask.expand(keys.shape[0], -1, -1)

    score = query_states @ keys.mT / math.sqrt(D)
    score_exp = torch.exp(score + attn_mask)
    softmax_denominator = einops.reduce(score_exp, "BH N_q N_k_dense -> BH N_q", "sum")

    #N_k_sparse = sparse_key_db[0].ntotal
    if sparse_keys.nelement() > 0:
        N_k_sparse = sparse_keys.shape[-2]
    else:
        N_k_sparse = 0


    # Get sparse part if cache is nonempty
    if N_k_sparse > 0:
        if topk_k > N_k_sparse:
            topk_k = N_k_sparse
        sparse_keys = einops.rearrange(sparse_keys, "B H N D -> (B H) N D")
        sparse_values = einops.rearrange(sparse_values, "B H N D -> (B H) N D")
        #topk_values, topk_indices = get_topk_via_faiss(topk_k, query_states, sparse_key_db, -1, -1)
        if use_topk:
            topk_values, topk_indices = get_topk_via_knn_cpu(topk_k, query_states, sparse_keys)
            topk_values_exp = torch.exp(topk_values)
            sparse_softmax_denominator = einops.reduce(
                topk_values_exp, "BH N_q topk_k -> BH N_q", "sum"
            )
            softmax_denominator = softmax_denominator + sparse_softmax_denominator.to(device, dtype)
            attn_sparse = create_sparse_matrix(
                topk_values_exp
                / einops.repeat(
                    softmax_denominator.cpu(), "BH N_k_sparse -> BH N_k_sparse topk_k", topk_k=topk_k
                ),
                topk_indices,
                N_k_sparse,
                mask=False,
            )
            # Potential for nans since we mask after selecting topk. These scores are set to 0.
            if torch.any(torch.isnan(attn_sparse)).item():
                attn_sparse = torch.nan_to_num(attn_sparse)
        else:
            sparse_score = query_states.cpu() @ sparse_keys.mT
            sparse_score_exp = torch.exp(sparse_score)
            sparse_softmax_denominator = einops.reduce(
                sparse_score_exp, "BH N_q N_k -> BH N_q", "sum"
            )
            softmax_denominator = softmax_denominator + sparse_softmax_denominator.to(device, dtype)
            attn_sparse = sparse_score_exp / einops.repeat(softmax_denominator.cpu(), "BH N_q -> BH N_q N_k", N_k=sparse_keys.shape[-2])

    attn = score_exp / einops.repeat(softmax_denominator, "BH C_inner -> BH C_inner C_outer_j", C_outer_j=N_k_dense)
    xhat = attn @ values

    if N_k_sparse > 0:
        # no bmm_sparse_cuda kernel for bfloat16, so must cast to float32
        # see https://github.com/pytorch/pytorch/issues/80574
        if use_topk:
            xhat = xhat + (
                torch.bmm(attn_sparse.to(torch.float32), sparse_values.to(torch.float32))
                .to(query_states.dtype)
                .to(query_states.device)
            )
        else:
            xhat = xhat + (attn_sparse @ sparse_values).to(query_states.device)

    xhat = einops.rearrange(xhat, "(b h) n d -> b h n d", b=B, h=H)
    return xhat
