import torch
import torch.optim as optim
import time
import os
from transformers import BitsAndBytesConfig, AutoTokenizer, GenerationConfig
from twnm.models.twnm_pretrained_model import TWNM, TWNMConfig

# --- 配置区 ---
DECODER_MODEL_PATH = "<PATH_TO_TWNM>/assets/checkpoints/qwen2-audio-llm-extracted"
SFT_LORA_CKPT_PATH = "<PATH_TO_TWNM>/exp/SFT2/checkpoint-1251/pytorch_model.bin"
# --- 模拟 GRPOConfig 的参数 ---
NUM_GENERATIONS = 1 # 与你的 grpo.sh 中 --num_generations 保持一致
MAX_COMPLETION_LENGTH = 50 # 模拟生成长度
# --- 配置区结束 ---

def print_peak_memory(prefix: str):
    """打印当前时刻的峰值预留显存（单位：MB）"""
    # torch.cuda.max_memory_reserved() 返回的是bytes
    peak_mem_mb = torch.cuda.max_memory_reserved() / (1024 * 1024)
    print(f"{prefix}: Peak VRAM Reserved: {peak_mem_mb:.2f} MB")

def main():
    """
    最终测试脚本：完整模拟 GRPOTrainer 中 compute_loss 的动态执行逻辑，
    【包含并行生成 NUM_GENERATIONS 个回复】。
    """
    print(f"--- 最终测试：模拟 GRPOTrainer 并行生成逻辑 (num_generations={NUM_GENERATIONS}) ---")

    device = "cuda"
    print(f"检测到 GPU: {torch.cuda.get_device_name(0)}")
    torch.cuda.reset_peak_memory_stats(device)

    try:
        # 1. 初始化 (与之前相同)
        tokenizer = AutoTokenizer.from_pretrained(DECODER_MODEL_PATH, trust_remote_code=True)
        if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
        
        quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
        twnm_config = TWNMConfig(decoder_model_name=DECODER_MODEL_PATH, spatial_encoder_ckpt_path="none")
        
        model = TWNM(config=twnm_config, quantization_config=quantization_config).to(device)
        ref_model = TWNM(config=twnm_config, quantization_config=quantization_config).to(device)
        
        lora_weights = torch.load(SFT_LORA_CKPT_PATH, map_location="cpu")
        model.load_state_dict(lora_weights, strict=False)
        ref_model.load_state_dict(lora_weights, strict=False)
        
        ref_model.eval()
        
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        optimizer = optim.AdamW(trainable_params, lr=1e-5)
        print("所有组件初始化成功。当前基线显存已建立。")
        print_peak_memory("After Initialization")

        # 2. 构造单一样本数据 (与之前相同)
        audio_tensor = torch.randn(1, 2, 44100 * 5, dtype=torch.float32).to(device)
        task_text = "<|user|>Describe the sound.<|assistant|>"
        prompt = [task_text.replace("<|assistant|>", "<|assistant|> <AcousticTokens>")]
        target_text = ["The audio contains a bird chirping."] 

        # 3. 高效的 "Encode Once, Use Many Times" 逻辑
        print("\n--- 开始高效的 compute_loss 模拟 ---")

        # 3.1 【编码一次】: 只执行一次高成本的音频编码 (与之前相同)
        print("  - [1/5] 【编码一次】调用 forward_encoder 获取音频特征...")
        with torch.no_grad():
            audio_embeds, _ = model.forward_encoder(audio_tensor, prompt)
            audio_embeds = audio_embeds.to(model.decoder.dtype)
        print(f"    - 音频特征已缓存，形状: {audio_embeds.shape}")
        print_peak_memory("  -> After Audio Encoding")
        
        # 3.2 【并行生成】: 复用特征并生成 NUM_GENERATIONS 个回复
        print(f"  - [2/5] 【并行生成】准备 {NUM_GENERATIONS} 个并行的输入...")
        task_tokens = tokenizer(task_text, return_tensors="pt").to(device)
        generation_config = GenerationConfig(max_new_tokens=MAX_COMPLETION_LENGTH, do_sample=True, pad_token_id=tokenizer.pad_token_id)
        
        # MODIFIED: 将所有输入张量和prompt列表的批次维度扩大 NUM_GENERATIONS 倍
        input_ids_repeated = task_tokens["input_ids"].repeat_interleave(NUM_GENERATIONS, dim=0)
        attention_mask_repeated = task_tokens["attention_mask"].repeat_interleave(NUM_GENERATIONS, dim=0)
        audio_embeds_repeated = audio_embeds.repeat_interleave(NUM_GENERATIONS, dim=0)
        prompts_repeated = [p for p in prompt for _ in range(NUM_GENERATIONS)]
        
        print("  - [3/5] 调用 generate 生成多个回复...")
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=input_ids_repeated, # MODIFIED
                attention_mask=attention_mask_repeated, # MODIFIED
                encoder_hidden_states=audio_embeds_repeated, # MODIFIED
                prompt=prompts_repeated, # MODIFIED
                generation_config=generation_config
            )

        prompt_length = task_tokens["input_ids"].shape[1]
        completions_text = tokenizer.batch_decode(generated_ids[:, prompt_length:], skip_special_tokens=True)
        print(f"    - 成功生成 {len(completions_text)} 条并行回复。")
        print_peak_memory("  -> After Parallel Generation")

        # 3.3 【并行计算Log Probs】: 对所有生成的回复进行前向传播
        print("  - [4/5] 【并行计算】进行前向传播计算 loss...")
        
        # MODIFIED: 构造一个包含 NUM_GENERATIONS 个样本的输入字典
        tasks_repeated = [task_text for _ in range(NUM_GENERATIONS)]
        policy_model_inputs = { "task": tasks_repeated, "text": completions_text }
        
        # MODIFIED: 传入批次大小为 NUM_GENERATIONS 的数据
        policy_outputs = model.forward(policy_model_inputs, encoder_hidden_states=audio_embeds_repeated)
        policy_loss = policy_outputs["loss"]
        print("    - Policy model 的前向传播完成。")
        print_peak_memory("  -> After Policy Forward")
        
        with torch.no_grad():
            ref_outputs = ref_model.forward(policy_model_inputs, encoder_hidden_states=audio_embeds_repeated)
        print("    - Reference model 的前向传播完成。")
        print_peak_memory("  -> After Reference Forward")
        
        # 3.4 反向传播与优化 (与之前相同)
        print("  - [5/5] 执行反向传播和优化...")
        final_loss = policy_loss 

        optimizer.zero_grad()
        final_loss.backward()
        print("    - 反向传播成功！")
        print_peak_memory("  -> After Backward Pass")

        optimizer.step()
        print("    - 优化器更新成功！")
        print_peak_memory("  -> After Optimizer Step")

        print("\n\n--- 最终测试成功 ---")

    except torch.cuda.OutOfMemoryError as e:
        print("\n--- 最终测试失败: 捕获到 CUDA OutOfMemoryError ---")
        print("错误信息:", e)
        print_peak_memory("Peak memory before OOM")
    
    finally:
        print("\n--- 脚本结束 ---")
        print_peak_memory("Overall Peak Memory")

if __name__ == "__main__":
    main()