# -*- coding: utf-8 -*-
"""MLSFT - Vector Drift.ipynb

Automatically generated by Colab.

# Overview

This script extracts vector information from the models created for the Multilingual Safety Fine-Tuning - MLSFT - project using a balanced dataset of 44 adversarial and 44 non-adversarial prompts, from SORRY-Bench and AlpacaEval respectively. Results are saved to Drive.
"""

!pip install -q -U bitsandbytes accelerate transformers

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import gc
import sys
import json
import os
import shutil
from tqdm.auto import tqdm
from google.colab import auth, drive
from google.auth import default
import gspread
import re
import base64

# Config
INPUT_SHEET_ID = "YOUR INPUT SHEET ID" # Update for ID for your input sheet
INPUT_TAB_NAME = "Sheet1"
OUTPUT_SHEET_NAME = "YOUR OUTPUT SHETE" # Update for your output sheet
DRIVE_FOLDER = "YOUR DRIVE LOCATION" # Update for your Drive folder
HF_TOKEN = None

SWEEP_RATIOS = np.round(np.linspace(0.1, 1.0, 10), 2)

BASE_MODEL_MAPPING = {
    "llama-3.2-3b-instruct": "unsloth/Llama-3.2-3B-Instruct",
    "llama-3.2-1b-instruct": "unsloth/Llama-3.2-1B-Instruct",
    "gemma-3-1b-it": "unsloth/gemma-3-1b-it",
    "gemma-3-4b-it": "unsloth/gemma-3-4b-it",
    "qwen3-4b": "unsloth/Qwen3-4B",
    "qwen3-0.6b": "unsloth/Qwen3-0.6B",
}


# Setup
print("Authenticating with Google...")
auth.authenticate_user()
creds, _ = default()
gc_client = gspread.authorize(creds)

print("Loading Prompts from Sheets...")

# Harmful Prompts (from SORRY-Bench)
try:
    harmful_sheet_url = "INSERT" # Your sheet of harmful prompts, we take the first of each 44 topics in SORRY-Bench
    wb_harmful = gc_client.open_by_url(harmful_sheet_url)
    ws_harmful = wb_harmful.get_worksheet(0)
    HARMFUL_PROMPTS = ws_harmful.col_values(2)[1:]
    print(f"Loaded {len(HARMFUL_PROMPTS)} Harmful Prompts")
except Exception as e:
    print(f"Error loading Harmful Prompts: {e}")
    HARMFUL_PROMPTS = []

# Harmless Prompts (from Alpaca)
try:
    harmless_sheet_url = "YOUR PROMPTS" # We select 44 non-adversarial Alpaca prompts
    wb_harmless = gc_client.open_by_url(harmless_sheet_url)
    ws_harmless = wb_harmless.get_worksheet(0)
    HARMLESS_PROMPTS = ws_harmless.col_values(1)[1:]
    print(f"Loaded {len(HARMLESS_PROMPTS)} Harmless Prompts")
except Exception as e:
    print(f"Error loading Harmless Prompts: {e}")
    HARMLESS_PROMPTS = []

if len(HARMFUL_PROMPTS) == 0 or len(HARMLESS_PROMPTS) == 0:
    raise ValueError("Prompts failed to load. Check Sheet URLs.")

def setup_drive():
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')
    if not os.path.exists(DRIVE_FOLDER):
        os.makedirs(DRIVE_FOLDER)
        print(f"Created folder: {DRIVE_FOLDER}")
    else:
        print(f"Using existing folder: {DRIVE_FOLDER}")

def save_vectors_to_drive(model_id, vectors_dict):
    safe_name = model_id.replace("/", "_").replace(" ", "_")
    file_path = f"{DRIVE_FOLDER}/{safe_name}.npz"
    numpy_dict = {str(k): v.numpy() for k, v in vectors_dict.items()}
    np.savez_compressed(file_path, **numpy_dict)
    return file_path

def extract_model_id(url_or_id):
    if url_or_id.startswith("http"):
        match = re.search(r'huggingface\.co/([^/]+/[^/]+)', url_or_id)
        if match: return match.group(1)
    return url_or_id

def find_base_model(model_name):
    model_name_lower = model_name.lower()
    for key, base_id in BASE_MODEL_MAPPING.items():
        if key.lower() in model_name_lower: return base_id
    return None

def cleanup_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

def cleanup_disk_cache(model_id):
    """Deletes the specific model from Hugging Face disk cache."""
    try:
        cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
        safe_id = model_id.replace("/", "--")
        dir_name = f"models--{safe_id}"
        path = os.path.join(cache_dir, dir_name)

        if os.path.exists(path):
            shutil.rmtree(path)
            print(f"  ✓ Deleted from disk cache: {path}")
        else:
            pass

    except Exception as e:
        print(f"Warning: Failed to clean disk cache for {model_id}: {e}")

def get_vram_usage():
    if not torch.cuda.is_available(): return "CPU"
    free, total = torch.cuda.mem_get_info()
    used = (total - free) / 1024**3
    total_gb = total / 1024**3
    return f"{used:.2f}/{total_gb:.2f} GB"

def get_model_layers(model):
    candidates = [
        lambda m: m.language_model.model.layers if hasattr(m, 'language_model') and hasattr(m.language_model, 'model') else None,
        lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None,
        lambda m: m.layers if hasattr(m, 'layers') else None,
        lambda m: m.transformer.h if hasattr(m, 'transformer') else None,
    ]

    for candidate_fn in candidates:
        try:
            layers = candidate_fn(model)
            if layers is not None and isinstance(layers, (nn.ModuleList, list)) and len(layers) > 0:
                return layers
        except:
            continue

    for name, module in model.named_modules():
        if 'language_model' in name and name.endswith('layers') and isinstance(module, nn.ModuleList):
            return module

    raise ValueError(f"Could not locate text decoder layers in model of type {type(model)}")

# Analysis

def get_safety_vectors_sweep(model, tokenizer, harmful_prompts, harmless_prompts, ratios=SWEEP_RATIOS):
    try:
        layers = get_model_layers(model)
        num_layers = len(layers)
    except ValueError as e:
        print(f"  ERROR: {e}")
        return None

    target_map = {}
    for r in ratios:
        idx = min(int(num_layers * r), num_layers - 1)
        target_map[r] = idx

    activations = {r: [] for r in ratios}
    handles = []

    def make_hook(ratio_key):
        def hook_fn(module, input, output):
            out = output[0] if isinstance(output, tuple) else output
            act = out[:, -1, :].detach().cpu().float()
            activations[ratio_key].append(act)
        return hook_fn

    for r, idx in target_map.items():
        layer = layers[idx]
        handle = layer.register_forward_hook(make_hook(r))
        handles.append(handle)

    try:
        all_prompts = harmful_prompts + harmless_prompts
        for prompt in all_prompts:
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
            with torch.no_grad(): _ = model(**inputs)

        results = {}
        n_harmful = len(harmful_prompts)
        for r in ratios:
            acts_list = activations[r]
            if len(acts_list) != len(all_prompts): continue
            all_acts = torch.cat(acts_list, dim=0)
            harmful_acts = all_acts[:n_harmful]
            harmless_acts = all_acts[n_harmful:]
            vec = harmful_acts.mean(dim=0) - harmless_acts.mean(dim=0)
            results[r] = vec
        return results
    except Exception as e:
        print(f"  ERROR during sweep: {e}")
        return None
    finally:
        for h in handles: h.remove()

def analyze_model_sweep(model_id, hf_token):
    cleanup_memory()
    print(f"Processing: {model_id}")
    print(f"VRAM before load: {get_vram_usage()}")

    compute_dtype = torch.bfloat16
    model = None
    tokenizer = None
    vectors = None

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True
        )

        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config,
            device_map="auto",
            token=hf_token,
            trust_remote_code=True,
            torch_dtype=compute_dtype,
            attn_implementation="sdpa"
        )
        model.eval()

        vectors = get_safety_vectors_sweep(
            model, tokenizer,
            HARMFUL_PROMPTS, HARMLESS_PROMPTS,
            ratios=SWEEP_RATIOS
        )
        if vectors: print(f"Extracted vectors for {len(vectors)} layers")

    except Exception as e:
        print(f"FAILED: {str(e)}")
    finally:
        if model is not None: del model
        if tokenizer is not None: del tokenizer
        cleanup_memory()
        cleanup_disk_cache(model_id)
        print(f"VRAM after cleanup: {get_vram_usage()}")

    return vectors

# Execute

def main():
    setup_drive()

    try:
        sh = gc_client.open(OUTPUT_SHEET_NAME)
        ws = sh.sheet1
        print(f"Using existing sheet: {OUTPUT_SHEET_NAME}")
        if not ws.row_values(1):
             ws.append_row(["Model ID", "Base Model", "Sim_Profile_JSON", "Cosine_Sim_0.6", "Status", "Notes", "Vector_File_Path"])
    except:
        sh = gc_client.create(OUTPUT_SHEET_NAME)
        ws = sh.sheet1
        ws.append_row(["Model ID", "Base Model", "Sim_Profile_JSON", "Cosine_Sim_0.6", "Status", "Notes", "Vector_File_Path"])
        print(f"Created new sheet: {OUTPUT_SHEET_NAME}")

    processed_models = set()
    try:
        existing_column = ws.col_values(1)
        if len(existing_column) > 1:
            processed_models = set(existing_column[1:])
            print(f"Found {len(processed_models)} models already processed.")
    except: pass

    try:
        input_sh = gc_client.open_by_key(INPUT_SHEET_ID)
        raw_models = input_sh.worksheet(INPUT_TAB_NAME).col_values(1)[1:]
        all_models = [extract_model_id(m.strip()) for m in raw_models if m.strip()]
    except Exception as e:
        print(f"ERROR loading models: {e}")
        return

    models_to_process = [m for m in all_models if m not in processed_models]
    if not models_to_process:
        print("All models processed!")
        return

    global HF_TOKEN
    if HF_TOKEN is None:
        from getpass import getpass
        HF_TOKEN = getpass("Enter your HuggingFace token: ")

    # Pre-calculate base models needed
    base_models_needed = set()
    for model in models_to_process:
        base = find_base_model(model)
        if base: base_models_needed.add(base)

    print(f"Computing {len(base_models_needed)} Base Model Sweeps")
    base_vectors_sweep = {}

    for base_id in sorted(base_models_needed):
        vecs = analyze_model_sweep(base_id, HF_TOKEN)
        if vecs:
            base_vectors_sweep[base_id] = vecs
            save_vectors_to_drive(base_id, vecs)
        else:
            print(f"WARNING: Failed base model {base_id}")

    print(f"Analyzing {len(models_to_process)} Fine-Tuned Models")

    for model_id in tqdm(models_to_process, desc="Models"):
        base_id = find_base_model(model_id)

        if not base_id or base_id not in base_vectors_sweep:
            ws.append_row([model_id, str(base_id), "{}", "0", "Error: Base Missing", "", ""])
            continue

        ft_vecs = analyze_model_sweep(model_id, HF_TOKEN)

        if not ft_vecs:
            ws.append_row([model_id, base_id, "{}", "0", "Error: Extraction Failed", "", ""])
            continue

        try:
            # Compute similarity
            sim_profile = {}
            base_sweep = base_vectors_sweep[base_id]
            for r in SWEEP_RATIOS:
                if r in ft_vecs and r in base_sweep:
                    v1 = base_sweep[r].numpy().reshape(1, -1)
                    v2 = ft_vecs[r].numpy().reshape(1, -1)
                    sim = cosine_similarity(v1, v2)[0, 0]
                    sim_profile[str(r)] = float(sim)

            sim_06 = sim_profile.get("0.6", 0.0)
            file_path = save_vectors_to_drive(model_id, ft_vecs)

            ws.append_row([model_id, base_id, json.dumps(sim_profile), sim_06, "Success", "", file_path])
            print(f"Saved Profile & Vectors -> {file_path}")

        except Exception as e:
            print(f"Error calculating/saving: {e}")
            ws.append_row([model_id, base_id, "{}", "0", f"Error: {e}", "", ""])

        cleanup_memory()

    print("Analysis completed")
    print(f"Sheet URL: {sh.url}")

if __name__ == "__main__":
    main()