import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
from scipy.spatial.distance import cosine
from tqdm import tqdm
import numpy as np

base_model_path = "/model-weights/Meta-Llama-3.1-8B-Instruct"
adapter_ckpts = [
] #load checkpoints here
device = "cuda" if torch.cuda.is_available() else "cpu"
layer_idx = -1       
num_samples = 1000
max_length = 256

tokenizer = AutoTokenizer.from_pretrained(base_model_path)

dataset = load_dataset("tatsu-lab/alpaca", split="train[:2000]")
texts = []
for item in dataset:
    text = item["instruction"].strip() + " " + item["input"].strip()
    if text:
        texts.append(text)
    if len(texts) >= num_samples:
        break
print(f"Loaded {len(texts)} samples from Alpaca for probing.")

def get_hidden_reps(model, texts):
    all_reps = []
    for text in tqdm(texts, desc="Extracting hidden reps"):
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            hidden = outputs.hidden_states[layer_idx]
            rep = hidden.mean(dim=1).squeeze(0).cpu()
            all_reps.append(rep)
        del inputs, outputs, hidden
        torch.cuda.empty_cache()
    return torch.stack(all_reps)

drift_values = []

print(f"\nLoading checkpoint 0...")
model_base = AutoModelForCausalLM.from_pretrained(base_model_path, output_hidden_states=True).to(device)
model_prev = PeftModel.from_pretrained(model_base, adapter_ckpts[0])
model_prev.eval()
rep_prev = get_hidden_reps(model_prev, texts)
del model_prev, model_base
torch.cuda.empty_cache()

for i in range(1, len(adapter_ckpts)):
    print(f"\nLoading checkpoint {i}...")

    model_base = AutoModelForCausalLM.from_pretrained(base_model_path, output_hidden_states=True).to(device)
    model_curr = PeftModel.from_pretrained(model_base, adapter_ckpts[i])
    model_curr.eval()

    rep_curr = get_hidden_reps(model_curr, texts)

    drift_scores = [cosine(a, b) for a, b in zip(rep_curr, rep_prev)]
    avg_drift = np.mean(drift_scores)
    print(f"Avg rep drift from ckpt_{i-1} → ckpt_{i}: {avg_drift:.4f}")
    drift_values.append(avg_drift)

    rep_prev = rep_curr

    del model_curr, model_base, drift_scores
    torch.cuda.empty_cache()

print("\nFinal Consecutive Representation Drift Results:")
for i, drift in enumerate(drift_values):
    print(f"  - ckpt_{i} → ckpt_{i+1}: {drift:.8f}")