from safetensors import safe_open
from safetensors.torch import save_file
import json
import os


base_model_path = "Qwen/Qwen-Image"

ft_model = "logs/unifiedreward/qwenimage_2/checkpoint-480/model.safetensors"  # 原始模型权重
save_root = os.path.join(os.path.dirname(ft_model), 'lora')
os.makedirs(save_root, exist_ok= True)
dst_path = os.path.join(save_root, 'adapter_model.safetensors')

# 你的 LoRA 配置（带 auto_mapping）
adapter_config = {
    "alpha_pattern": {},
    "auto_mapping": {
        "base_model_class": "QwenImageTransformer2DModel",
        "parent_library": "diffusers.models.transformers.transformer_qwenimage"
    },
    "base_model_name_or_path": "Qwen/Qwen-Image",
    "bias": "none",
    "fan_in_fan_out": False,
    "inference_mode": True,
    "init_lora_weights": "gaussian",
    "lora_alpha": 128,
    "lora_bias": False,
    "lora_dropout": 0.0,
    "peft_type": "LORA",
    "r": 64,
    "target_modules": [
        "attn.to_k",
        "attn.to_q",
        "attn.to_v",
        "attn.to_out.0",
        "attn.add_k_proj",
        "attn.add_q_proj",
        "attn.add_v_proj",
        "attn.to_add_out",
        "img_mlp.net.0.proj",
        "img_mlp.net.2",
        "txt_mlp.net.0.proj",
        "txt_mlp.net.2"
    ],
    "task_type": None
}


# 提取 LoRA 权重
lora_tensors = {}
with safe_open(ft_model, framework="pt") as f:
    for key in f.keys():
        if "lora" in key.lower():  # 只提取 LoRA 层
            new_key = key
            # 去掉 .default
            new_key = new_key.replace(".lora_A.default.weight", ".lora_A.weight")
            new_key = new_key.replace(".lora_B.default.weight", ".lora_B.weight")
            lora_tensors[new_key] = f.get_tensor(key)

print(f"找到 {len(lora_tensors)} 个 LoRA 参数，保存到 {dst_path}")
save_file(lora_tensors, dst_path)


with open(os.path.join(save_root, "adapter_config.json"), "w") as f:
    json.dump(adapter_config, f, indent=2)

print("adapter_config.json 已生成！")
