import os
import math
import torch
import time
from torch import nn
import numpy as np
from flash_attn import flash_attn_with_kvcache
from flash_attn import flash_attn_func



def retr_attn_v1_prefill_attn(query_states, key_states, value_states, causal):

    attn_out = flash_attn_with_kvcache(
        q=query_states, 
        k_cache=key_states, 
        v_cache=value_states,
        causal=causal
    )
    
    return attn_out


def retr_attn_v1_decode_attn(query_states, key_states, value_states, layer_idx, retr_attn_v1_cache):
    query_backup = query_states.clone()
    
    query_states = query_states.transpose(1, 2).contiguous()
    query_states = query_states.detach().cpu().numpy()

    current_layer_index = retr_attn_v1_cache.key_cache[layer_idx]

    if retr_attn_v1_cache.index_type == "Flat": 
        distances, indices = current_layer_index.search(query_states[0], retr_attn_v1_cache.group_size, retr_attn_v1_cache.topk, os.cpu_count())
    elif retr_attn_v1_cache.index_type == "IVF":
        distances, indices = current_layer_index.search(query_states[0], layer_idx, retr_attn_v1_cache.group_size, retr_attn_v1_cache.topk, os.cpu_count())
    elif retr_attn_v1_cache.index_type == "RAIndex":
        search_l = retr_attn_v1_cache.ra_search_params[layer_idx]

        distances = np.zeros((query_states.shape[1], query_states.shape[2], retr_attn_v1_cache.topk), dtype=np.float32)
        indices = np.zeros((query_states.shape[1], query_states.shape[2], retr_attn_v1_cache.topk), dtype=np.uint32)
        cmps_arr = np.zeros((query_states.shape[1]), dtype=np.uint32)
        
        if current_layer_index.is_use_sq():
            current_layer_index.searchRAIndexGetCmpsSQ(query_states[0].astype(np.float32), retr_attn_v1_cache.topk, search_l, indices, distances, query_states.shape[2], 12, cmps_arr)
        else:
            current_layer_index.searchRAIndexGetCmps(query_states[0].astype(np.float32), retr_attn_v1_cache.topk, search_l, indices, distances, query_states.shape[2], 12, cmps_arr)
    elif retr_attn_v1_cache.index_type == "HNSW":
        distances = np.zeros((query_states.shape[1], query_states.shape[2], retr_attn_v1_cache.topk), dtype=np.float32)
        indices = np.zeros((query_states.shape[1], query_states.shape[2], retr_attn_v1_cache.topk), dtype=np.uint32)

        for head_idx in range(query_states.shape[1]):
            q = query_states[0, head_idx, 0]
            
            query2key = head_idx//retr_attn_v1_cache.group_size
            labels, dist = retr_attn_v1_cache.key_cache[layer_idx][query2key].knn_query(data=q, k=retr_attn_v1_cache.topk)
            
            indices[head_idx] = labels
            distances[head_idx] = 1-dist[0]
    else:
        raise ValueError(f"Unsupported index type: {retr_attn_v1_cache.index_type}")

    retrieval_attn_weights = distances.reshape((query_states.shape[0], query_states.shape[1], query_states.shape[2], retr_attn_v1_cache.topk)) / math.sqrt(query_states.shape[3])
    retrieval_attn_weights = torch.from_numpy(retrieval_attn_weights)

    retrieval_indices = indices.reshape((query_states.shape[1], query_states.shape[2], retr_attn_v1_cache.topk)).astype(np.int32)
    
    row_max, _ = torch.max(retrieval_attn_weights, dim=-1, keepdim=True)
    normalize = retrieval_attn_weights - row_max   # address overflow risk in overflow

    retrieval_lse = (torch.log(torch.sum(torch.exp(normalize), dim=-1))+row_max.squeeze(-1)).to(retr_attn_v1_cache.layer_mapping[str(layer_idx)])

    retrieval_attn_weights = nn.functional.softmax(normalize, dim=-1, dtype=torch.float32)
    retrieval_attn_weights = nn.functional.dropout(retrieval_attn_weights, p=0, training=False)
    
    # naive_retrieval_out = torch.full((query_states.shape[0], query_states.shape[1], query_states.shape[2], query_states.shape[3]), 0, dtype=torch.float32)
    # for head_idx in range(query_states.shape[1]):
    #     for q_idx in range(query_states.shape[2]):
    #         query2key = head_idx // retr_attn_v1_cache.group_size
    #         retrieval_attn_values = retr_attn_v1_cache.naive_value_cache[layer_idx][0][query2key][retrieval_indices[head_idx][q_idx]] # (bs, nhead, true_k, head_dim)
    #         naive_retrieval_out[:, head_idx:, q_idx, :] = torch.matmul(retrieval_attn_weights[:, head_idx, q_idx, :], retrieval_attn_values.to(torch.float32))

    selected_v = torch.zeros(size=(retr_attn_v1_cache.batch_size * retr_attn_v1_cache.num_heads, retr_attn_v1_cache.topk, retr_attn_v1_cache.head_dim), dtype=torch.float16)
    retrieval_indices = torch.from_numpy(retrieval_indices)

    prefetch_idx = retrieval_indices.to(torch.int).reshape(-1, retr_attn_v1_cache.topk)

    retr_attn_v1_cache.value_cache.gather_value(layer_idx, 16, selected_v, prefetch_idx)
    selected_v = selected_v.reshape(retr_attn_v1_cache.batch_size, retr_attn_v1_cache.num_heads, retr_attn_v1_cache.topk, retr_attn_v1_cache.head_dim)
    
    retrieval_out = torch.matmul(retrieval_attn_weights, selected_v.to(torch.float))

    retrieval_out = retrieval_out.to(retr_attn_v1_cache.layer_mapping[str(layer_idx)]).to(torch.float16)

    static_pattern_total = retr_attn_v1_cache.static_pattern_total if layer_idx == retr_attn_v1_cache.layer_num - 1 else retr_attn_v1_cache.static_pattern_total + 1
    flash_out, flash_lse, _ = flash_attn_func(
        query_backup,
        retr_attn_v1_cache.gpu_key_cache[layer_idx][:, :, :static_pattern_total, :],
        retr_attn_v1_cache.gpu_value_cache[layer_idx][:, :, :static_pattern_total, :],
        dropout_p=0,
        causal=False,
        window_size=(-1, -1),
        alibi_slopes=None,
        deterministic=False,
        return_attn_probs=True,
    )

    flash_lse = flash_lse.transpose(-2, -1).unsqueeze(dim=-1)
    retrieval_lse = retrieval_lse.transpose(-2, -1).unsqueeze(dim=-1)
    new_lse = retrieval_lse + torch.log(1 + torch.exp(flash_lse - retrieval_lse))

    final_out = torch.exp(retrieval_lse - new_lse) * retrieval_out.transpose(1, 2) + torch.exp(flash_lse - new_lse) * flash_out
    retrieval_out = final_out.to(torch.float16)

    return retrieval_out.contiguous()

