import torch
import argparse
import torchaudio
from transformers import BitsAndBytesConfig
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
import os

# 关键：导入 TWNM 模型定义
from twnm.models.twnm_pretrained_model import TWNM, TWNMConfig

def preprocess_audio(wav_path: str, target_sr: int = 44100, max_length_seconds: int = 20) -> torch.Tensor:
    """
    加载、重采样、填充/截断音频文件，确保输出为双通道。
    """
    waveform, original_sr = torchaudio.load(wav_path)
    if original_sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
        waveform = resampler(waveform)
    if waveform.shape[0] == 1:
        waveform = waveform.repeat(2, 1)
    elif waveform.shape[0] > 2:
        waveform = waveform[:2, :]
    max_length = target_sr * max_length_seconds
    if waveform.shape[1] > max_length:
        waveform = waveform[:, :max_length]
    else:
        padding = max_length - waveform.shape[1]
        waveform = torch.nn.functional.pad(waveform, (0, padding))
    return waveform

def main():
    ### 1. 修改命令行参数 ###
    # 不再需要 --prompt，改为 --output_file 来指定结果保存路径
    parser = argparse.ArgumentParser(description="Batch inference script for comparing SFT and GRPO models.")
    parser.add_argument(
        "--wav_path", 
        type=str, 
        required=True,
        help="Path to the input WAV audio file."
    )
    parser.add_argument(
        "--output_file", 
        type=str, 
        default="inference_results.txt",
        help="Path to save the inference results."
    )
    args = parser.parse_args()

    ### 新增：定义问题列表 ###
    # 您可以在这里添加或修改您想要测试的问题
    question_list = [
        "音频中有哪些乐器？它们分别大概在什么位置？",
        "除了乐队的演奏，还能听到其他类型的声音吗？",
        "能分辨出有几个吉他声源吗？它们分别大概在什么位置？",
        "你听到了几种打击乐器？它们分别大概在什么位置？",
        "根据混响和声音的反射特性，你认为这是一个小录音室、一个大型音乐厅，还是一个露天体育场？",
        "这段音频听起来是现场录音还是录音室版本？",
        "吉他弹奏的和弦进行大概是怎样的？",
        "哪两件乐器在空间上听起来离得最近？",
        "请根据你听到的所有声音，描述一下整个乐队在舞台上的空间布局。",
        "总结一下观众和乐队之间的空间关系。",
        "为什么观众的掌声听起来如此宽广和具有包围感？",
        "如果把主唱的位置从中间移动到最右侧，听感上会有什么不同？",
        "如果这场表演是在一个吸音很好的小房间里，而不是现在这个开阔的空间，声音最大的区别会是什么？",
        "描述一下你从这段音乐中听到了什么？",
        "相比于主唱，吉他的位置是更靠哪边？"
    ]

    # 2. 设置设备和量化配置 (不变)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"--- Using device: {device} ---")

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )

    ### 3. 加载完整模型 (不变) ###
    grpo_lora_config = LoraConfig(
        target_modules=["q_proj", "v_proj"], task_type=TaskType.CAUSAL_LM,
        r=8, lora_alpha=32, lora_dropout=0.1,
    )

    config = TWNMConfig()
    model = TWNM(config=config, peft_config=grpo_lora_config, quantization_config=quantization_config)

    state_dict = torch.load("assets/checkpoints/sft2_checkpoint-2502/pytorch_model.bin", map_location="cpu")

    new_state_dict = {}
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
        
    model.load_state_dict(new_state_dict, strict=False)
    model.decoder = model.decoder.merge_and_unload()
    print("--- Base SFT2 model loaded successfully ---")

    model.decoder = PeftModel.from_pretrained(
        model.decoder, 
        "assets/checkpoints/grpo_checkpoint-1/policy/policy", 
        adapter_name="policy"
    )
    print("--- GRPO policy adapter loaded successfully ---")
    
    model.to(device)
    model.eval()

    # 4. 预处理音频 (不变, 只需执行一次)
    audio_tensor = preprocess_audio(args.wav_path).unsqueeze(0).to(device)
    tokenizer = model.tokenizer

    # 确保输出目录存在
    output_dir = os.path.dirname(args.output_file)
    if output_dir:  # 只有当目录名不为空时才创建
        os.makedirs(output_dir, exist_ok=True)
    
    ### 5. 【核心修改】循环处理问题并保存结果 ###
    with open(args.output_file, 'w', encoding='utf-8') as f:
        print(f"\n--- Starting batch inference. Results will be saved to {args.output_file} ---")
        
        for i, question in enumerate(question_list):
            print(f"\n[{i+1}/{len(question_list)}] Processing question: '{question}'")
            
            # 准备当前问题的输入
            input_ids = tokenizer(question, return_tensors="pt").input_ids.to(device)

            # --- 5.1 使用 GRPO (policy) 模型进行推理 ---
            model.decoder.set_adapter("policy")
            print("   > Generating with GRPO model...")
            with torch.no_grad():
                generated_ids_grpo = model.generate(
                    input_ids=input_ids,
                    audio=audio_tensor,
                    max_new_tokens=1024,
                    do_sample=True,
                    temperature=0.7,
                )
            
            # 解码GRPO模型的输出，并去除输入的问题部分
            decoded_output_grpo = tokenizer.batch_decode(generated_ids_grpo, skip_special_tokens=True)[0]

            # --- 5.2 使用 SFT2 (base) 模型进行推理 ---
            model.decoder.disable_adapter()
            print("   > Generating with SFT2 model...")
            with torch.no_grad():
                generated_ids_sft2 = model.generate(
                    input_ids=input_ids,
                    audio=audio_tensor,
                    max_new_tokens=1024,
                    do_sample=True,
                    temperature=0.7,
                )

            # 解码SFT2模型的输出，并去除输入的问题部分
            decoded_output_sft2 = tokenizer.batch_decode(generated_ids_sft2, skip_special_tokens=True)[0]

            # --- 5.3 将结果写入文件 ---
            f.write(f"--- Question {i+1} ---\n")
            f.write(f"Q: {question}\n\n")
            f.write(f"A (GRPO Model):\n{decoded_output_grpo}\n\n")
            f.write(f"A (SFT2 Model):\n{decoded_output_sft2}\n\n")
            f.write("="*50 + "\n\n")
            print("   > Results for this question have been saved.")

    print(f"\n--- All questions processed. Final results are in {args.output_file} ---")


if __name__ == "__main__":
    main()

"""
CUDA_VISIBLE_DEVICES=4 python inference_grpo_list.py     --wav_path "<PATH_TO_TWNM>/cal_hotel_audio/cut6.wav"     --output_file "comparison_results_cut6.txt"
"""