import torch
from compression.kernel.fused_recover_fa import fused_recover_fa, quant_and_pack_vcache, gen_embeds, _flash_attention_forward
from compression.kernel.loki_attention import loki_attention
from flash_attn import flash_attn_func

import math

from transformers.models.llama.modeling_llama import repeat_kv

def compute_mean_variance(seq):
    n = len(seq)
    if n == 0:
        return None, None
    mean = sum(seq) / n
    var = sum((x - mean) ** 2 for x in seq if x < mean * 2) / n  # 如果是样本方差除以 n-1
    std = math.sqrt(var)
    return mean, std
import time
def time_test_with_var(Bsz, Seq_len, low_rank_dim, vbits, method):
    bsz = Bsz
    seq = Seq_len
    num_kv_head = 8
    head_num = 32
    high_rank_dim = 512
    low_rank_dim = low_rank_dim
    recent_length = seq // 64
    low_length = seq - recent_length
    head_dim = 128
    vbits = vbits
    sparsity = 8
    group_size = 32
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    recover_weight = torch.randn(high_rank_dim, num_kv_head*head_dim, dtype=torch.float16, device=device)
    recover_bias = torch.randn(num_kv_head*head_dim, dtype=torch.float16, device=device)


    query_states = torch.randn(bsz, 1, head_num, head_dim, dtype=torch.float16, device=device)
    key_cache = torch.randn(bsz, num_kv_head, seq, head_dim, dtype=torch.float16, device=device)
    key_cache_transform = torch.matmul(key_cache.transpose(1, 2).reshape(bsz, seq, -1), recover_weight.transpose(0, 1))
    key_cache_high = key_cache_transform[:, -recent_length:, :].unsqueeze(1)
    key_cache_low = key_cache_transform[:, :-recent_length, -low_rank_dim:].unsqueeze(1)
    # key_cache_high = torch.randn(bsz, 1, recent_length, high_rank_dim, dtype=torch.float16, device=device)
    # key_cache_low = torch.randn(bsz, 1, low_length, low_rank_dim, dtype=torch.float16, device=device)

    value_cache = torch.randn(bsz, num_kv_head, seq, head_dim, dtype=torch.float16, device=device)
    value_cache_high = value_cache[:, :, -recent_length:, :]
    value_cache_low, value_scale, value_mn = quant_and_pack_vcache(value_cache[:, :, :-recent_length, :], group_size, vbits) 

    sin, cos = gen_embeds(head_dim)
    attn_mask = torch.ones(bsz, seq, dtype=torch.bool, device=device)

    #
    print("Evaling on:", method)
    if method == 'triton':
        high_index = torch.arange(0, recent_length, device=key_cache_high.device).unsqueeze(0).expand(bsz, -1)
        for i in range(10):
            fused_recover_fa(
                query_states, key_cache_high, key_cache_low, value_cache_high, value_cache_low, value_scale, value_mn, recover_weight, recover_bias, sin, cos, high_index, sparsity, num_kv_head, vbits, group_size, None,
            )
        n_repeat = 1000
        softmax_scale = 1.0 / math.sqrt(head_dim)
        time_list = []
        for i in range(n_repeat):
            torch.cuda.synchronize()
            start = time.time()
            fused_recover_fa(
                query_states, key_cache_high, key_cache_low, value_cache_high, value_cache_low, value_scale, value_mn, recover_weight, recover_bias, sin, cos, high_index, sparsity, num_kv_head, vbits, group_size, None
            )
            torch.cuda.synchronize()
            time_list.append((time.time() - start) * 1000)
        mean, var = compute_mean_variance(time_list)
        print("Bsz", bsz, "seq_len", seq, "head_num", head_num, "num_kv_head", num_kv_head, f"compression: {vbits / 16 * 1000}%")
        # print("mean", mean, "std", var, "max", max(time_list), "min", min(time_list), sum(x > mean * 1.5 for x in time_list))
        print("mean", mean, "std", var)
        # print((end-start) / n_repeat * 1000)
    elif method == 'loki':
        pca_component = torch.randn(num_kv_head, head_dim, head_dim, device=device, dtype=torch.float16)
        Out = torch.empty(bsz, head_num, head_dim, device=device, dtype=torch.float16)
        repeated_k = torch.randn(bsz * seq, head_num, head_dim, dtype=torch.float16, device=device)
        repeated_v = torch.randn(bsz * seq, head_num, head_dim, dtype=torch.float16, device=device)
        for i in range(10):
            loki_attention(
                query_states, repeated_k, repeated_v, attn_mask, Out, head_dim // 8, Seq_len // 8
            )
        n_repeat = 1000
        softmax_scale = 1.0 / math.sqrt(head_dim)

        torch.cuda.synchronize()
        start = time.time()
        for i in range(n_repeat):
            loki_attention(
                query_states, repeated_k, repeated_v, attn_mask, Out, head_dim // 8, Seq_len // 8
            )
            
        torch.cuda.synchronize()
        end = time.time()
        print("Bsz", bsz, "seq_len", seq, "head_num", head_num, "num_kv_head", num_kv_head, "head_dim", head_dim, "vbits", vbits, "sparsity", sparsity, "group_size", group_size)
        print((end-start) / n_repeat * 1000)

    
def benchmark_timer(Bsz, Seq_len, low_rank_dim):
    from functools import partial
    bsz = Bsz
    seq = Seq_len
    num_kv_head = 8
    head_num = 32
    high_rank_dim = 512
    low_rank_dim = low_rank_dim
    recent_length = seq // 64
    low_length = seq - recent_length
    head_dim = 128
    vbits = 2
    sparsity = 8
    group_size = 32
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    recover_weight = torch.randn(high_rank_dim, num_kv_head*head_dim, dtype=torch.float16, device=device)
    recover_bias = torch.randn(num_kv_head*head_dim, dtype=torch.float16, device=device)


    query_states = torch.randn(bsz, 1, head_num, head_dim, dtype=torch.float16, device=device)
    key_cache = torch.randn(bsz, num_kv_head, seq, head_dim, dtype=torch.float16, device=device)
    key_cache_transform = torch.matmul(key_cache.transpose(1, 2).reshape(bsz, seq, -1), recover_weight.transpose(0, 1))
    key_cache_high = key_cache_transform[:, -recent_length:, :].unsqueeze(1)
    key_cache_low = key_cache_transform[:, :-recent_length, -low_rank_dim:].unsqueeze(1)
    # key_cache_high = torch.randn(bsz, 1, recent_length, high_rank_dim, dtype=torch.float16, device=device)
    # key_cache_low = torch.randn(bsz, 1, low_length, low_rank_dim, dtype=torch.float16, device=device)

    value_cache = torch.randn(bsz, num_kv_head, seq, head_dim, dtype=torch.float16, device=device)
    value_cache_high = value_cache[:, :, -recent_length:, :]
    value_cache_low, value_scale, value_mn = quant_and_pack_vcache(value_cache[:, :, :-recent_length, :], group_size, vbits) 

    sin, cos = gen_embeds(head_dim)
    attn_mask = torch.ones(bsz, seq, dtype=torch.bool, device=device)
    high_index = torch.arange(0, recent_length, device=key_cache_high.device).unsqueeze(0).expand(bsz, -1)
    for i in range(10):
        fused_recover_fa(
                query_states, key_cache_high, key_cache_low, value_cache_high, value_cache_low, value_scale, value_mn, recover_weight, recover_bias, sin, cos, high_index, sparsity, num_kv_head, vbits, group_size, None
            )
    with torch.autograd.profiler.profile(use_cuda=True) as prof:       
        fused_recover_fa(
                query_states, key_cache_high, key_cache_low, value_cache_high, value_cache_low, value_scale, value_mn, recover_weight, recover_bias, sin, cos, high_index, sparsity, num_kv_head, vbits, group_size, None
            )

    print(prof.key_averages().table(sort_by="cuda_time_total")) 


# benchmark_timer(8, 1024, low_rank_dim=128)
# exit()
# time_test(8, 2048, 128, 2, "triton")
        # time_test(bsz, seq_len, 128, 2, "triton")


# time_test(16, 4096, 128, "flash_cuda")
# time_test(16, 4096, 128, "flash_triton")
# seq_lens = [8192, 16384, 32768, 65536, 128 * 1024]
# time_test_with_var(8, 4096, 128, 2, "triton")
# exit()
seq_lens = [1024, 2048, 4096]
bszs = [8, 16]
# bszs = [16]
for bsz in bszs:
    for seq_len in seq_lens:
        # time_test(bsz, seq_len, "flash")
        # time_test(bsz, seq_len, 128, 2, "flash_triton")
        # time_test(bsz, seq_len, 128, 2, "flash_cuda")
        time_test_with_var(bsz, seq_len, 128, 2, "triton")
        time_test_with_var(bsz, seq_len, 256, 4, "triton")
        # time_test(bsz, seq_len, 128, 2, "triton")
        # time_test(bsz, seq_len, 256, 4, "triton")
        # time_test(bsz, seq_len, 1024, 4, "triton")
        # benchmark_timer(bsz, seq_len, low_rank_dim=128)
        # time_test(bsz, seq_len, 128, 2, "loki")
        # time_test(bsz, seq_len, 256, 4, "loki")
