import argparse
import torch
import torchaudio
from transformers import BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model, PeftModel

from twnm.models.twnm_pretrained_model import TWNM, TWNMConfig


def preprocess_audio(wav_path: str, target_sr: int = 44100, max_length_seconds: int = 20) -> torch.Tensor:
    """加载、重采样、裁剪/填充到固定长度，输出双通道。"""
    waveform, original_sr = torchaudio.load(wav_path)
    if original_sr != target_sr:
        waveform = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)(waveform)
    if waveform.shape[0] == 1:
        waveform = waveform.repeat(2, 1)
    elif waveform.shape[0] > 2:
        waveform = waveform[:2, :]
    max_len = target_sr * max_length_seconds
    if waveform.shape[1] > max_len:
        waveform = waveform[:, :max_len]
    else:
        waveform = torch.nn.functional.pad(waveform, (0, max_len - waveform.shape[1]))
    return waveform


def main():
    parser = argparse.ArgumentParser(description="TWNM GRPO policy inference (SFT merge + policy attach)")
    parser.add_argument("--sft_checkpoint_path", type=str, required=True, help="SFT2 权重路径（未 merge 或已 merge）")
    parser.add_argument("--policy_adapter_path", type=str, required=True, help="GRPO policy LoRA 目录")
    parser.add_argument("--wav_path", type=str, required=True, help="输入 WAV 路径")
    parser.add_argument("--prompt", type=str, required=True, help="文本 prompt")
    parser.add_argument(
        "--sft_is_merged",
        action="store_true",
        help="SFT 权重已 merge 时打开，跳过占位 merge",
    )
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"--- Using device: {device} ---")

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )

    model = TWNM(config=TWNMConfig(), peft_config=None, quantization_config=quantization_config)

    state_dict = torch.load(args.sft_checkpoint_path, map_location="cpu")
    cleaned = {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}

    if args.sft_is_merged:
        model.load_state_dict(cleaned, strict=False)
    else:
        # 挂 SFT 占位（默认 adapter 名 default），加载后 merge
        sft_lora = LoraConfig(
            target_modules=["q_proj", "v_proj"],
            task_type=TaskType.CAUSAL_LM,
            r=8,
            lora_alpha=32,
            lora_dropout=0.1,
            inference_mode=True,
        )
        model.decoder = get_peft_model(model.decoder, sft_lora, adapter_name="default")
        model.load_state_dict(cleaned, strict=False)
        model.decoder = model.decoder.merge_and_unload()

    # 挂 policy 适配器
    policy_lora = LoraConfig(
        target_modules=["q_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        inference_mode=False,
    )
    model.decoder = get_peft_model(model.decoder, policy_lora, adapter_name="policy")
    model.decoder.load_adapter(args.policy_adapter_path, adapter_name="policy")
    model.decoder.set_adapter("policy")

    model.to(device)
    model.eval()
    tokenizer = model.tokenizer

    audio = preprocess_audio(args.wav_path).unsqueeze(0).to(device)
    input_ids = tokenizer(args.prompt, return_tensors="pt").input_ids.to(device)

    with torch.no_grad():
        gen_ids = model.generate(
            input_ids=input_ids,
            audio=audio,
            max_new_tokens=256,
            do_sample=False,
            num_beams=2,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

    decoded = tokenizer.batch_decode(gen_ids, add_special_tokens=True)
    print("\n" + "=" * 30)
    print("Inference Result")
    print("=" * 30)
    print(decoded[0])
    print("=" * 30 + "\n")


if __name__ == "__main__":
    main()
