import os
import re
import json
import torch
from safetensors.torch import load_file, save_file
import argparse





parser = argparse.ArgumentParser()
parser.add_argument("--model_lora_path", type=str, required=True)
parser.add_argument("--init_lora_dir", type=str, required=True)
args = parser.parse_args()

trained_adapter_path = os.path.join(args.model_lora_path, "adapter_model.safetensors")


model_dir = os.path.dirname(trained_adapter_path)
init_lora_dir = args.init_lora_dir

init_lora_A_path = os.path.join(init_lora_dir, "lora_A_init.safetensors")
init_lora_B_path = os.path.join(init_lora_dir, "lora_B_init.safetensors")


backup_path = os.path.join(model_dir, "beifen-adp-md.safetensors")

if os.path.exists(backup_path):
    exit(0)

os.rename(trained_adapter_path, backup_path)



trained_weights = load_file(backup_path)
init_A = load_file(init_lora_A_path)
init_B = load_file(init_lora_B_path)


def remap_keys(d, suffix_old, suffix_new):
    return {
        k.replace(suffix_old, suffix_new) if k.endswith(suffix_old) else k: v
        for k, v in d.items()
    }

init_A = remap_keys(init_A, ".weight.lora_A", ".lora_A.weight")
init_B = remap_keys(init_B, ".weight.lora_B", ".lora_B.weight")


new_weights = {}
for key in trained_weights.keys():
    if not key.endswith("lora_A.weight") and not key.endswith("lora_B.weight"):
        new_weights[key] = trained_weights[key]
        continue

    prefix = key.replace(".lora_A.weight", "").replace(".lora_B.weight", "")
    key_A = f"{prefix}.lora_A.weight"
    key_B = f"{prefix}.lora_B.weight"

    if key_A in trained_weights and key_A in init_A and key_B in trained_weights and key_B in init_B:
        A_trained = trained_weights[key_A].to(torch.float32)
        B_trained = trained_weights[key_B].to(torch.float32)
        A_init = init_A[key_A].to(torch.float32)
        B_init = init_B[key_B].to(torch.float32)

        B_combined = torch.cat([B_trained, B_init], dim=1)
        A_combined = torch.cat([A_trained, -A_init], dim=0)

        new_weights[key_A] = A_combined.to(torch.bfloat16).contiguous()
        new_weights[key_B] = B_combined.to(torch.bfloat16).contiguous()
    else:
        new_weights[key] = trained_weights[key]



save_file(new_weights, trained_adapter_path)










config_path = os.path.join(model_dir, "adapter_config.json")
backup_config_path = os.path.join(model_dir, "beifen-cg.json")


if not os.path.exists(config_path):
    raise FileNotFoundError(f"Wrong path.")

with open(config_path, "r") as f:
    config = json.load(f)

os.rename(config_path, backup_config_path)

if "r" in config and isinstance(config["r"], int):
    config["r"] *= 2
if "lora_alpha" in config and isinstance(config["lora_alpha"], int):
    config["lora_alpha"] *= 2

with open(config_path, "w") as f:
    json.dump(config, f, indent=2)

