{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "import baukit\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "import os\n",
    "from src import functional\n",
    "import src.tokens as tokenization_utils\n",
    "\n",
    "torch.__version__, transformers.__version__, torch.version.cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# MODEL_PATH = \"EleutherAI/gpt-j-6B\"\n",
    "# MODEL_PATH = \"meta-llama/Llama-2-7b-hf\"\n",
    "# MODEL_PATH = \"mistralai/Mistral-7B-v0.1\"\n",
    "MODEL_PATH = \"state-spaces/mamba-2.8b-slimpj\" # state-spaces/mamba-2.8b\n",
    "\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_path=MODEL_PATH, \n",
    "    torch_dtype=torch.float32\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Cache promoting tokens for all the `down_proj` neurons\n",
    "# ---------------------------------------\n",
    "cut_off_rank = 50\n",
    "path = \"../results/neuron_prommotions/out_proj\"\n",
    "out_proj_path_format = \"layers.{}.mixer.out_proj\"\n",
    "# ---------------------------------------\n",
    "\n",
    "# os.makedirs(path, exist_ok=True)\n",
    "\n",
    "# neuron_promotions = {layer_idx: {} for layer_idx in range(mt.n_layer)}\n",
    "\n",
    "# for layer_idx in tqdm(range(mt.n_layer)):\n",
    "#     print(f\"layer {layer_idx}\")\n",
    "#     out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer_idx))\n",
    "\n",
    "#     for column in tqdm(range(out_proj.weight.shape[1])):\n",
    "#         next_tok_candidates = functional.logit_lens(\n",
    "#             mt = mt, \n",
    "#             h = out_proj.weight[:, column],\n",
    "#             k = cut_off_rank\n",
    "#         )\n",
    "\n",
    "#         neuron_promotions[layer_idx][column] = [\n",
    "#             {\"token\": tok, \"logit\": logit} for tok, logit in next_tok_candidates\n",
    "#         ]\n",
    "    \n",
    "#     with open(os.path.join(path, f\"layer_{layer_idx}.json\"), \"w\") as f:\n",
    "#         json.dump(neuron_promotions[layer_idx], f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cache_weights(mt):\n",
    "    weights_cached = {}\n",
    "    for layer in range(mt.n_layer):\n",
    "        out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer))\n",
    "        weights_cached[layer] = out_proj.weight.clone().detach()\n",
    "    return weights_cached\n",
    "\n",
    "def restore_weights(mt, weights_cached):\n",
    "    for layer in range(mt.n_layer):\n",
    "        out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer))\n",
    "        with torch.no_grad():\n",
    "            out_proj.weight[...] = weights_cached[layer]\n",
    "\n",
    "# weights_cached = cache_weights(mt)\n",
    "# restore_weights(mt, weights_cached)\n",
    "            \n",
    "####################################\n",
    "WEIGHTS_CACHED = cache_weights(mt)\n",
    "####################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "########################################################################\n",
    "concepts = [\n",
    "    \"doctor\", \"nurse\", \"therapist\",\n",
    "    \"healthcare\", \"medicine\", \"medical\"\n",
    "]\n",
    "# concepts = [\n",
    "#     \"computer\", \"software\", \"engineer\", \"programmer\", \"developer\", \"hacker\"\n",
    "# ]\n",
    "########################################################################\n",
    "\n",
    "concept_start_token_ids= mt.tokenizer([\n",
    "    \" \" + concept + \" \" for concept in concepts\n",
    "], return_tensors=\"pt\", padding=True).input_ids[:, 0].tolist()\n",
    "\n",
    "[(id, mt.tokenizer.decode(id)) for id in concept_start_token_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_neuron_promotions(path, layer_idx):\n",
    "    with open(os.path.join(path, f\"layer_{layer_idx}.json\")) as f:\n",
    "        loaded_dict = json.load(f)\n",
    "        return {\n",
    "            int(k): v for k, v in loaded_dict.items()\n",
    "        }\n",
    "\n",
    "neuron_promotions = {layer_idx: {} for layer_idx in range(mt.n_layer)}\n",
    "for layer_idx in range(mt.n_layer):\n",
    "    neuron_promotions[layer_idx] = load_neuron_promotions(path, layer_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cache promoting tokens for all the `down_proj` neurons\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "cut_off_rank = 1\n",
    "concept_drivers = {layer_idx: [] for layer_idx in range(mt.n_layer)}\n",
    "\n",
    "for layer_idx in range(mt.n_layer):\n",
    "    out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer_idx))\n",
    "\n",
    "    found_neurons = []\n",
    "\n",
    "    for column in range(out_proj.weight.shape[1]):\n",
    "        # for t in concept_start_token_ids:\n",
    "        #     if concept_ranks[t]['rank'] <= cut_off_rank:\n",
    "        #         concept_driver_neurons.append({\n",
    "        #             \"layer\": layer_idx,\n",
    "        #             \"neuron\": column,\n",
    "        #             \"concept_ranks\": concept_ranks,\n",
    "        #         })\n",
    "        #         break\n",
    "        candidate_tokens = [\n",
    "            candidate[\"token\"] for candidate in neuron_promotions[layer_idx][column]\n",
    "        ][:cut_off_rank]\n",
    "        for target_token in concepts:\n",
    "            found = False\n",
    "            for candidate in candidate_tokens:\n",
    "                if len(candidate.strip()) < 4:  # skip very short trivial tokens\n",
    "                    continue\n",
    "                if functional.is_nontrivial_prefix(\n",
    "                    prediction=candidate, target=target_token\n",
    "                ):\n",
    "                    found = True\n",
    "                    concept_drivers[layer_idx].append(column)\n",
    "                    found_neurons.append(column)\n",
    "                    break\n",
    "            if found:\n",
    "                break\n",
    "\n",
    "    if len(found_neurons) > 0:\n",
    "        print(\n",
    "            f\"found {len(found_neurons)} neurons in layer {layer_idx} > {found_neurons}\"\n",
    "        )\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layer = 25\n",
    "# neuron = 1799\n",
    "\n",
    "\n",
    "# out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer))\n",
    "# logits = mt.lm_head(out_proj.weight[:, neuron])\n",
    "\n",
    "# logit_values = logits.sort(descending=True).values.detach().cpu().numpy()[:30]\n",
    "# logit_tokens = logits.sort(descending=True).indices.detach().cpu().numpy()[:30]\n",
    "\n",
    "# logit_tokens = ['\"{}\"'.format(mt.tokenizer.decode([t])) for t in logit_tokens]\n",
    "\n",
    "# from matplotlib import pyplot as plt\n",
    "# plt.bar(range(len(logit_values)), logit_values)\n",
    "# plt.xticks(range(len(logit_values)), logit_tokens, rotation=90)\n",
    "\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ! Don't forget to restore the weights\n",
    "restore_weights(mt, WEIGHTS_CACHED)\n",
    "\n",
    "prompt = \"Eric works as a\"\n",
    "prompt = tokenization_utils.maybe_prefix_eos(mt.tokenizer, prompt)\n",
    "\n",
    "functional.predict_next_token(\n",
    "    mt = mt,\n",
    "    prompt = prompt\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "magnify_scale = 100\n",
    "\n",
    "# ! Don't forget to restore the weights\n",
    "restore_weights(mt, WEIGHTS_CACHED)\n",
    "\n",
    "for layer in concept_drivers:\n",
    "    for neuron in concept_drivers[layer]:\n",
    "        out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer))\n",
    "        with torch.no_grad():\n",
    "            out_proj.weight[:, neuron] *= magnify_scale\n",
    "\n",
    "functional.predict_next_token(\n",
    "    mt = mt,\n",
    "    prompt = prompt\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "restore_weights(mt, WEIGHTS_CACHED)\n",
    "\n",
    "functional.mamba_generate(\n",
    "    mt = mt, \n",
    "    prompt = prompt,\n",
    "    topk=1\n",
    ").generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Driving LM by replacing certain neurons with a set random concept directions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[(id, mt.tokenizer.decode(id)) for id in concept_start_token_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lm_head = baukit.get_module(mt.model, \"lm_head\")\n",
    "lm_head.weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models\n",
    "\n",
    "noise_level = .1\n",
    "num_rand_vectors = 10\n",
    "\n",
    "random_neurons = []\n",
    "\n",
    "for id in concept_start_token_ids:\n",
    "    unembed_row = lm_head.weight[id].squeeze().clone().detach()\n",
    "    for _ in range(num_rand_vectors):\n",
    "        random_neurons.append(unembed_row + torch.randn_like(unembed_row) * noise_level)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "all_neurons = [\n",
    "    (layer, neuron) for layer in range(mt.n_layer) for neuron in range(lm_head.weight.shape[1])\n",
    "]\n",
    "\n",
    "random_neuron_idxes = random.sample(all_neurons, k=len(random_neurons))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "magnify_scale = 100\n",
    "\n",
    "# ! Don't forget to restore the weights\n",
    "restore_weights(mt, WEIGHTS_CACHED)\n",
    "\n",
    "for (layer, neuron_idx), neuron in zip(random_neuron_idxes, random_neurons):\n",
    "    out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer))\n",
    "    with torch.no_grad():\n",
    "        out_proj.weight[:, neuron_idx] = magnify_scale * neuron\n",
    "\n",
    "functional.predict_next_token(\n",
    "    mt = mt,\n",
    "    prompt = prompt\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
