import os
import torch
from peft import PeftModel, PeftConfig
from safetensors.torch import load_file
import numpy as np

adapter_ckpts = [
] #add checkpoints here

def compute_adapter_norm(adapter_path):
    config = PeftConfig.from_pretrained(adapter_path)

    safetensor_path = os.path.join(adapter_path, "adapter_model.safetensors")
    adapter_weights = load_file(safetensor_path)

    total_norm_sq = 0.0
    for name, tensor in adapter_weights.items():
        norm = torch.norm(tensor.float(), p='fro')
        total_norm_sq += norm.item()**2

    total_norm = np.sqrt(total_norm_sq)
    return total_norm

print("Computing adapter Frobenius norms...\n")
adapter_norms = []

for i, path in enumerate(adapter_ckpts):
    norm = compute_adapter_norm(path)
    adapter_norms.append(norm)
    print(f"  - ckpt_{i} ({os.path.basename(path)}): Frobenius norm = {norm:.8f}")

print("\nDelta Frobenius Norms (ckpt-to-ckpt):")
for i in range(1, len(adapter_norms)):
    delta = adapter_norms[i] - adapter_norms[i-1]
    print(f"  - ckpt_{i-1} to ckpt_{i}: delta norm = {delta:+.8f}")