# 文件名: save_model.py (最终修正)
import torch
from twnm.models.twnm_pretrained_model import TWNM, TWNMConfig
from transformers import BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

# 1. 配置 (保持不变)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
twnm_config = TWNMConfig(
    decoder_model_name="<PATH_TO_TWNM>/assets/checkpoints/qwen2-audio-llm-extracted",
    spatial_encoder_ckpt_path="assets/checkpoints/spatial_encoder/loss=0.4612.ckpt"
)
sft_lora_checkpoint = "<PATH_TO_TWNM>/exp/SFT2/checkpoint-1251/pytorch_model.bin"
output_path = "assets/checkpoints/sft_merged_model_complete"

# 2. 加载基础的TWNM模型 (保持不变)
print("Loading 4-bit base model...")
model = TWNM(config=twnm_config, quantization_config=quantization_config)

# 3. 定义LoRA配置 (保持不变)
print("Defining LoRA configuration...")
lora_config = LoraConfig(
    target_modules=["q_proj", "v_proj"],
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none",
)

# 4. 应用LoRA结构 (保持不变)
print("Applying LoRA structure to the decoder...")
model.decoder = get_peft_model(model.decoder, lora_config)
print("Decoder is now a PeftModel.")

# 5. 加载SFT LoRA权重 (保持不变)
print("Loading SFT LoRA weights...")
lora_weights = torch.load(sft_lora_checkpoint, map_location="cpu")
model.load_state_dict(lora_weights, strict=False)
print("LoRA weights loaded successfully.")

# 6. 融合权重 (保持不变)
print("Merging LoRA weights... This will de-quantize the decoder.")
model.decoder = model.decoder.merge_and_unload() 
print("Merge complete. Decoder is now a standard, high-precision model.")

# 7. 保存完整的模型 (核心修正)
print(f"Saving the COMPLETE merged model to {output_path}...")

# --- 修改就是下面这一行，添加 safe_serialization=False ---
model.save_pretrained(output_path, safe_serialization=False)
# -----------------------------------------------------------

# 同时保存tokenizer
model.tokenizer.save_pretrained(output_path)

print("Complete merged model saved successfully.")