import os
import re
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
from tqdm import tqdm

# sample script for cka, replace with dataset-finetuned models
base_model_path = "/model-weights/Meta-Llama-3.1-8B-Instruct"
adapter_ckpts = [
] #add checkpoints here
device = "cuda" if torch.cuda.is_available() else "cpu"
layer_idx = -1
num_samples = 1000
max_length = 256

tokenizer = AutoTokenizer.from_pretrained(base_model_path)
dataset = load_dataset("tatsu-lab/alpaca", split="train[:2000]")
texts = []
for item in dataset:
    text = item["instruction"].strip() + " " + item["input"].strip()
    if text:
        texts.append(text)
    if len(texts) >= num_samples:
        break
print(f"Loaded {len(texts)} samples")

def get_hidden_reps(model, texts):
    reps = []
    for text in tqdm(texts, desc="Extracting hidden reps"):
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            hidden = outputs.hidden_states[layer_idx] 
            pooled = hidden.mean(dim=1).squeeze(0).cpu()
            reps.append(pooled)
        del inputs, outputs, hidden
        torch.cuda.empty_cache()
    return torch.stack(reps)

# === CKA Functions ===
def center_gram(G):
    n = G.shape[0]
    H = torch.eye(n) - torch.ones((n, n)) / n
    return H @ G @ H

def compute_linear_CKA(X, Y):
    K = X @ X.T
    L = Y @ Y.T
    Kc = center_gram(K)
    Lc = center_gram(L)
    hsic = (Kc * Lc).sum()
    norm_x = (Kc * Kc).sum().sqrt()
    norm_y = (Lc * Lc).sum().sqrt()
    return hsic / (norm_x * norm_y + 1e-8)

all_reps = []
for i, path in enumerate(adapter_ckpts):
    print(f"\nLoading {os.path.basename(path)}...")
    model_base = AutoModelForCausalLM.from_pretrained(base_model_path, output_hidden_states=True).to(device)
    model = PeftModel.from_pretrained(model_base, path)
    model.eval()
    reps = get_hidden_reps(model, texts)
    all_reps.append(reps)
    del model, model_base
    torch.cuda.empty_cache()

print("\nComputing CKA similarity matrix...")
n = len(all_reps)
cka_matrix = np.zeros((n, n))
for i in range(n):
    for j in range(n):
        sim = compute_linear_CKA(all_reps[i], all_reps[j])
        cka_matrix[i, j] = sim.item()
        print(f"CKA[{i},{j}] = {sim.item():.4f}")

print("\nCKA Similarity Matrix:")
header = ["{:>12}".format(os.path.basename(p)) for p in adapter_ckpts]
print(" " * 12 + " ".join(header))
for i in range(n):
    row = ["{:>12.4f}".format(cka_matrix[i, j]) for j in range(n)]
    print(f"{os.path.basename(adapter_ckpts[i]):>12}: {' '.join(row)}")
