# -*- coding: utf-8 -*-
import argparse
import yaml
import torch
import torchaudio
from twnm.models.twnm_sft2 import TWNM
import transformers
import logging

def preprocess_audio(wav_path: str, target_sr: int = 16000, max_length_seconds: int = 5) -> torch.Tensor:
    """
    加载、重采样、填充/截断音频文件，使其符合模型输入要求。
    """
    waveform, original_sr = torchaudio.load(wav_path)

    # 1. 重采样到目标采样率
    if original_sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
        waveform = resampler(waveform)

    # 2. 确保是双通道
    if waveform.shape[0] == 1:
        waveform = waveform.repeat(2, 1)
    elif waveform.shape[0] > 2:
        waveform = waveform[:2, :] # 如果超过双通道，只取前两个

    # 3. 截断或填充到固定长度
    max_length = target_sr * max_length_seconds
    if waveform.shape[1] > max_length:
        waveform = waveform[:, :max_length]
    else:
        padding = max_length - waveform.shape[1]
        waveform = torch.nn.functional.pad(waveform, (0, padding))
        
    return waveform

def main():

    transformers.logging.set_verbosity(logging.INFO)


    parser = argparse.ArgumentParser(description="Test script for the TWNM model.")
    parser.add_argument("--config_path", type=str, required=True, help="Path to the training config YAML file.")
    parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint file (e.g., pytorch_model.bin).")
    parser.add_argument("--wav_path", type=str, required=True, help="Path to the input WAV audio file.")
    parser.add_argument("--prompt", type=str, required=True, help="Text prompt for the model.")
    
    args = parser.parse_args()

    # --- 1. 设置设备 ---
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # --- 2. 加载配置并初始化模型 ---
    print("Loading config and initializing model...")
    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)

    model = TWNM(config, is_inference=True)
    
    # --- 3. 加载模型权重 ---
    print(f"Loading checkpoint from {args.checkpoint_path}...")
    # 加载状态字典，并使用 map_location 以免GPU内存不足
    state_dict = torch.load(args.checkpoint_path, map_location='cpu')
    
    # 如果模型是在DataParallel或DDP下训练的，权重键名可能会有'module.'前缀，需要移除
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
        
    model.load_state_dict(new_state_dict, strict=False)
    
    model.to(device)
    model.eval() # 设置为评估模式
    print("Model loaded successfully.")

    # --- 4. 预处理音频 ---
    print(f"Preprocessing audio file: {args.wav_path}...")
    sample_rate = config.get("dataset_conf", {}).get("sample_rate", 16000)
    max_len = config.get("dataset_conf", {}).get("max_len", 30)
    
    audio_tensor = preprocess_audio(args.wav_path, target_sr=sample_rate, max_length_seconds=max_len)
    
    # 添加batch维度，并移动到正确的设备和数据类型 (bfloat16)
    audio_tensor = audio_tensor.unsqueeze(0).to(device)

    # --- 5. 准备模型输入 ---
    # 模型内部会将'task'转换为"<prompt> <AcousticTokens>"
    samples = {
        "audios": audio_tensor,
        "task": [args.prompt] 
    }

    

    # --- 6. 执行推理 ---
    print("Running inference...")
    print(model.tokenizer.eos_token, model.tokenizer.pad_token, model.tokenizer.bos_token)
    print(model.decoder.config.pad_token_id)
    with torch.no_grad(): # 在推理时不需要计算梯度
        # 调用模型的generate方法，可以调整生成参数
        generated_captions = model.generate(
            samples,
            # use_nucleus_sampling=True,
            num_beams=3,
            max_length=1024, # 可以根据需要调整最大生成长度
            min_length=2,
            repetition_penalty=1.1,
            # eos_token_id=151643
        )

    # --- 7. 输出结果 ---
    print("\n" + "="*30)
    print("Inference Result")
    print("="*30)
    print(f"Audio File: {args.wav_path}")
    print(f"Prompt: {args.prompt}")
    print("-" * 30)
    print(f"Generated Text: {generated_captions[0]}")
    print("="*30)


if __name__ == "__main__":
    main()

"""
CUDA_VISIBLE_DEVICES=7 python inference.py \
    --config_path <PATH_TO_TWNM>/configs/inference.yaml \
    --checkpoint_path <PATH_TO_TWNM>/exp/SFT2/checkpoint-1251/pytorch_model.bin \
    --wav_path /data2/wl/RL_sample/audio/scene_000018.wav \
    --prompt "请仔细聆听，音频中总共可以分辨出多少个独立的声源？. Please choose the answer from the following options: A: 1个\nB: 2个\nC: 3个\nD: 4个"
"""

"""
{"audio_path": "/data2/wl/SFT/audio/scene_042047.wav", "task_type": "pure_acoustics", "instruction": "这个房间的混响效果是怎样的？", "answer": "整个空间带有中等的混响感，混响时间大约为0.12秒。", "router_label": [0, 0, 1, 0, 0], "metadata": {"audio_path": "/data2/wl/SFT/audio/scene 042047.wav",
"""

"""
|<think>| 问题的核心是判断音频中可以分辨出的独立声源数量。\n\n首先，我仔细分析了整个音频。音频的背景非常干净，混响很低，这使得声音的细节和来源都非常清晰。在这样的环境中，我只听到了一个明显的声音。这个声音听起来像是家庭日常活动中发出的，它的位置非常稳定，一直保持在我的右前方（方位角约65度）。根据音频信息，“在这个环境中，可以听到一个独立的声源”，这直接指明了声源的数量为1。\n\n因此，选项A是正确的。\n\n其他选项是错误的。选项B、C和D分别假设存在2个、3个或4个声源。然而，在整个音频中，我没有感知到任何其他独立的声音事件或声源。声音的来源单一且明确，没有多个声音同时或先后出现。因此，这些选项与我从音频中感知到的信息不符。 |</think>| 通过对音频的仔细分析，我识别出环境中只有一个清晰、位置固定的声源。 |<answer>| A |</answer>|
"""

"""
CUDA_VISIBLE_DEVICES=2 python inference_grpo.py \
    --model_path "assets/checkpoints/sft2_checkpoint-2502" \
    --wav_path "/data2/wl/RL_sample/audio/scene_000018.wav" \
    --prompt "请仔细聆听，音频中总共可以分辨出多少个独立的声源？. Please choose the answer from the following options: A: 1个\nB: 2个\nC: 3个\nD: 4个"
"""
