#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Single-task-vector merge  (Prefix-robust + save-safe for LLaMAFactory)
θ_merge = θ_base + α·Δ_vl + (1−α)·Δ_coder
"""

import os, shutil, torch
from transformers import (
    AutoModelForCausalLM,
    Qwen2_5_VLForConditionalGeneration,
    AutoTokenizer,
)

# ========= 参数 =========
alpha  = 0.7
dtype = torch.float32 
device = torch.device("cpu")

base_ckpt   = "Qwen2.5-7B"
vl_ckpt     = "Qwen2.5-VL-7B-Instruct"
coder_ckpt  = "OpenCodeReasoning-Nemotron-1.1-7B"
out_dir     = f"./merged_ckpt_delta/Qwen2.5-7B_VL×nv_ocr1_1_alpha{alpha}"

# ========= 1. 载入模型 =========
print("▶️  Loading checkpoints …")
mbase  = AutoModelForCausalLM.from_pretrained(base_ckpt,  torch_dtype=dtype).to(device)
mvl    = Qwen2_5_VLForConditionalGeneration.from_pretrained(vl_ckpt,   torch_dtype=dtype).to(device)
mcoder = AutoModelForCausalLM.from_pretrained(coder_ckpt,  torch_dtype=dtype).to(device)

sd_base   = mbase.model.state_dict()
sd_coder  = mcoder.model.state_dict()
sd_vl     = mvl.model.state_dict()        # key 形如 language_model.model.xxx

beta = 1.0 - alpha
print(f"ℹ️  alpha={alpha:.2f}, beta={beta:.2f}")

# ========= 2. 前缀匹配工具 =========
def map_to_vl(key: str) -> str | None:
    """
    给定 base-style key, 返回在 VL state_dict 中的等价 key。
    依次尝试：
        language_model.model.<key>
        language_model.<key>
        <key>
    找到即返回，否则 None
    """
    for pfx in ("language_model.model.", "language_model.", ""):
        candidate = pfx + key
        if candidate in sd_vl:
            return candidate
    return None

# ========= 3. 融合主干权重 =========
merged_cnt = 0
for name, w_base in sd_base.items():
    vl_key = map_to_vl(name)
    if vl_key is None:
        continue
    if name not in sd_coder:
        continue
    if w_base.shape != sd_coder[name].shape or w_base.shape != sd_vl[vl_key].shape:
        continue

    with torch.no_grad():
        delta_vl    = sd_vl[vl_key]  - w_base
        delta_coder = sd_coder[name] - w_base
        merged      = w_base + alpha * delta_vl + beta * delta_coder
        sd_vl[vl_key].copy_(merged)
        merged_cnt += 1

print(f"✅  已成功融合 {merged_cnt} 个权重张量。")

# ========= 4. 安全保存：重新 load state_dict，然后 save_pretrained =========
print("💾  Saving merged checkpoint …")
mvl.model.load_state_dict(sd_vl)   
mvl.to(torch.float16)   
mvl.save_pretrained(out_dir)       
AutoTokenizer.from_pretrained(vl_ckpt).save_pretrained(out_dir)

for extra in ["chat_template.json", "preprocessor_config.json"]:
    src = os.path.join(vl_ckpt, extra)
    if os.path.exists(src):
        shutil.copy(src, out_dir)

print(f"\n🎉  完成！模型已保存至：{out_dir}")
