import torch
from utils import Timer
from transformers.models.llama.modeling_llama import rotate_half, repeat_kv
timer = Timer()

def standard_attention_decode(query, key, value, cos, sin):
    bsz, num_head, q_len, head_dim = query.shape
    bsz, num_kv_head, kv_len, head_dim = value.shape
    query_cos = cos[0, :]
    query_sin = sin[0, :]
    timer.start('RoPE')
    query = query * query_cos + rotate_half(query) * query_sin
    query = query * query_cos + rotate_half(query) * query_sin
    timer.end('RoPE')
    key = repeat_kv(key, num_head // num_kv_head)
    value = repeat_kv(value, num_head // num_kv_head)
    timer.start('p=qk')
    attn = torch.matmul(query, key.transpose(-2, -1))
    attn = attn / torch.sqrt(torch.tensor(query.size(-1), dtype=torch.float32, device=query.device))
    timer.end('p=qk')
    timer.start('Softmax')
    attn = torch.nn.functional.softmax(attn, dim=-1)
    timer.end('Softmax')
    timer.start('y=pv')
    output = torch.matmul(attn, value)
    timer.end('y=pv')
    return output


def reconstruct_attention_decode(query, lowrank_key, value, cos, sin, recover):
    bsz, num_head, q_len, head_dim = query.shape
    bsz, num_kv_head, kv_len, head_dim = value.shape
    query_cos = cos[0, :]
    query_sin = sin[0, :]
    timer.start('RoPE_q')
    query = query * query_cos + rotate_half(query) * query_sin
    timer.end('RoPE_q')
    timer.start('Reconstruct')
    key = torch.matmul(lowrank_key, recover).view(bsz, -1, num_kv_head, head_dim).transpose(1, 2)
    timer.end('Reconstruct')
    timer.start('RoPE_k')
    key = key * cos + rotate_half(key) * sin
    timer.end('RoPE_k')
    key = repeat_kv(key, num_head // num_kv_head)
    value = repeat_kv(value, num_head // num_kv_head)
    timer.start('p=qk')
    attn = torch.matmul(query, key.transpose(-2, -1))
    attn = attn / torch.sqrt(torch.tensor(query.size(-1), dtype=torch.float32, device=query.device))
    timer.end('p=qk')
    timer.start('Softmax')
    attn = torch.nn.functional.softmax(attn, dim=-1)
    timer.end('Softmax')
    timer.start('y=pv')
    output = torch.matmul(attn, value)
    timer.end('y=pv')
    return output

def reconstruct_topk_attention_decode(query, lowrank_key, value, cos, sin, recover, sparsity):
    bsz, num_head, q_len, head_dim = query.shape
    bsz, num_kv_head, kv_len, head_dim = value.shape
    lowrank_query = torch.matmul(query.transpose(1, 2).view(bsz, num_head // num_kv_head, -1).mean(1), recover.transpose(0, 1)).view(bsz, q_len, -1)
    topk = kv_len // sparsity
    query_cos = cos[0, :]
    query_sin = sin[0, :]
    timer.start('TopK')
    approx_scores = torch.matmul(lowrank_query, lowrank_key.transpose(1, 2)).view(bsz, kv_len)
    _, topk_indices = torch.topk(approx_scores, topk, dim=-1)
    timer.end('TopK')
    select_lowrank_key = lowrank_key[:, topk_indices[0], :]
    value = value[:, :, topk_indices[0], :]
    timer.start('RoPE_q')
    query = query * query_cos + rotate_half(query) * query_sin
    timer.end('RoPE_q')
    timer.start('Reconstruct')
    key = torch.matmul(select_lowrank_key, recover).view(bsz, -1, num_kv_head, head_dim).transpose(1, 2)
    timer.end('Reconstruct')
    select_cos = cos[:key.shape[2], :]
    select_sin = sin[:key.shape[2], :]
    # import pdb;pdb.set_trace()
    timer.start('RoPE_k')
    key = key * select_cos + rotate_half(key) * select_sin
    timer.end('RoPE_k')
    key = repeat_kv(key, num_head // num_kv_head)
    value = repeat_kv(value, num_head // num_kv_head)
    timer.start('p=qk')
    attn = torch.matmul(query, key.transpose(-2, -1))
    attn = attn / torch.sqrt(torch.tensor(query.size(-1), dtype=torch.float32, device=query.device))
    timer.end('p=qk')
    timer.start('Softmax')
    attn = torch.nn.functional.softmax(attn, dim=-1)
    timer.end('Softmax')
    timer.start('y=pv')
    output = torch.matmul(attn, value)
    timer.end('y=pv')
    return output


def time_test(seq_len):
    bsz = 8
    num_head = 32
    num_kv_head = 8
    kv_len = seq_len
    low_rank_dim = 128
    head_dim = 128
    sparsity = 8

    query = torch.randn(bsz, num_head, 1, head_dim, device='cuda', dtype=torch.float16)
    key = torch.randn(bsz, num_kv_head, kv_len, head_dim, device='cuda', dtype=torch.float16)
    lowrank_key = torch.randn(bsz, kv_len, low_rank_dim, device='cuda', dtype=torch.float16)
    value = torch.randn(bsz, num_kv_head, kv_len, head_dim, device='cuda', dtype=torch.float16)
    
    cos = torch.randn(kv_len, head_dim, device='cuda', dtype=torch.float16)
    sin = torch.randn(kv_len, head_dim, device='cuda', dtype=torch.float16)

    recover = torch.randn(low_rank_dim, num_kv_head * head_dim, device='cuda', dtype=torch.float16)

    for i in range(10):
        reconstruct_topk_attention_decode(query, lowrank_key, value, cos, sin, recover, sparsity)
    timer.reset()
    n_repeat = 100
    for i in range(n_repeat):
        standard_attention_decode(query, key, value, cos, sin)
    standard_time = timer.summary()
    print("Sandard attention")
    timer.structure_print()
    timer.reset()
    for i in range(n_repeat):
        reconstruct_attention_decode(query, lowrank_key, value, cos, sin, recover)
    timer.time_merge(['RoPE_q', 'RoPE_k'], 'RoPE')
    reconstruct_time = timer.summary()
    print("Reconstruct Attention")
    timer.structure_print()
    timer.reset()
    for i in range(n_repeat):
        reconstruct_topk_attention_decode(query, lowrank_key, value, cos, sin, recover, sparsity)
    timer.time_merge(['RoPE_q', 'RoPE_k'], 'RoPE')
    reconstruct_topk_time = timer.summary()
    print("Select Reconstruct Attention")
    timer.structure_print()
    timer.reset()
    # print('standard_time: ', standard_time)
    # print('reconstruct_time: ', reconstruct_time)
    # print('reconstruct_topk_time: ', reconstruct_topk_time)

    return {'standard_time':standard_time, 'reconstruct_time': reconstruct_time, 'reconstruct_topk_time': reconstruct_topk_time}

def plot(seq_len_list, full_data, pre_data, sals_data):
    # 修改版：合并图例、改成方形图、调大字体

    import numpy as np
    import matplotlib.pyplot as plt

    # x轴
    sequence_lengths = seq_len_list
    x = np.arange(len(sequence_lengths))
    bar_width = 0.25

    # 调整推理时间
    # full_rope = np.array(full_data['RoPE'])
    full_pqk = np.array(full_data['p=qk'])
    full_softmax = np.array(full_data['Softmax'])
    full_pv = np.array(full_data['y=pv'])

    pre_recover = np.array(pre_data['Reconstruct'])
    pre_rope = np.array(pre_data['RoPE'])
    pre_pqk = np.array(pre_data['p=qk'])
    pre_softmax = np.array(pre_data['Softmax'])
    pre_pv = np.array(pre_data['y=pv'])

    sals_topk = np.array(sals_data['TopK'])
    sals_recover = np.array(sals_data['Reconstruct'])
    sals_rope = np.array(sals_data['RoPE'])
    sals_pqk = np.array(sals_data['p=qk'])
    sals_softmax = np.array(sals_data['Softmax'])
    sals_pv = np.array(sals_data['y=pv'])

    # 绘图
    fig, ax = plt.subplots(figsize=(12, 8))  # 方形图

    # Full Attention 堆叠（没有recover）
    # 统一管理颜色变量
    COLOR_PV = '#ef9a9a'
    COLOR_SOFTMAX = '#a5d6a7'
    COLOR_PQK = '#ffb74d'
    COLOR_RECOVER = '#d1c4e9'
    COLOR_TOPK = '#b39ddb'

    full_bottom = np.zeros_like(sequence_lengths, dtype=float)
    ax.bar(x - bar_width, full_pv, width=bar_width, bottom=full_bottom, color=COLOR_PV)
    full_bottom += full_pv 
    ax.bar(x - bar_width, full_softmax, width=bar_width, bottom=full_bottom, color=COLOR_SOFTMAX)
    full_bottom += full_softmax
    ax.bar(x - bar_width, full_pqk, width=bar_width, bottom=full_bottom, color=COLOR_PQK)
    full_bottom += full_pqk 
    # ax.bar(x - bar_width, full_rope, width=bar_width, bottom=full_bottom, color='#1f77b4')
    # full_bottom += full_rope

    # Pre-RoPE Compression 堆叠
    pre_bottom = np.zeros_like(sequence_lengths, dtype=float)
    ax.bar(x, pre_pv, width=bar_width, bottom=pre_bottom, color=COLOR_PV)
    pre_bottom += pre_pv
    ax.bar(x, pre_softmax, width=bar_width, bottom=pre_bottom, color=COLOR_SOFTMAX)
    pre_bottom += pre_softmax
    ax.bar(x, pre_pqk, width=bar_width, bottom=pre_bottom, color=COLOR_PQK)
    pre_bottom += pre_pqk
    # ax.bar(x, pre_rope, width=bar_width, bottom=pre_bottom, color='#1f77b4')
    # pre_bottom += pre_rope
    # ax.bar(x, pre_recover, width=bar_width, bottom=pre_bottom, color='#9467bd')
    ax.bar(x, pre_recover + pre_rope, width=bar_width, bottom=pre_bottom, color=COLOR_RECOVER)

    # SALS 堆叠
    sals_bottom = np.zeros_like(sequence_lengths, dtype=float)
    ax.bar(x + bar_width, sals_pv, width=bar_width, bottom=sals_bottom, color=COLOR_PV)
    sals_bottom += sals_pv
    ax.bar(x + bar_width, sals_softmax, width=bar_width, bottom=sals_bottom, color=COLOR_SOFTMAX)
    sals_bottom += sals_softmax
    ax.bar(x + bar_width, sals_pqk, width=bar_width, bottom=sals_bottom, color=COLOR_PQK)
    sals_bottom += sals_pqk
    # ax.bar(x + bar_width, sals_rope, width=bar_width, bottom=sals_bottom, color='#1f77b4')
    # sals_bottom += sals_rope
    # ax.bar(x + bar_width, sals_recover, width=bar_width, bottom=sals_bottom, color='#9467bd')
    ax.bar(x + bar_width, sals_recover + sals_rope, width=bar_width, bottom=sals_bottom, color=COLOR_RECOVER)
    # sals_bottom += sals_recover
    sals_bottom += sals_recover + sals_rope
    ax.bar(x + bar_width, sals_topk, width=bar_width, bottom=sals_bottom, color=COLOR_TOPK)
    # 设置
    ax.set_xticks(x)
    ax.set_xticklabels(['8k', '16k', '32k', '64k', '128k'], fontsize=16)
    xticklabels = ax.get_yticklabels()
    for label in xticklabels:
        label.set_fontsize(16)
    # ax.set_xlabel('Sequence Length (Standard Attention, SALS w/o sparse and SALS)', fontsize=16, fontweight='bold')
    ax.set_ylabel('Inference Time (ms)', fontsize=16, fontweight='bold')
    ax.set_title('Inference Time per Module Across Methods', fontsize=20, fontweight='bold')
    ax.grid(True, linestyle='--', alpha=0.7)

    # 合并图例：只解释颜色，不区分模型
    handles = [
        plt.Rectangle((0, 0), 1, 1, color=COLOR_TOPK, label='TopK (only SALS)'),
        plt.Rectangle((0, 0), 1, 1, color=COLOR_RECOVER, label='Recover'),
        # plt.Rectangle((0, 0), 1, 1, color='#1f77b4', label='RoPE'),
        plt.Rectangle((0, 0), 1, 1, color=COLOR_PQK, label='p=qk'),
        plt.Rectangle((0, 0), 1, 1, color=COLOR_SOFTMAX, label='Softmax'),
        plt.Rectangle((0, 0), 1, 1, color=COLOR_PV, label='y=pv')
    ]
    ax.legend(handles=handles, ncol=2, loc='upper left', bbox_to_anchor=(0, 1), fontsize=16)

    plt.tight_layout(rect=[0, 0, 1, 0.9])

    # 保存为高分辨率PDF
    # plt.savefig('attention_timing_bar_stacked_square.pdf', format='pdf', dpi=600)
    plt.savefig("attention_timing_bar_stacked_square.svg", format='svg', bbox_inches='tight')

    plt.show()

def post_process(seq_len, summary:dict, target_data:dict):
    for key in summary.keys():
        if key not in target_data:
            target_data[key] = []
        target_data[key].append(summary[key]['mean'] * 1000)



from tqdm import tqdm
import os
import json


if __name__ == '__main__':
    seq_list = [8192, 16384, 32768, 65536, 131072]
    full_data = {}
    pre_data = {}
    sals_data = {}
    data_file = 'attention_time_data.json'

    if os.path.exists(data_file):
        # 若文件存在，从本地读取数据
        with open(data_file, 'r') as f:
            data = json.load(f)
        full_data = data['full_data']
        pre_data = data['pre_data']
        sals_data = data['sals_data']
    else:
        # 若文件不存在，重新运行测试并保存数据
        for seq in seq_list:
            print(seq)
            summary = time_test(seq)
            post_process(seq, summary['standard_time'], full_data)
            post_process(seq, summary['reconstruct_time'], pre_data)
            post_process(seq, summary['reconstruct_topk_time'], sals_data)
        
        # 保存数据到本地
        data = {
            'full_data': full_data,
            'pre_data': pre_data,
            'sals_data': sals_data
        }
        with open(data_file, 'w') as f:
            json.dump(data, f)

    plot(seq_list, full_data, pre_data, sals_data)

        
    
    # plot(time_data_dict)