{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9bcf1224-40eb-48cd-b528-b2f7dfdad584",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!mkdir Database #Only run once, makes a folder where the h5py database is saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b332fe4-996d-4f54-922e-8355caed189b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#This is required to import the APMAE model from the other directory\n",
    "import sys, os\n",
    "path2add = os.path.normpath(os.path.abspath(os.path.join(os.path.dirname('./run.ipynb'), os.path.pardir, 'Model')))\n",
    "if (not (path2add in sys.path)) :\n",
    "    sys.path.append(path2add)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d7c11b6-fcfd-4bf4-a49e-c050ab8566ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Our code\n",
    "from DataUtil.DataLoader import IterableAttentionLoader\n",
    "from DataUtil.AttentionData import AttentionData\n",
    "from ap_mae import APMAE\n",
    "\n",
    "#Imported packages\n",
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from collections import Counter\n",
    "\n",
    "#We recommend to use the cuml package for quicker computation if a decent gpu is available, can be replaced by the corresponding sklearn packages\n",
    "from cuml import UMAP\n",
    "from cuml import HDBSCAN\n",
    "from cuml.metrics.pairwise_distances import pairwise_distances\n",
    "\n",
    "\n",
    "#for classification\n",
    "from sklearn.model_selection import train_test_split\n",
    "from catboost import CatBoostClassifier, Pool\n",
    "\n",
    "from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import shap\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "635abcfe-6dec-4a0d-a1a6-854bfaf52684",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd95bfb6-94dd-4373-8011-f81f1bb10fd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "size = '3B' #Set the size of the target model here. 3B, 7B or 15B\n",
    "            #3B requires 2TB of storage\n",
    "            #7B requires 3.5TB of storage\n",
    "            #15B requires 5.5TB of storage\n",
    "\n",
    "db_name = \"reproduction_{}\".format(size)\n",
    "target_model_name = 'bigcode/starcoder2-{}'.format(size.lower())\n",
    "encoding_model_name = 'LaughingLogits/AP-MAE-SC2-{}'.format(size)\n",
    "dataset_name = 'LaughingLogits/Stackless_Java_V2'\n",
    "split = 'test'\n",
    "\n",
    "device = 'cpu'\n",
    "languages = ['java']\n",
    "\n",
    "tasks = ['noise', 'random', 'identifiers', 'boolean_literals', 'numeric_literals', 'string_literals', 'boolean_operators', 'mathematical_operators', 'assignment_operators', 'eol', 'closing_bracket']\n",
    "\n",
    "samples_per_task = 1000\n",
    "context_length = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd044f9b-ddad-4460-abf0-7f0c9afd25c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# These can be replaced with a list of keys, but we used all values in our investigation\n",
    "# e.g. incorrect java predicitions for the eol task, all heads from layer 4 and 7\n",
    "# langs = ['java']\n",
    "# corrects = ['incorrect']\n",
    "# querys = ['eol']\n",
    "# layers = ['4','7']\n",
    "# heads = \"*\"\n",
    "langs = \"*\"\n",
    "corrects = \"*\"\n",
    "querys = \"*\"\n",
    "layers = \"*\"\n",
    "heads = \"*\"\n",
    "\n",
    "if size =='3B':\n",
    "    n_layers, n_heads = 30, 24\n",
    "elif size =='7B':\n",
    "    n_layers, n_heads = 32, 36\n",
    "elif size =='15B':\n",
    "    n_layers, n_heads = 40, 48"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "775a3422-edcf-4433-bcb3-bb5c739a15ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map=\"auto\")\n",
    "encoding_model = APMAE.from_pretrained(pretrained_model_name_or_path=encoding_model_name)\n",
    "attention_data = AttentionData(target_model.config, tasks, languages, db_name)\n",
    "attention_loader = IterableAttentionLoader(dataset_name, samples_per_task, context_length, tasks, languages[0], target_model_name, False, target_model, device, split, True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1e7edac-1db1-4ee6-8710-0804b0a5f1cb",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Generate patterns and encode - Section 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "853ca6d3-e88e-4bdc-8fb3-98c009288c98",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Run this and all patterns are saved, it takes up alot of storage (up to 5.5TB per 10,000 samples).\n",
    "attention_data.generate_patterns(attention_loader)\n",
    "attention_data.encode(encoding_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9bdfa23-968b-4721-bad4-6934b5b17c0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Run this and it wont save the actual pattern, only the encoding (up to 750GB per 100,000 samples) 10 is the btachsize for the encoder.\n",
    "attention_data.generate_and_encode(attention_loader, encoding_model, 10)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1fcb7a0-6ff5-4aeb-ba60-f08a1f40bc9e",
   "metadata": {},
   "source": [
    "# Clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89fd76f6-2f38-4760-91f0-1c2b9767e0cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def jitter(X, jitter = 1e-3):\n",
    "    X = np.array(X)  # ensure it's a NumPy array\n",
    "    \n",
    "    # Step 1: Identify duplicate rows\n",
    "    _, idx_unique, idx_inverse, counts = np.unique(X, axis=0, return_index=True, return_inverse=True, return_counts=True)\n",
    "    \n",
    "    # Step 2: Find indices of duplicated rows (excluding the first occurrence)\n",
    "    duplicate_mask = counts[idx_inverse] > 1\n",
    "    first_occurrence_mask = np.zeros_like(duplicate_mask)\n",
    "    first_occurrence_mask[idx_unique] = True\n",
    "    final_mask = duplicate_mask & ~first_occurrence_mask  # Only actual duplicates\n",
    "    \n",
    "    # Step 3: Add noise to just the duplicated rows\n",
    "    X[final_mask] += jitter * np.random.randn(np.sum(final_mask), X.shape[1])\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73494749-485e-4746-969f-6bf02cbd39e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for l in range(n_layers):\n",
    "    for h in range(n_heads):\n",
    "        X = attention_data.data.get_grouped_samples(langs, corrects, querys, [l], [h], 'enc_cls')\n",
    "\n",
    "\n",
    "        #We add jitter, only where the values match, if it fails 3 times, we add jitter everywhere\n",
    "        jitter_val = 1e-3\n",
    "        for attempt in range(3):\n",
    "            try:\n",
    "                X = jitter(X, jitter = jitter_val)\n",
    "                X = cp.asarray(X)\n",
    "                # The UMAP model, with the hyperparameters we used\n",
    "                umap_model = UMAP(\n",
    "                    n_components=8,\n",
    "                    n_neighbors=20,\n",
    "                    min_dist=0.05,\n",
    "                    metric='cosine'\n",
    "                )\n",
    "                X_embed = umap_model.fit_transform(X)\n",
    "                break #Exit retry loop if it worked\n",
    "            except Exception as e:\n",
    "                print(f\"Attempt {attempt + 1} failed\")\n",
    "                print(jitter_val)\n",
    "                jitter_val = 5*jitter_val\n",
    "                if attempt == 2:\n",
    "                    print(\"selective jitter failed, jittering everywhere\")\n",
    "                    X = cp.asarray(X.get() + 1e-3 * np.random.randn(X.shape[0], X.shape[1]))\n",
    "                        umap_model = UMAP(\n",
    "                        n_components=8,\n",
    "                        min_dist=0.05,\n",
    "                        n_neighbors=20,\n",
    "                        metric='cosine'\n",
    "                    )\n",
    "                    X_embed = umap_model.fit_transform(X)\n",
    "                    \n",
    "        # The HDBSCAN model with the hyperparameters we used.\n",
    "        hdbscan_model = HDBSCAN(min_samples=20, min_cluster_size=25, allow_single_cluster=True) \n",
    "        labels = hdbscan_model.fit_predict(X_embed.get())\n",
    "    \n",
    "        #Save the clusters in our H5PY Database\n",
    "        attention_data.data.write_grouped_samples(langs, corrects, querys, [l], [h], \"class_cls\", labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78c0087c-88a6-4427-889d-897c3fcbe26c",
   "metadata": {},
   "source": [
    "# Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57304f9a-096b-4e25-8c0c-9178f8ca4c16",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Categorical value features to pass to CatBoost\n",
    "col_names = []\n",
    "for l in range(40):\n",
    "    for h in range(48):\n",
    "        col_names.append(f'l{l}h{h}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1d9056e-490c-4aa6-8f06-dbf890f22ccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['random', 'identifiers', 'boolean_literals', 'numeric_literals', 'string_literals', 'boolean_operators', 'mathematical_operators', 'assignment_operators', 'eol', 'closing_bracket']\n",
    "\n",
    "for t in tasks:\n",
    "    for l in tqdm(range(n_layers)):\n",
    "        for h in range(n_heads):\n",
    "            df[f'l{l}h{h}'] = attention_data.data.get_grouped_samples(langs, corrects, querys, [l], [h], 'enc_cls')\n",
    "\n",
    "    # returns the labels correct, or incorrect for each prediction\n",
    "    y = attention_data.data.get_grouped_clusters(langs, corrects, querys, layers, heads, 'enc_cls', True, False, True, True, True) \n",
    "        \n",
    "    X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.1, random_state=42, stratify=y)\n",
    "\n",
    "    model = CatBoostClassifier(\n",
    "        iterations=1000,\n",
    "        depth=6,\n",
    "        learning_rate=0.1,\n",
    "        loss_function='Logloss',\n",
    "        cat_features=col_names,\n",
    "        verbose=0,\n",
    "        early_stopping_rounds=25,\n",
    "        eval_fraction=0.1,\n",
    "        task_type='GPU', #remove if no GPU is available\n",
    "        devices='0'\n",
    "    )\n",
    "    \n",
    "    # Fit the model\n",
    "    model.fit(X_train, y_train)\n",
    "    \n",
    "    # Predict\n",
    "    y_pred = model.predict(X_test)\n",
    "\n",
    "\n",
    "    #Evaluate the classification\n",
    "    cmd = ConfusionMatrixDisplay.from_predictions(y_test, y_pred)\n",
    "    plt.show()\n",
    "    accuracy = accuracy_score(y_test, y_pred)\n",
    "    print(f'{t} accuracy: {accuracy}')\n",
    "\n",
    "\n",
    "    ##### SHAP VALUES #####\n",
    "    pool = Pool(X_test, y_test, cat_features=col_names)\n",
    "    # Get SHAP values from CatBoost\n",
    "    shap_values = model.get_feature_importance(pool, type='ShapValues')\n",
    "    \n",
    "    # Extract only per-feature SHAP values\n",
    "    feature_shap_values = shap_values[:, 1:]\n",
    "    \n",
    "    # Build a shap.Explanation object\n",
    "    expl = shap.Explanation(\n",
    "        values=feature_shap_values,\n",
    "        base_values=shap_values[:, 0],\n",
    "        data=X_test.values,\n",
    "        feature_names=X_test.columns.tolist()\n",
    "    )\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b21922bc-90ed-40e2-9f82-344e25fc59a1",
   "metadata": {},
   "source": [
    "# Intervention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e1faf71-b0ed-46d9-bd52-f49e5f954936",
   "metadata": {},
   "outputs": [],
   "source": [
    "########## Helper functions for the main intervention loop ##################################\n",
    "def get_global_pos_neg_features(explanation, setting = 'pos_only', top_n=10, neutral_threshold = 1e-5):\n",
    "    \"\"\"\n",
    "    Get top N globally positive and negative contributing features\n",
    "    from a SHAP Explanation object.\n",
    "    \"\"\"\n",
    "    shap_values = explanation.values  # (n_samples, n_features)\n",
    "    feature_names = explanation.feature_names\n",
    "\n",
    "    # compute mean positive and mean negative contributions per feature\n",
    "    mean_pos = np.where(shap_values > 0, shap_values, 0).mean(axis=0)\n",
    "    mean_neg = np.where(shap_values < 0, shap_values, 0).mean(axis=0)\n",
    "    mean_total = shap_values.mean(axis=0)\n",
    "\n",
    "    if setting == 'pos_only':\n",
    "    # pair names with values\n",
    "        features = [x[0] for x in sorted(zip(feature_names, mean_pos), key=lambda x: x[1], reverse=True)]\n",
    "        features = features[:top_n]\n",
    "    elif setting == 'neg_only':\n",
    "        features = [x[0] for x in sorted(zip(feature_names, mean_neg), key=lambda x: x[1])]  # already negative\n",
    "        features = features[:top_n]\n",
    "    elif setting =='neutral':\n",
    "        features = [x[0] for x in [(f, v) for f, v in zip(feature_names, mean_total)\n",
    "                        if abs(v) <= neutral_threshold]]\n",
    "        features = random.choices(features, k=top_n)\n",
    "\n",
    "    return features\n",
    "\n",
    "def get_random_heads(n_heads = 5):\n",
    "    heads = [f\"l{l}h{h}\" for l in range(model_layers) for h in range(model_heads)]\n",
    "    random.shuffle(heads)\n",
    "    return heads[:n_heads]\n",
    "\n",
    "import contextlib\n",
    "import torch\n",
    "from typing import Dict, Iterable, Optional\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "###############################################################################\n",
    "# HeadSkipper: zeroes specified attention heads per layer during inference.\n",
    "# - Works with kv caching (past_key_values) because we don't mutate K/V.\n",
    "# - Operates right before the attention o_proj, so it's fast and robust.\n",
    "###############################################################################\n",
    "\n",
    "class HeadSkipper:\n",
    "    def __init__(self, model, heads_by_layer):\n",
    "        self.model = model\n",
    "        self.heads_by_layer = {int(l): sorted(set(v)) for l, v in heads_by_layer.items()}\n",
    "        # Validate layer indices\n",
    "        num_layers = len(self.model.model.layers)\n",
    "        for layer_idx in self.heads_by_layer:\n",
    "            if not (0 <= layer_idx < num_layers):\n",
    "                raise ValueError(f\"Layer index {layer_idx} out of range (0 to {num_layers-1})\")\n",
    "        # Validate head indices\n",
    "        num_heads = model.config.num_attention_heads\n",
    "        for heads in self.heads_by_layer.values():\n",
    "            for h in heads:\n",
    "                if not (0 <= h < num_heads):\n",
    "                    raise ValueError(f\"Head index {h} out of range (0 to {num_heads-1})\")\n",
    "\n",
    "        self.handles = []\n",
    "\n",
    "    @contextlib.contextmanager\n",
    "    def apply(self):\n",
    "        try:\n",
    "            for layer_idx, heads in self.heads_by_layer.items():\n",
    "                block = self.model.model.layers[layer_idx]\n",
    "                o_proj = block.self_attn.o_proj\n",
    "\n",
    "                def make_pre_hook(heads_to_zero):\n",
    "                    def pre_hook(module, inputs):\n",
    "                        x, = inputs  # shape: [batch, seq_len, hidden_size]\n",
    "                        num_heads = self.model.config.num_attention_heads\n",
    "                        head_dim = x.shape[-1] // num_heads\n",
    "                        x_view = x.view(x.size(0), x.size(1), num_heads, head_dim)\n",
    "                        x_view[:, :, heads_to_zero, :] = 0\n",
    "                        return (x_view.reshape(x.shape[0], x.shape[1], -1),)\n",
    "                    return pre_hook\n",
    "\n",
    "                handle = o_proj.register_forward_pre_hook(make_pre_hook(heads))\n",
    "                self.handles.append(handle)\n",
    "            yield\n",
    "        finally:\n",
    "            for h in self.handles:\n",
    "                h.remove()\n",
    "            self.handles.clear()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8fb646a-a0b3-4847-aff4-a47f123e81aa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
