import torch
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from peft import PeftModel
from transformers import AutoModelForCausalLM
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

BASE_MODEL = "/model-weights/Meta-Llama-3.1-8B-Instruct"
DOMAINS = [] #load dataset names there
CHECKPOINTS = [f"checkpoint-{i}" for i in range(50, 501, 50)]
ROOT_DIR = "../../models"

def get_lora_weights_vector(peft_model):
    vecs = []
    for name, param in peft_model.named_parameters():
        if "lora_A" in name or "lora_B" in name:
            normed = (param.detach().cpu() - param.mean()) / (param.std() + 1e-6)
            vecs.append(normed.flatten())
    return torch.cat(vecs).numpy()

all_vectors = []
all_labels = []
all_domains = []

print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float16, device_map="cpu")

for domain in DOMAINS:
    for ckpt in CHECKPOINTS:
        ckpt_path = os.path.join(ROOT_DIR, domain, ckpt)
        if not os.path.exists(ckpt_path):
            print(f"Warning: {ckpt_path} does not exist, skipping.")
            continue
        try:
            print(f"Loading {domain} → {ckpt}")
            peft_model = PeftModel.from_pretrained(base_model, ckpt_path)
            peft_model.eval()
            vector = get_lora_weights_vector(peft_model)
            all_vectors.append(vector)
            all_labels.append(f"{domain}_{ckpt.split('-')[1]}")
            all_domains.append(domain)
        except Exception as e:
            print(f"Failed loading {ckpt_path}: {e}")

print(f"Total checkpoints loaded: {len(all_vectors)}")

all_vectors = np.stack(all_vectors)

scaler = StandardScaler()
vectors_scaled = scaler.fit_transform(all_vectors)

print("Running PCA...")
pca = PCA(n_components=30)
vectors_reduced = pca.fit_transform(vectors_scaled)

print("Running t-SNE...")
tsne = TSNE(
    n_components=2,
    perplexity=3,
    learning_rate='auto',
    init="pca",
    random_state=42
)
tsne_results = tsne.fit_transform(vectors_reduced)

plt.figure(figsize=(8, 6))
unique_domains = DOMAINS 
palette = sns.color_palette("husl", len(unique_domains))

for i, domain in enumerate(unique_domains):
    xs = [tsne_results[j, 0] for j in range(len(all_domains)) if all_domains[j] == domain]
    ys = [tsne_results[j, 1] for j in range(len(all_domains)) if all_domains[j] == domain]
    lbls = [all_labels[j] for j in range(len(all_domains)) if all_domains[j] == domain]
    plt.scatter(xs, ys, label=domain, s=70, color=palette[i])
    for x, y, lbl in zip(xs, ys, lbls):
        plt.annotate(lbl.split("_")[1], (x, y), fontsize=7, alpha=0.7)

plt.xlabel("t-SNE Dimension 1")
plt.ylabel("t-SNE Dimension 2")
plt.legend(title="Domain", fontsize=10, title_fontsize=11)
plt.tight_layout()

plt.savefig("newest_lora_tsne_multidomain.pdf", bbox_inches='tight')
plt.close()

print("Saved plot as lora_tsne_multidomain.pdf")