import torch
import triton
import triton.language as tl
from einops import rearrange
@triton.jit
def recover_rope_kernel(
    key_cache, # [bsz, 1, seq_len, rank_dim]
    indices, # [bsz, token_num]
    recover_weight, # [rank_dim, head_dim * head_num]
    sin,
    cos,
    key_recover, # [bsz, head_num, token_num, head_dim]
    TOKEN_NUM: tl.constexpr,
    BSZ: tl.constexpr,
    HEAD_NUM: tl.constexpr,
    RANK_DIM: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    RANK_BLOCK: tl.constexpr,
    SEQ_BLOCK: tl.constexpr,
    SEQ_BLOCK_NUM: tl.constexpr
):
    """"gather topk tokens + recover from latent space + interleave rope"""
    pid = tl.program_id(axis=0)
    bsz_idx = pid // (HEAD_NUM * SEQ_BLOCK_NUM)
    head_num_idx = pid // SEQ_BLOCK_NUM % HEAD_NUM
    seq_block_idx = pid % SEQ_BLOCK_NUM
    indices_offset = bsz_idx * TOKEN_NUM + seq_block_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
    token_indices = tl.load(indices + indices_offset, mask=indices_offset<TOKEN_NUM, other=SEQ_LEN)
    # if pid == 1:
    #     tl.device_print("bsz_idx: ", bsz_idx) 
    #     tl.device_print("head_num_idx:", head_num_idx)
    #     tl.device_print("seq_block_idx:", seq_block_idx)
    acc = tl.zeros([SEQ_BLOCK, HEAD_DIM], dtype=tl.float32)
    for rank_start in range(0, RANK_DIM, RANK_BLOCK):
        rank_dim_offset = rank_start + tl.arange(0, RANK_BLOCK)
        key_offset = bsz_idx * (RANK_DIM * SEQ_LEN) + token_indices[:, None] * RANK_DIM + rank_dim_offset[None, :]
        mask = (token_indices[:, None] < SEQ_LEN) & (rank_dim_offset[None, :] < RANK_DIM)
        key_block = tl.load(key_cache + key_offset, mask=mask, other=0.0)
        weight_col_offset = head_num_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
        weight_row_offset = (rank_start + tl.arange(0, RANK_BLOCK)) * HEAD_DIM * HEAD_NUM
        weight_offset = weight_row_offset[:, None] + weight_col_offset[None, :]
        weight = tl.load(recover_weight + weight_offset)
        acc = tl.dot(key_block, weight, acc)
    # rope
    acc = acc.to(dtype=tl.float16)
    emb_offset = token_indices[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM)
    sin_select = tl.load(sin + emb_offset)
    cos_select = tl.load(cos + emb_offset)
    acc_2 = tl.reshape(acc, (SEQ_BLOCK * HEAD_DIM // 2, 2))
    acc_first, acc_second = tl.split(acc_2)
    acc_interleave = tl.join(-acc_second, acc_first).reshape(SEQ_BLOCK, HEAD_DIM)
    result = acc * cos_select + acc_interleave * sin_select
    seq_offset = seq_block_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
    head_dim_offset = tl.arange(0, HEAD_DIM)
    result_offset = bsz_idx * (HEAD_DIM * TOKEN_NUM * HEAD_NUM) + head_num_idx * (HEAD_DIM * TOKEN_NUM) + seq_offset[:, None] * HEAD_DIM + head_dim_offset[None, :]
    result_mask = seq_offset[:, None] < TOKEN_NUM
    tl.store(key_recover + result_offset, result, mask=result_mask)
def recover_rope(
    key_cache,
    indices,
    recover_weight,
    sin,
    cos,
    head_num,
    head_dim,
    result,
    offset
):
    bsz = key_cache.shape[0]
    seq_len = key_cache.shape[2]
    rank_dim = key_cache.shape[3]
    token_num = indices.shape[1]
    # result = torch.empty([bsz, head_num, token_num, head_dim], dtype=torch.float16, device="cuda:0")
    # print(f"bsz: {bsz}, head_num: {head_num}, seq_len: {seq_len}, rank_dim: {rank_dim}, token_num: {token_num}, head_dim: {head_dim}")
    result = result.narrow(2, offset, token_num)
    SEQ_BLOCK = 128
    SEQ_BLOCK_NUM = (token_num + SEQ_BLOCK - 1) // SEQ_BLOCK
    RANK_BLOCK = 128
    TOTAL_BLOCK = SEQ_BLOCK_NUM * bsz * head_num
    # print(f"SEQ_BLOCK: {SEQ_BLOCK}, SEQ_BLOCK_NUM: {SEQ_BLOCK_NUM}, RANK_BLOCK: {RANK_BLOCK}, TOTAL_BLOCK: {TOTAL_BLOCK}")
    recover_rope_kernel[(TOTAL_BLOCK,)](key_cache, indices, recover_weight, sin, cos, result, token_num, bsz, head_num, rank_dim, head_dim, seq_len, RANK_BLOCK, SEQ_BLOCK, SEQ_BLOCK_NUM)
    # return result
def gen_embeds(head_dim, base=10000, max_seq_len=32768):
    inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() * (1 / head_dim)))
    t = torch.arange(max_seq_len)
    freqs = torch.einsum("i,j->ij", t, inv_freq)
    emb = torch.cat((freqs, freqs), dim=-1)
    return emb.sin().to(dtype=torch.float16).cuda(), emb.cos().to(dtype=torch.float16).cuda()

# bsz = 1
# seq = 2048
# head_num = 8
# rank_dim = 512
# head_dim = 128
# token_num = seq // 8 # 1/8 sparsity
# key_cache = torch.randn([bsz, 1, seq, rank_dim], dtype=torch.float16).cuda()
# indices = torch.randint(0, seq, [bsz, token_num]).cuda()
# recover_weight = torch.randn([rank_dim, head_num * head_dim], dtype=torch.float16).cuda()
# sin, cos = gen_embeds(head_dim)
# key_recover = recover_rope(key_cache, indices, recover_weight, sin, cos, head_num, head_dim)
# print(key_recover.shape)
# print(key_recover)
def rotate_interleaved(x):
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ...(d two)", two=2)
def get_golden(
    key_cache,
    indices,
    recover_weight,
    sin,
    cos,
    head_num,
    head_dim,
):
    key_cache = key_cache.expand([-1, head_num, -1, -1])
    bsz = key_cache.shape[0]
    results = []
    print(indices.shape)
    for b in range(bsz):
        batch_results = []
        for h in range(head_num):
            recovered = torch.matmul(key_cache[b, h, indices[b], :], recover_weight[:, h * head_dim: h * head_dim + head_dim])
            rope = recovered * cos[indices[b], :] + rotate_interleaved(recovered) * sin[indices[b], :]
            batch_results.append(rope)
        results.append(torch.stack(batch_results, dim=0))
    result = torch.stack(results, dim=0)
    return result
# golden = get_golden(key_cache, indices, recover_weight, sin, cos, head_num, head_dim)
# # print(golden.shape)
# # print(golden)
# def is_similar(a, b):
#     return a.shape == b.shape and torch.allclose(a.cpu(), b.cpu(), 0.05, 0.05)
# if is_similar(key_recover, golden):
#     print("Pass!")
# else:
#     print("key_recover:", key_recover.shape, key_recover)
#     print("golden:", golden.shape, golden)

configs = []
ref_lib = 'cuBLAS'
configs.append(
    triton.testing.Benchmark(
        x_names=["SEQ_LEN"],  # Argument names to use as an x-axis for the plot
        x_vals=[2**i for i in range(10, 16)],  # Different possible values for `x_name`
        line_arg="provider",  # Argument name whose value corresponds to a different line in the plot
        # Possible values for `line_arg`
        # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
        line_vals=[ref_lib.lower(), "triton"],  # Label name for the lines
        line_names=[ref_lib, "Triton"],  # Line styles
        styles=[("green", "-"), ("blue", "-")],
        ylabel="Ms",  # Label name for the y-axis
        plot_name="matmul-performance-" +
        ("fp16"),  # Name for the plot, used also as a file name for saving the plot.
        args={
            "H": 8,
            "BATCH": 8,
            "HEAD_DIM": 128,
        },
    ))
@triton.testing.perf_report(configs)
def benchmark(BATCH, HEAD_DIM, H, SEQ_LEN, provider):

    bsz = BATCH
    seq = SEQ_LEN
    head_num = H
    rank_dim = 128
    head_dim = 128
    token_num = seq // 8 # 1/8 sparsity
    key_cache = torch.randn([bsz, 1, seq, rank_dim], dtype=torch.float16).cuda()
    indices = torch.randint(0, seq, [bsz, token_num]).cuda()
    recover_weight = torch.randn([rank_dim, head_num * head_dim], dtype=torch.float16).cuda()
    sin, cos = gen_embeds(head_dim)
    
    quantiles = [0.5, 0.2, 0.8]
    if provider == ref_lib.lower():
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: get_golden(key_cache, indices, recover_weight, sin, cos, head_num, head_dim), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: recover_rope(key_cache, indices, recover_weight, sin, cos, head_num, head_dim), quantiles=quantiles)
    # perf = lambda ms: (2 * BSZ * SEQ_LEN * K_LOW_RANK * HEAD_NUM * HEAD_DIM + 
    #                    2 * BATCH * SEQ_LEN * HEAD_NUM * HEAD_DIM + 
    #                    2 * BSZ * HEAD_NUM * HEAD_DIM * SEQ_LEN)  * 1e-12 / (ms * 1e-3)
    # return perf(ms), perf(max_ms), perf(min_ms)
    return ms, max_ms, min_ms

# benchmark.run(show_plots=True, print_data=True, save_path="./result")