import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import matplotlib.pyplot as plt
import numpy as np
import gc

# 配置
# MODEL_NAME = "/remote-home1/yklin/models/vicuna-7b-v1.5"  # 也可以使用其他Llama2变体
MODEL_NAME = "/inspire/hdd/project/embodied-multimodality/public/csli/models/Llama-2-7b-chat-hf"
# INPUT_TEXT = "what is the meaning of life?"
MAX_TOKENS = 300
# BATCH_SIZES = [1, 2, 4, 8, 16]
BATCH_SIZES = [1, 2, 4]
# BATCH_SIZES = [8, 16]
NUM_TRIALS = 3  # 每个批大小运行次数取平均

# 加载模型和tokenizer（确保有访问权限）
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token  # 设置填充token为eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id  # 确保填充token ID正确
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2"  # 使用Flash Attention 2
).to('cuda:0')
model.config.use_cache = True
model.config.torch_dtype = torch.float16
model.eval()
import json
with open('data/WaterBench/2-1_longform_qa.jsonl', 'r') as f:
    data = [json.loads(line) for line in f]
# 测试函数
item = data[10]
context = item['context']
input = item['input']
template="You are a helpful assistant, please answer the following question within 300 words:\n{context}\n{input}"
INPUT_TEXT = template.format(context=context, input=input)
print('Input text length: ',len(tokenizer.encode(INPUT_TEXT)))
print('='*50)
def benchmark_generation(batch_size):
    # 准备输入（复制相同输入到批次中）
    input_texts = [INPUT_TEXT] * batch_size
    inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(model.device)
    
    # 预热（第一次运行可能较慢）
    if batch_size == BATCH_SIZES[0]:
        _ = model.generate(**inputs, max_new_tokens=5)
    
    # 显存监控
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    start_mem = torch.cuda.memory_allocated()
    
    # 计时
    start_time = time.time()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_TOKENS,
            do_sample=True,
            top_k=50,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    torch.cuda.synchronize()
    elapsed = time.time() - start_time
    peak_mem = torch.cuda.max_memory_allocated()
    peak_mem = peak_mem / (1024**2)
    # avg_mem = (peak_mem + start_mem) / 2 / (1024 ** 2)  # MB

    # 计算指标
    total_tokens = sum(len(seq) - len(inputs['input_ids'][0]) for seq in outputs)
    avg_tokens = total_tokens / batch_size
    avg_tokens_per_second = avg_tokens / elapsed
    tokens_per_second = total_tokens / elapsed

    gc.collect()
    torch.cuda.empty_cache()

    return elapsed, tokens_per_second, avg_tokens_per_second, peak_mem # MB

# 运行测试
results = {}
for bs in BATCH_SIZES:
    print(f"Testing batch size: {bs}")
    times = []
    tps = []
    avg_tps = []
    avg_mems = []
    for _ in range(NUM_TRIALS):
        print(f"Trial {_ + 1}/{NUM_TRIALS} for batch size {bs}...")
        elapsed, tps_val, avg_tps_val, avg_mem_val = benchmark_generation(bs)
        times.append(elapsed)
        tps.append(tps_val)
        avg_tps.append(avg_tps_val)
        avg_mems.append(avg_mem_val)

    results[bs] = {
        'avg_time': np.mean(times),
        'std_time': np.std(times),
        'avg_tps': np.mean(tps),
        'std_tps': np.std(tps),
        'avg_tokens': np.mean(avg_tps),
        'avg_tokens_std': np.std(avg_tps),
        'peak_mem': np.mean(avg_mems),
        'avg_mem_std': np.std(avg_mems)
    }
    print(results)
    print(f"Batch {bs}: {results[bs]['avg_tps']:.2f} tokens/sec")

# 可视化
# plt.figure(figsize=(12, 6))

# # 吞吐量图表
# plt.subplot(1, 4, 1)
# x = BATCH_SIZES
# y = [results[bs]['avg_tps'] for bs in BATCH_SIZES]
# y_err = [results[bs]['std_tps'] for bs in BATCH_SIZES]
# plt.bar(range(len(x)), y, yerr=y_err, tick_label=x, color='skyblue')
# plt.xlabel('Batch Size')
# plt.ylabel('Tokens per Second')
# plt.title('Throughput vs Batch Size')

# # 延迟图表
# plt.subplot(1, 4, 2)
# y_time = [results[bs]['avg_time'] for bs in BATCH_SIZES]
# y_time_err = [results[bs]['std_time'] for bs in BATCH_SIZES]
# plt.bar(range(len(x)), y_time, yerr=y_time_err, tick_label=x, color='lightgreen')
# plt.xlabel('Batch Size')
# plt.ylabel('Generation Time (s)')
# plt.title('Latency vs Batch Size')

# # 平均生成token数图表
# plt.subplot(1, 4, 3)
# y_tokens = [results[bs]['avg_tokens'] for bs in BATCH_SIZES]
# y_tokens_err = [results[bs]['avg_tokens_std'] for bs in BATCH_SIZES]
# plt.bar(range(len(x)), y_tokens, yerr=y_tokens_err, tick_label=x, color='salmon')
# plt.xlabel('Batch Size')
# plt.ylabel('Average Generated Tokens')
# plt.title('Average Generated Tokens vs Batch Size')

# # 平均显存使用图表
# plt.subplot(1, 4, 4)
# y_mem = [results[bs]['avg_mem'] for bs in BATCH_SIZES]
# y_mem_err = [results[bs]['avg_mem_std'] for bs in BATCH_SIZES]
# plt.bar(range(len(x)), y_mem, yerr=y_mem_err, tick_label=x, color='lightcoral')
# plt.xlabel('Batch Size')
# plt.ylabel('Average Memory Usage (MB)')
# plt.title('Average Memory Usage vs Batch Size')

# plt.tight_layout()
# plt.savefig('yukang/speed_stat/llama2_parallel_stat_3-2.png')
# plt.show()

# 打印表格结果
print("\nBenchmark Results:")
print(f"{'Batch Size':>10} | {'Time (s)':>10} | {'Tps':>10} | {'Tokens':>10} | {'Memory (MB)':>10}")
print("-" * 50)
for bs in BATCH_SIZES:
    print(f"{bs:>10} | {results[bs]['avg_time']:>10.2f} | {results[bs]['avg_tps']:>10.2f} | {results[bs]['avg_tokens']:>10.2f} | {results[bs]['peak_mem']:>10.2f}")