import torch
from transformers import Qwen2AudioForConditionalGeneration, AutoTokenizer
import os

# --- 配置 ---
# 原始的Qwen2-Audio模型ID
source_model_id = "Qwen/Qwen2-Audio-7B"

# 你希望将提取出的纯LLM模型保存到的本地路径
output_dir = "./qwen2-audio-llm-extracted"

print(f"准备从 '{source_model_id}' 中提取LLM...")
print(f"模型将被保存到: '{output_dir}'")

# --- 加载完整的Qwen2-Audio模型 ---
# 提示：这一步会比较慢，且需要较多内存/显存，但我们只需要执行这一次。
print("正在加载完整的Qwen2-Audio模型，请耐心等待...")
full_model = Qwen2AudioForConditionalGeneration.from_pretrained(
    source_model_id,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto" # 自动分配到GPU/CPU
)

# --- 提取LLM组件和Tokenizer ---
# LLM组件通常存储在 .language_model 属性中
llm_part = full_model.language_model

# 加载对应的Tokenizer
tokenizer = AutoTokenizer.from_pretrained(source_model_id, trust_remote_code=True)

# --- 保存提取出的LLM和Tokenizer ---
print("\n提取完成，正在将纯LLM模型和Tokenizer保存到本地...")

# 使用 .save_pretrained() 方法，它会自动保存正确的配置文件(config.json)和权重
llm_part.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"✅ 成功！纯LLM模型已保存至 '{output_dir}' 文件夹。")
print("现在你可以在其他脚本中直接从这个本地路径加载模型了。")

# 释放内存
del full_model
torch.cuda.empty_cache()