import os
from safetensors.torch import load_file, save_file
import torch


def refine_and_save(model_save_dir, weight_path, target_modules):
    output_name = "adapter_model.safetensors"
    ori_path = os.path.join(model_save_dir, output_name)
    adapter = load_file(ori_path, device="cpu")
    print(f"Loaded original adapter from {ori_path}")

    weight_list = torch.load(weight_path, map_location="cpu")
    print(f"Loaded C weights from {weight_path}")

    for lora_C_k in weight_list.keys():
        if not lora_C_k.endswith("_lora_C"):
            continue
        parts = lora_C_k.split("_")
        layer_idx = int(parts[-3])
        mod = "_".join(parts[:-3])
        if mod not in target_modules:
            continue

        b_key = f"base_model.model.model.layers.{layer_idx}.self_attn.{mod}.lora_B.weight"
        if b_key in adapter:
            adapter[b_key] = weight_list[lora_C_k] @ adapter[b_key]
            print(f"Absorbed {lora_C_k} into {b_key}")

    save_path = os.path.join(model_save_dir, output_name)
    save_file(adapter, save_path)
    print(f"Saved new adapter to {save_path}")
