import torch
import time
import random
from typing import List, Optional, Union
from vllm import attention_ops, LLM, SamplingParams

n_head, d_embed = 32, 128

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask,
) -> torch.Tensor:
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
    if attn_mask is not None:
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    num_queries_per_kv: int,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
    scale: float,
    alibi_slopes,
) -> None:
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
    num_seqs = query.shape[0]

    block_tables = block_tables.cpu().tolist()
    context_lens = context_lens.cpu().tolist()
    for i in range(num_seqs):
        q = query[i].unsqueeze(0)
        block_table = block_tables[i]
        context_len = int(context_lens[i])

        keys = []
        values = []
        for j in range(context_len):
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
            k = k.reshape(num_kv_heads, head_size)
            keys.append(k)

            v = value_cache[block_number, :, :, block_offset]
            values.append(v)
        keys = torch.stack(keys, dim=0)
        values = torch.stack(values, dim=0)
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
            position_ids = torch.arange(context_len, device="cuda").int()
            alibi_bias = (position_ids - context_len + 1).float()
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
        output[i].copy_(out, non_blocking=True)

def create_block_tables(block_tables):
    max_block_len = max([len(block_table) for block_table in block_tables])
    paddings = [[-1] * (max_block_len - len(block_table)) for block_table in block_tables]
    new_block_tables = [block_table + padding for block_table, padding in zip(block_tables, paddings)]
    return torch.tensor(new_block_tables, dtype=torch.int)

def gen_dataset(n_prompt, n_completion, n_shared, batch_size, block_size):
    seq_len = n_prompt + n_completion
    n_blocks = seq_len * batch_size // block_size
    assert seq_len % block_size == 0
    assert n_shared % block_size == 0
    
    # make the last dimension always 16 bytes
    x = 16 // torch.tensor([]).element_size()
    assert d_embed % x == 0
    key_cache = torch.randn(size=(n_blocks, n_head, d_embed // x, block_size, x))
    value_cache = torch.randn(size=(n_blocks, n_head, d_embed, block_size))
    print(f'key_cache.shape:{key_cache.shape} value_cache.shape:{value_cache.shape}')
    
    # place shared prompt tokens
    n_shared_blocks = n_shared // block_size
    #n_shared_blocks = 0
    block_tables = [list(range(n_shared_blocks)) for _ in range(batch_size)]
    n_used_blocks = n_shared_blocks

    # place non-shared prompt tokens
    context_lens = [] 
    for i in range(batch_size):
        n = n_prompt // block_size - n_shared_blocks
        context_lens.append((n_shared_blocks + n) * block_size)
        for _ in range(n):
            block_tables[i].append(n_used_blocks)
            n_used_blocks += 1
    
    query = torch.randn(batch_size, n_head, d_embed)
    scale = float(1.0 / (d_embed ** 0.5))
    output = torch.empty(batch_size, n_head, d_embed)
    head_mapping = torch.repeat_interleave(torch.arange(n_head, dtype=torch.int32), 1)
    
    # warm up
    attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            head_mapping,
            scale,
            create_block_tables(block_tables),
            torch.tensor(context_lens, dtype=torch.int),
            block_size,
            max(context_lens),
            None)
    # Run the reference implementation.
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
        1,
        key_cache,
        value_cache,
        create_block_tables(block_tables),
        torch.tensor(context_lens, dtype=torch.int),
        scale,
        None,
    )
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3)
    
    return query, key_cache, value_cache, output, head_mapping, block_tables, context_lens, scale, n_used_blocks
    
@torch.inference_mode()
def run_paged_attention_tps(n_prompt, n_completion, n_shared, batch_size):
    print(f'\n[PagedAttn]')
    print(f'{torch.randn(1).device} dtype: {torch.randn(1).dtype}')
    print(f'prompt:{n_prompt} completion:{n_completion} shared:{n_shared} batch_size:{batch_size}')
    block_size = 32 # tokens in each block
    query, key_cache, value_cache, output, head_mapping, block_tables, context_lens, scale, n_used_blocks = \
        gen_dataset(n_prompt, n_completion, n_shared, batch_size, block_size)

    # start decoding
    ret = []
    latency = 0.0
     
    for t in range(n_prompt, n_prompt + n_completion):
        if t % block_size == 0:
            for i in range(batch_size):
                block_tables[i].append(n_used_blocks)
                n_used_blocks += 1
        for i in range(batch_size): context_lens[i] += 1
        block_tables_tensor = create_block_tables(block_tables)
        context_lens_tensor = torch.tensor(context_lens, dtype=torch.int)
        max_context_len = max(context_lens)
        torch.cuda.synchronize()
        start_time = time.perf_counter()
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            head_mapping,
            scale,
            block_tables_tensor,
            context_lens_tensor,
            block_size,
            max_context_len,
            None)
        torch.cuda.synchronize()
        end_time = time.perf_counter()
        latency += (end_time - start_time)
        ret.append(latency)
        #print(f'iter {t}, latency {end_time - start_time}')
    return ret

@torch.inference_mode()
def run_paged_attention_latency(seq_len, n_shared, batch_size):
    print(f'\n[PagedAttn]')
    print(f'{torch.randn(1).device} dtype: {torch.randn(1).dtype}')
    print(f'seq_len:{seq_len} n_shared:{n_shared}')
    
    block_size = 32 # tokens in each block
    query, key_cache, value_cache, output, head_mapping, block_tables, context_lens, scale, _ = \
        gen_dataset(seq_len, seq_len, n_shared, batch_size, block_size)
    
    block_tables_tensor = create_block_tables(block_tables)
    context_lens_tensor = torch.tensor(context_lens, dtype=torch.int)
    max_context_len = max(context_lens)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    n_repeat = 1000
    for _ in range(n_repeat):          
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            head_mapping,
            scale,
            block_tables_tensor,
            context_lens_tensor,
            block_size,
            max_context_len,
            None)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    t = (end_time - start_time)/n_repeat * 1e3
    print(f"PagedAttn: {t:.2f} ms")
    return t

def gen_dataset_tps(n_seqs, seq_len, n_shared):
    shared_tokens = [0] * n_shared
    res = []
    for _ in range(n_seqs):
        tokens = [i for i in shared_tokens] + [random.randint(10, 100) for _ in range(seq_len - n_shared)]
        res.append(tokens)
    return res

if __name__ == '__main__':
    torch.set_default_device('cuda')
    torch.set_default_dtype(torch.float16)
    latency = run_paged_attention_latency(512, 256, 32)
    print(latency)

