import torch
import time
from chunk_attn import Attention, Trace
from benchmark_attn_pytorch import gen_dataset

n_head, d_embed = 32, 128
is_cuda = torch.cuda.is_available()
device = torch.randn(1, device='cuda').device if is_cuda else torch.device('cpu')
torch.set_default_device(device)
dtype = torch.float16 if is_cuda else torch.float32
torch.set_default_dtype(dtype)

def run_chunk_attn_tps(n_prompt, n_completion, n_shared, batch_size):
    chunk_size = 64 
    print(f'\n[ChunkAttn]\nnum_of_threads:{torch.get_num_threads()} chunk_size:{chunk_size}')
    print(f'{device} {dtype}')
    print(f'prompt:{n_prompt} completion:{n_completion} shared:{n_shared} batch_size:{batch_size}')
        
    keys, values, qs, seqs = gen_dataset(n_head, d_embed, batch_size, n_prompt, n_shared) 
    q=torch.cat(qs, dim=1) 
    attn = Attention(n_head=n_head, d_embed=d_embed, chunk_size=chunk_size, memory_mb=8192,
                     dtype=dtype, device=device)
    for i in range(batch_size):
        attn.add_prompt(tokens=seqs[i], k=keys[i], v=values[i])
    
    new_tokens = list(range(batch_size))
    new_k = torch.randn((n_head, batch_size, d_embed))
    new_v = torch.randn((n_head, batch_size, d_embed))
    
    # warm up
    attn.forward(q=q)
    
    ret = []
    latency = 0.0 
    for i in range(n_prompt, n_prompt + n_completion):
        trace = Trace(record_kernel_t = True)
        output = attn.forward(q=q, trace=trace)
        latency += (trace.chunk_kernel_t + trace.seq_kernel_t)
        attn.append_completions(tokens=new_tokens, k=new_k, v=new_v)
        ret.append(latency/1e6) # in seconds
    return ret

def run_chunk_attn_latency(seq_len, n_shared, batch_size, chunk_size, partition=0):
    print(f'\n[ChunkAttn]\ninterop_threads:{torch.get_num_interop_threads()} intraop_threads:{torch.get_num_threads()}')
    print(f'{device} {dtype}')   
    print(f'seq_len:{seq_len} n_shared:{n_shared} chunk_size:{chunk_size} partition:{partition}')

    keys, values, qs, seqs = gen_dataset(n_head, d_embed, batch_size, seq_len, n_shared)
    n_seqs = len(qs)
    q=torch.cat(qs, dim=1) 
    attn = Attention(n_head=n_head, d_embed=d_embed, chunk_size=chunk_size, memory_mb=8192,
                     dtype=dtype, device=device)
    for i in range(n_seqs):
        attn.add_prompt(tokens=seqs[i], k=keys[i], v=values[i])   
    # warm up
    #attn.forward(q=q, partition=partition)
    
    if is_cuda: torch.cuda.synchronize()
    n_repeat = 1
    t_total = 0.0
    for step in range(n_repeat):
        trace = Trace(record_kernel_t = True)
        output = attn.forward(q=q, partition=partition, trace=trace)
        t_total += (trace.chunk_kernel_t + trace.seq_kernel_t)
    t = t_total/n_repeat/1e3    # in microseconds
    #kernel_t = sum([s.elapsed_time(e) for s, e in zip(start_events, end_events)]) / n_repeat
    #print(f'all: {t:.2f} ms, kernel: {kernel_t:.2f}')
    print(f'all: {t:.2f} ms')
    return t

def run_chunk_attn_cmp_chunk_seq(seq_len, n_shared, batch_size, chunk_size):    
    print(f'\n[ChunkAttn]\ninterop_threads:{torch.get_num_interop_threads()} intraop_threads:{torch.get_num_threads()}')
    print(f'{device} {dtype}')
    print(f'seq_len:{seq_len}, n_shared:{n_shared} chunk_size:{chunk_size}')
    
    keys, values, qs, seqs = gen_dataset(n_head, d_embed, batch_size, seq_len, n_shared)
    n_seqs = len(qs)
    q=torch.cat(qs, dim=1) 
    attn = Attention(n_head=n_head, d_embed=d_embed, chunk_size=chunk_size, memory_mb=8192,
                     dtype=dtype, device=device)
    for i in range(n_seqs):
        attn.add_prompt(tokens=seqs[i], k=keys[i], v=values[i])  
    # warm up
    attn.forward(q=q)
    n_repeat = 100
    
    # chunk first   
    t_total = 0.0
    for _ in range(n_repeat):
        trace = Trace(record_kernel_t=True)
        output = attn.forward(q=q, trace=trace)
        t_total += (trace.chunk_kernel_t + trace.seq_kernel_t)
    t_chunk = t_total/n_repeat/1e3
    
    # sequence first
    t_total = 0.0
    for _ in range(n_repeat):
        trace = Trace(record_kernel_t=True)
        output = attn.forward(q=q, partition=2, trace=trace)
        t_total += (trace.chunk_kernel_t + trace.seq_kernel_t)
    t_seq = t_total/n_repeat/1e3
       
    print(f'Chunk: {t_chunk:.2f} ms, Seq: {t_seq:.2f} ms')
    return (t_chunk, t_seq)

if __name__ == "__main__":
    #torch.set_num_threads(32)
    #run_chunk_attn_cmp_chunk_seq(256, 256, 32, 64)
    #run_chunk_attn_cmp_chunk_seq(320, 256, 32, 64)
    run_chunk_attn_latency(256, 256, 32, 64, 0)
    run_chunk_attn_latency(256, 256, 32, 64, 2)
