{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from peft import PeftModel, PeftConfig\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from sklearn.metrics import silhouette_score\n",
    "import gc\n",
    "import json\n",
    "import numpy as np\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "used_model = \"llama\"\n",
    "username = \"Anonymous19782130\"\n",
    "\n",
    "domains = [\"legal\", \"math\", \"medical\", \"commonsense\", \"coding\"]\n",
    "\n",
    "if used_model == 'llama':\n",
    "    MODEL_NAME = \"meta-llama/Llama-3.1-8B\" \n",
    "    model_list = [(domain, f\"{username}/llama-3.1-8b-{domain}-{split}\") for domain in domains for split in [\"first\", \"second\", \"third\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "probe_templates = [\n",
    "    \"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\n{task} Input:{input}\\n\\n### Response:\",\n",
    "\n",
    "    \"The task described below requires a response that completes the request accurately.\\n\\n### Instruction:\\n{task} Input:{input}\\n\\n### Response:\",\n",
    "\n",
    "    \"Below is a description of a task. Provide a response that aligns with the requirements.\\n\\n### Instruction:\\n{task} Input:{input}\\n\\n### Response:\",\n",
    "\n",
    "    \"The following instruction outlines a task. Generate a response that meets the specified request.\\n\\n### Instruction:\\n{task} Input:{input}\\n\\n### Response:\",\n",
    "\n",
    "    \"You are given an instruction and input. Write a response that completes the task as requested.\\n\\n### Instruction:\\n{task} Input:{input}\\n\\n### Response:\"\n",
    "]\n",
    "\n",
    "task_prompt = \"Please provide a response.\"\n",
    "input_text = \"Input.\"\n",
    "formatted_probes = [template.format(task=task_prompt, input=input_text) for template in probe_templates]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to get average activation\n",
    "def get_average_activation(model_instance, texts, last_token=True):\n",
    "    model_instance.eval()\n",
    "    activations = []\n",
    "    model_instance.config.output_hidden_states = True\n",
    "    \n",
    "    for text in tqdm(texts, desc=\"Processing probes\", leave=False):\n",
    "        inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=256)\n",
    "        inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            outputs = model_instance(**inputs)\n",
    "            \n",
    "        # Use the last layer hidden state\n",
    "        hidden = outputs.hidden_states[-1].float()\n",
    "        \n",
    "        if last_token:\n",
    "            # Get the last token representation\n",
    "            last_hidden = hidden[:, -1, :].squeeze(0).cpu().numpy()\n",
    "            activations.append(last_hidden)\n",
    "        else:\n",
    "            # Get average of all tokens in the sequence\n",
    "            mean_hidden = hidden.mean(dim=1).squeeze(0).cpu().numpy()\n",
    "            activations.append(mean_hidden)\n",
    "        \n",
    "    return np.mean(np.stack(activations), axis=0)\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "# Load tokenizer once\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "print(\"Loading base model...\")\n",
    "base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map=\"cuda\")\n",
    "base_model.config.output_hidden_states = True\n",
    "base_model.eval()\n",
    "\n",
    "# Get base model activation\n",
    "print(\"Computing base model activation...\")\n",
    "base_activation = get_average_activation(base_model, formatted_probes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_delta_embeddings = {}\n",
    "\n",
    "for domain, model_path in model_list:\n",
    "    try:\n",
    "        print(f\"Loading model for {domain}: {model_path}\")\n",
    "        peft_config = PeftConfig.from_pretrained(model_path)\n",
    "        peft_model = PeftModel.from_pretrained(base_model, model_path)\n",
    "        \n",
    "        peft_model.config.output_hidden_states = True\n",
    "        peft_model.eval()\n",
    "        \n",
    "        peft_model_activation = get_average_activation(peft_model, formatted_probes)\n",
    "        \n",
    "        # Compute Delta Activations\n",
    "        delta = peft_model_activation - base_activation\n",
    "        adapter_delta_embeddings[model_path] = delta\n",
    "\n",
    "        # Clean up\n",
    "        del peft_model\n",
    "        gc.collect()\n",
    "        torch.cuda.empty_cache()\n",
    "            \n",
    "    except Exception as e:\n",
    "        print(f\"Error loading model for {model_path}: {e}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if len(adapter_delta_embeddings) > 2:\n",
    "    print(\"\\nEvaluating clustering quality by dataset...\")\n",
    "    model_names = list(adapter_delta_embeddings.keys())\n",
    "    delta_matrix = np.stack([adapter_delta_embeddings[name] for name in model_names])\n",
    "    \n",
    "    # Compute cosine similarity matrix\n",
    "    similarity_matrix = cosine_similarity(delta_matrix)\n",
    "    similarity_matrix = (similarity_matrix + 1) / 2\n",
    "    distance_matrix = np.clip(1 - similarity_matrix, 0, 1)\n",
    "    \n",
    "    # Create \n",
    "    labels = [domain for domain, _ in model_list]\n",
    "\n",
    "    dataset_silhouette_avg = silhouette_score(distance_matrix, labels, metric=\"precomputed\")\n",
    "    print(f\"silhouette score for dataset clustering: {dataset_silhouette_avg:.4f}\")\n",
    "else:\n",
    "    print(\"Not enough models to calculate silhouette score (need at least 3).\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_names = list(adapter_delta_embeddings.keys())\n",
    "delta_matrix = np.stack([adapter_delta_embeddings[name] for name in adapter_names])\n",
    "tsne = TSNE(n_components=2, random_state=41, perplexity=2)\n",
    "delta_matrix_2d = tsne.fit_transform(delta_matrix.astype(np.float32))\n",
    "fig = plt.figure(figsize=(6, 3))\n",
    "ax = fig.add_subplot(1, 1, 1)\n",
    "\n",
    "dataset_colors = {\n",
    "    'legal': '#648FFF',      # Blue (IBM colorblind-friendly)\n",
    "    'math': '#FFB000',       # Orange/Yellow (IBM colorblind-friendly)\n",
    "    'medical': '#DC267F',    # Magenta (IBM colorblind-friendly)\n",
    "    'commonsense': '#785EF0', # Purple (IBM colorblind-friendly)\n",
    "    'coding': '#FE6100'      # Orange/Red (IBM colorblind-friendly)\n",
    "}\n",
    "\n",
    "colors = []\n",
    "datasets = []\n",
    "\n",
    "for domain, model_path in model_list:\n",
    "    colors.append(dataset_colors[domain])\n",
    "    datasets.append(domain)\n",
    "\n",
    "for i, (x, y) in enumerate(delta_matrix_2d):\n",
    "    ax.scatter(x, y, color=colors[i], s=60)\n",
    "\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], marker='o', color='w', markerfacecolor=color, \n",
    "           markersize=10, label=dataset)\n",
    "    for dataset, color in dataset_colors.items()\n",
    "]\n",
    "\n",
    "ax.set_title('Delta Activations', fontsize=16, fontname='DejaVu Serif')\n",
    "ax.set_xticks([])\n",
    "ax.set_yticks([])\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "\n",
    "fig.legend(handles=legend_elements, \n",
    "          bbox_to_anchor=(0.5, -0.02),\n",
    "          loc='lower center',\n",
    "          title=\"Domains\", \n",
    "          ncol=min(5, len(dataset_colors)),\n",
    "          fontsize=10, \n",
    "          frameon=True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(bottom=0.2)  # Make room for the legend\n",
    "plt.savefig(f'delta_activations.png', dpi=300, bbox_inches='tight', format='png')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lora",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
