# Specification

# NVIDIA-H100-80GB
HW_CONFIG = {
    "bandwidth": 2 * 1024 * 1024 * 1024 * 1024, # 2000 GB / sec
    "peak_fp16_FLOPS": 756 * 1024 * 1024 * 1024 * 1024, # 756 TFLOPS
}

EXP_CONFIG = {
    "B": {
        "llama3.1-8b-instruct": 1,
        "llama3.2-1b-instruct": 1,
    },
    "T": {
        "llama3.1-8b-instruct": 32768,
        "llama3.2-1b-instruct": 32768,
    },
    "C": {
        "llama3.1-8b-instruct": 128,
        "llama3.2-1b-instruct": 128,
    },
    "W": {
        "llama3.1-8b-instruct": 32,
        "llama3.2-1b-instruct": 32,
    },
}

MODEL_CONFIG = {
    "L": {
        "llama3.1-8b-instruct": 32,
        "llama3.2-1b-instruct": 16,
    },
    "S": {
        "llama3.1-8b-instruct": {
            "prefill": EXP_CONFIG["T"]["llama3.1-8b-instruct"],
            "decode": 1,
        },
        "llama3.2-1b-instruct": {
            "prefill": EXP_CONFIG["T"]["llama3.2-1b-instruct"],
            "decode": 1,
        },
    },
    "D": {
        "llama3.1-8b-instruct": 4096,
        "llama3.2-1b-instruct": 2048,
    },
    "H": {
        "llama3.1-8b-instruct": 32,
        "llama3.2-1b-instruct": 32,
    },
    "K": {
        "llama3.1-8b-instruct": 8,
        "llama3.2-1b-instruct": 8,
    },
    "E": {
        "llama3.1-8b-instruct": 128,
        "llama3.2-1b-instruct": 64,
    },
    "V": {
        "llama3.1-8b-instruct": 14336,
        "llama3.2-1b-instruct": 8192,
    },
    "R": {
        "llama3.1-8b-instruct": 8,
        "llama3.2-1b-instruct": 8,
    },
}


def cal_latency(flops: int, bytes: int):
    
    t_mem = bytes / (HW_CONFIG["bandwidth"] * 0.9)
    t_compute = flops / (HW_CONFIG["peak_fp16_FLOPS"] * 0.7)
    latency = max(t_mem, t_compute)
    
    return latency


def forward_latency(k: str, t: str, kv_compressed: bool, kv_caching: bool):
    q_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    q_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    q_lat = cal_latency(q_flops, q_bytes)

    k_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    k_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    k_lat = cal_latency(k_flops, k_bytes)
    
    v_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    v_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    v_lat = cal_latency(v_flops, v_bytes)
    
    qkv_lat = q_lat + k_lat + v_lat

    if t == "prefill":
        if not kv_compressed:
            kv_write = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * MODEL_CONFIG["S"][k][t] * 2
        else:
            kv_write = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * EXP_CONFIG["C"][k] * 2
        kv_io = kv_write
    elif t == "decode":
        if not kv_compressed:
            kv_read = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * EXP_CONFIG["T"][k] * 2
            kv_write = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * MODEL_CONFIG["S"][k][t] * 2
        else:
            kv_read = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * EXP_CONFIG["C"][k] * 2
            kv_write = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * MODEL_CONFIG["S"][k][t] * 2
        kv_io = kv_read + kv_write
    else:
        raise NotImplementedError
    
    kv_fetch_lat = cal_latency(0, kv_io)
    
    qk_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k][t] * 2
    qk_lat = cal_latency(qk_flops, 0)
    
    av_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k][t] * 2
    av_lat = cal_latency(av_flops, 0)
    
    out_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k][t] * (MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * MODEL_CONFIG["D"][k] * 2
    out_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    out_lat = cal_latency(out_flops, out_bytes)
    
    attn_lat = qk_lat + av_lat + out_lat + kv_fetch_lat if kv_caching else qk_lat + av_lat + out_lat
    
    gate_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_lat = cal_latency(gate_flops, gate_bytes)
    
    up_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_lat = cal_latency(up_flops, up_bytes)
    
    down_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_lat = cal_latency(down_flops, down_bytes)
    
    ffn_lat = gate_lat + up_lat + down_lat
    
    layer_lat = qkv_lat + attn_lat + ffn_lat
    
    batch_tot_lat = layer_lat * MODEL_CONFIG["L"][k]
    
    return batch_tot_lat

def snapkv_ttft(k: str):
    q_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    q_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    q_lat = cal_latency(q_flops, q_bytes)

    k_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    k_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    k_lat = cal_latency(k_flops, k_bytes)
    
    v_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    v_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    v_lat = cal_latency(v_flops, v_bytes)
    
    qkv_lat = q_lat + k_lat + v_lat

    kv_write = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * EXP_CONFIG["C"][k] * 2
    kv_io = kv_write
    
    kv_fetch_lat = cal_latency(0, kv_io)
    
    qk_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k]["prefill"] * 2
    qk_lat = cal_latency(qk_flops, 0)
    
    av_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k]["prefill"] * 2
    av_lat = cal_latency(av_flops, 0)
    
    out_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * (MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * MODEL_CONFIG["D"][k] * 2
    out_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    out_lat = cal_latency(out_flops, out_bytes)
    
    attn_lat = qk_lat + av_lat + out_lat + kv_fetch_lat
    
    gate_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_lat = cal_latency(gate_flops, gate_bytes)
    
    up_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_lat = cal_latency(up_flops, up_bytes)
    
    down_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_lat = cal_latency(down_flops, down_bytes)
    
    ffn_lat = gate_lat + up_lat + down_lat
    
    layer_lat = qkv_lat + attn_lat + ffn_lat
    
    batch_tot_lat = layer_lat * MODEL_CONFIG["L"][k]
    
    return batch_tot_lat


def lookaheadkv_ttft(k: str, t: str):

    q_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    q_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    q_lora_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["R"][k] * (MODEL_CONFIG["D"][k] + MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * 2
    q_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    q_lora_A_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["R"][k] * 2
    q_lora_B_bytes = MODEL_CONFIG["R"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    
    q_lat = cal_latency(q_flops + q_pseudo_flops, q_bytes) + cal_latency(q_lora_flops, q_lora_A_bytes + q_lora_B_bytes)

    k_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    k_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    k_lora_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["R"][k] * (MODEL_CONFIG["D"][k] + MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k]) * 2
    k_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    k_lora_A_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["R"][k] * 2
    k_lora_B_bytes = MODEL_CONFIG["R"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    k_lat = cal_latency(k_flops + k_pseudo_flops, k_bytes) + cal_latency(k_lora_flops, k_lora_A_bytes + k_lora_B_bytes)
    
    v_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    v_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    v_lora_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["R"][k] * (MODEL_CONFIG["D"][k] + MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k]) * 2
    v_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    v_lora_A_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["R"][k] * 2
    v_lora_B_bytes = MODEL_CONFIG["R"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    v_lat = cal_latency(v_flops + v_pseudo_flops, v_bytes) + cal_latency(v_lora_flops, v_lora_A_bytes + v_lora_B_bytes)
    
    qkv_lat = q_lat + k_lat + v_lat

    if t == "prefill":
        kv_fetch = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * EXP_CONFIG["C"][k] * 2
    elif t == "decode":
        kv_read = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * EXP_CONFIG["C"][k] * 2
        kv_update = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * MODEL_CONFIG["S"][k][t] * 2
        kv_fetch = kv_read + kv_update
    else:
        raise NotImplementedError
    
    kv_fetch_lat = cal_latency(0, kv_fetch)
    
    qk_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k][t] * 2
    qk_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * (EXP_CONFIG["T"][k] + EXP_CONFIG["W"][k]) * MODEL_CONFIG["E"][k] * EXP_CONFIG["W"][k] * 2
    qk_lat = cal_latency(qk_flops + qk_pseudo_flops, 0)
    
    av_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k][t] * 2
    av_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * (EXP_CONFIG["T"][k] + EXP_CONFIG["W"][k]) * MODEL_CONFIG["E"][k] * EXP_CONFIG["W"][k] * 2
    av_lat = cal_latency(av_flops + av_pseudo_flops, 0)
    
    out_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k][t] * (MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * MODEL_CONFIG["D"][k] * 2
    out_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * (MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * MODEL_CONFIG["D"][k] * 2
    out_lora_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["R"][k] * (MODEL_CONFIG["D"][k] + MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * 2
    out_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    out_lora_A_bytes = MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * MODEL_CONFIG["R"][k] * 2
    out_lora_B_bytes = MODEL_CONFIG["R"][k] * MODEL_CONFIG["D"][k] * 2
    out_lat = cal_latency(out_flops + out_pseudo_flops, out_bytes) + cal_latency(out_lora_flops, out_lora_A_bytes + out_lora_B_bytes)
    
    attn_lat = qk_lat + av_lat + out_lat + kv_fetch_lat
    
    gate_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_lora_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["R"][k] * (MODEL_CONFIG["D"][k] + MODEL_CONFIG["V"][k]) * 2
    gate_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_lora_A_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["R"][k] * 2
    gate_lora_B_bytes = MODEL_CONFIG["R"][k] * MODEL_CONFIG["V"][k] * 2
    gate_lat = cal_latency(gate_flops + gate_pseudo_flops, gate_bytes) + cal_latency(gate_lora_flops, gate_lora_A_bytes + gate_lora_B_bytes)
    
    up_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_lora_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["R"][k] * (MODEL_CONFIG["D"][k] + MODEL_CONFIG["V"][k]) * 2
    up_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_lora_A_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["R"][k] * 2
    up_lora_B_bytes = MODEL_CONFIG["R"][k] * MODEL_CONFIG["V"][k] * 2
    up_lat = cal_latency(up_flops + up_pseudo_flops, up_bytes) + cal_latency(up_lora_flops, up_lora_A_bytes + up_lora_B_bytes)
    
    down_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k][t] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_lora_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["R"][k] * (MODEL_CONFIG["D"][k] + MODEL_CONFIG["V"][k]) * 2
    down_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_lora_A_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["R"][k] * 2
    down_lora_B_bytes = MODEL_CONFIG["R"][k] * MODEL_CONFIG["V"][k] * 2
    down_lat = cal_latency(down_flops + down_pseudo_flops, down_bytes) + cal_latency(down_lora_flops, down_lora_A_bytes + down_lora_B_bytes)
    
    ffn_lat = gate_lat + up_lat + down_lat
    
    layer_lat = qkv_lat + attn_lat + ffn_lat
    
    batch_tot_lat = layer_lat * MODEL_CONFIG["L"][k]
    
    return batch_tot_lat


def laq_ttft(k: str):
    # The First Round Eviction (SnapKV)
    q_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    q_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    q_lat = cal_latency(q_flops, q_bytes)

    k_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    k_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    k_lat = cal_latency(k_flops, k_bytes)
    
    v_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    v_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    v_lat = cal_latency(v_flops, v_bytes)
    
    qkv_lat = q_lat + k_lat + v_lat

    full_kv_write = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * MODEL_CONFIG["S"][k]["prefill"] * 2
    
    kv_fetch_lat = cal_latency(0, full_kv_write)
    
    qk_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k]["prefill"] * 2
    qk_lat = cal_latency(qk_flops, 0)
    
    av_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k]["prefill"] * 2
    av_lat = cal_latency(av_flops, 0)
    
    out_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * (MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * MODEL_CONFIG["D"][k] * 2
    out_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    out_lat = cal_latency(out_flops, out_bytes)
    
    attn_lat = qk_lat + av_lat + out_lat + kv_fetch_lat
    
    gate_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_lat = cal_latency(gate_flops, gate_bytes)
    
    up_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_lat = cal_latency(up_flops, up_bytes)
    
    down_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_lat = cal_latency(down_flops, down_bytes)
    
    ffn_lat = gate_lat + up_lat + down_lat
    
    layer_lat = qkv_lat + attn_lat + ffn_lat
    
    batch_tot_first_evict_lat = layer_lat * MODEL_CONFIG["L"][k]

    # Low-cost Generation (Decode)
    batch_tot_low_cost_generation_lat = forward_latency(k=k, t="decode", kv_compressed=True, kv_caching=True) * EXP_CONFIG["W"][k]

    # The Second Round Eviction (LAQ)
    q_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    q_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    q_lat = cal_latency(q_pseudo_flops, q_bytes)

    k_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    k_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    k_lat = cal_latency(k_pseudo_flops, k_bytes)
    
    v_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    v_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    v_lat = cal_latency(v_pseudo_flops, v_bytes)
    
    qkv_lat = q_lat + k_lat + v_lat

    full_kv_read = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * MODEL_CONFIG["S"][k]["prefill"] * 2
    evict_kv_write = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * EXP_CONFIG["C"][k] * 2
    
    kv_fetch_lat = cal_latency(0, full_kv_read + evict_kv_write)
    
    qk_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * (EXP_CONFIG["T"][k] + EXP_CONFIG["W"][k]) * MODEL_CONFIG["E"][k] * EXP_CONFIG["W"][k] * 2
    qk_lat = cal_latency(qk_pseudo_flops, 0)
    
    av_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * (EXP_CONFIG["T"][k] + EXP_CONFIG["W"][k]) * MODEL_CONFIG["E"][k] * EXP_CONFIG["W"][k] * 2
    av_lat = cal_latency(av_pseudo_flops, 0)
    
    out_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * (MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * MODEL_CONFIG["D"][k] * 2
    out_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    out_lat = cal_latency(out_pseudo_flops, out_bytes)
    
    attn_lat = qk_lat + av_lat + out_lat + kv_fetch_lat

    gate_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_lat = cal_latency(gate_pseudo_flops, gate_bytes)
    
    up_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_lat = cal_latency(up_pseudo_flops, up_bytes)
    
    down_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_lat = cal_latency(down_pseudo_flops, down_bytes)
    
    ffn_lat = gate_lat + up_lat + down_lat
    
    layer_lat = qkv_lat + attn_lat + ffn_lat
    
    batch_tot_second_evict_lat = layer_lat * MODEL_CONFIG["L"][k]

    batch_tot_lat = batch_tot_first_evict_lat + batch_tot_low_cost_generation_lat + batch_tot_second_evict_lat
    
    return batch_tot_lat

def speckv_ttft(k: str):
    # Draft Prefill + Generation
    batch_tot_draft_prefill_lat = forward_latency(k="llama3.2-1b-instruct", t="prefill", kv_compressed=False, kv_caching=True)
    batch_tot_draft_decode_lat = forward_latency(k="llama3.2-1b-instruct", t="decode", kv_compressed=False, kv_caching=True) * EXP_CONFIG["W"][k]

    # Target Prefill
    q_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    q_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    q_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    q_lat = cal_latency(q_flops + q_pseudo_flops, q_bytes)

    k_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    k_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    k_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    k_lat = cal_latency(k_flops + k_pseudo_flops, k_bytes)
    
    v_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    v_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * 2
    v_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["K"][k] * 2
    v_lat = cal_latency(v_flops + v_pseudo_flops, v_bytes)
    
    qkv_lat = q_lat + k_lat + v_lat

    kv_fetch = EXP_CONFIG["B"][k] * MODEL_CONFIG["K"][k] * MODEL_CONFIG["E"][k] * 2 * EXP_CONFIG["W"][k] * 2
    
    kv_fetch_lat = cal_latency(0, kv_fetch)
    
    qk_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k]["prefill"] * 2
    qk_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * (EXP_CONFIG["T"][k] + EXP_CONFIG["W"][k]) * MODEL_CONFIG["E"][k] * EXP_CONFIG["W"][k] * 2
    qk_lat = cal_latency(qk_flops + qk_pseudo_flops, 0)
    
    av_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * EXP_CONFIG["T"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["S"][k]["prefill"] * 2
    av_pseudo_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["H"][k] * (EXP_CONFIG["T"][k] + EXP_CONFIG["W"][k]) * MODEL_CONFIG["E"][k] * EXP_CONFIG["W"][k] * 2
    av_lat = cal_latency(av_flops + av_pseudo_flops, 0)
    
    out_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * (MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * MODEL_CONFIG["D"][k] * 2
    out_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * (MODEL_CONFIG["H"][k] * MODEL_CONFIG["E"][k]) * MODEL_CONFIG["D"][k] * 2
    out_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["E"][k] * MODEL_CONFIG["H"][k] * 2
    out_lat = cal_latency(out_flops + out_pseudo_flops, out_bytes)
    
    attn_lat = qk_lat + av_lat + out_lat + kv_fetch_lat
    
    gate_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    gate_lat = cal_latency(gate_flops + gate_pseudo_flops, gate_bytes)
    
    up_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    up_lat = cal_latency(up_flops + up_pseudo_flops, up_bytes)
    
    down_flops = EXP_CONFIG["B"][k] * MODEL_CONFIG["S"][k]["prefill"] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_pseudo_flops = EXP_CONFIG["B"][k] * EXP_CONFIG["W"][k] * MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_bytes = MODEL_CONFIG["D"][k] * MODEL_CONFIG["V"][k] * 2
    down_lat = cal_latency(down_flops + down_pseudo_flops, down_bytes)
    
    ffn_lat = gate_lat + up_lat + down_lat
    
    layer_lat = qkv_lat + attn_lat + ffn_lat
    
    batch_tot_target_prefill_lat = layer_lat * MODEL_CONFIG["L"][k]
    
    batch_tot_lat = batch_tot_draft_prefill_lat + batch_tot_draft_decode_lat + batch_tot_target_prefill_lat

    return batch_tot_lat


if __name__=="__main__":
    prefill_forward_pass_8b = forward_latency(k="llama3.1-8b-instruct", t="prefill", kv_compressed=False, kv_caching=False) * 1000
    print(f"Prefill forward pass 8B: {prefill_forward_pass_8b:,} ms")

    snapkv_ttft_8b = snapkv_ttft(k="llama3.1-8b-instruct") * 1000
    print(f"SnapKV TTFT 8B: {snapkv_ttft_8b:,} ms")

    lookaheadkv_ttft_8b = lookaheadkv_ttft(k="llama3.1-8b-instruct", t="prefill") * 1000
    print(f"LookaheadKV TTFT 8B: {lookaheadkv_ttft_8b:,} ms")
    
    laq_ttft_8b = laq_ttft(k="llama3.1-8b-instruct") * 1000
    print(f"LAQ TTFT 8B: {laq_ttft_8b:,} ms")

    speckv_ttft_8b = speckv_ttft(k="llama3.1-8b-instruct") * 1000
    print(f"SpecKV TTFT 8B: {speckv_ttft_8b:,} ms")

    print(f"SnapKV 8B additional overhead: {snapkv_ttft_8b - prefill_forward_pass_8b} ms / +{(snapkv_ttft_8b - prefill_forward_pass_8b) / prefill_forward_pass_8b * 100:,}%")    
    print(f"LookaheadKV 8B additional overhead: {lookaheadkv_ttft_8b - prefill_forward_pass_8b} ms / +{(lookaheadkv_ttft_8b - prefill_forward_pass_8b) / prefill_forward_pass_8b * 100:,}%")    
    print(f"LAQ 8B additional overhead: {laq_ttft_8b - prefill_forward_pass_8b} ms / +{(laq_ttft_8b - prefill_forward_pass_8b) / prefill_forward_pass_8b * 100:,}%")    
    print(f"SpecKV 8B additional overhead: {speckv_ttft_8b - prefill_forward_pass_8b} ms / +{(speckv_ttft_8b - prefill_forward_pass_8b) / prefill_forward_pass_8b * 100:,}%")    
