{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7282012b-41d1-4be0-baa6-1ef164d99b4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# need to do this before transformer imports\n",
    "import os\n",
    "os.environ['HF_HOME'] = '/workspace/cache/huggingface/'\n",
    "\n",
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "from evals.utils import *\n",
    "from models.bigram_model import *\n",
    "from models.mlp_model import *\n",
    "from models.future_model import *\n",
    "from data.utils import get_tokenizer\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import nn\n",
    "from itertools import islice\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import Dataset\n",
    "from torch import nn\n",
    "\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import gc\n",
    "from glob import glob\n",
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "afd263b0-21a7-4a6d-bc2e-21729a19b20f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def invert(f, y, start=1, eps=1e-4):\n",
    "    '''\n",
    "    Given monotonic increasing function f with domain [0,\\infty), returns f^{-1}(y)\n",
    "    '''\n",
    "    if f(0) > y + eps:\n",
    "        assert False\n",
    "    elif np.abs(f(0) - y) < eps:\n",
    "        return 0\n",
    "    x = start\n",
    "    if np.abs(f(x) - y) < eps:\n",
    "        return x\n",
    "    while f(x) < y:\n",
    "        x *= 2\n",
    "    while f(x) > y:\n",
    "        x /= 2\n",
    "    return x + invert(lambda z: f(z + x), y, start=x, eps=eps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c19ac071-e3dd-473c-a4dd-50eaa0f00296",
   "metadata": {},
   "outputs": [],
   "source": [
    "lsqr_cache = dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a76fa2b7-403a-41d2-b7db-acc7a276ea97",
   "metadata": {},
   "outputs": [],
   "source": [
    "def constr_lsqr(y, A, c):\n",
    "    '''\n",
    "    Finds inf_w \\|y-Aw\\|_2 s.t. \\|w\\|_2<=c\n",
    "    '''\n",
    "    hash = tuple(A.flatten())\n",
    "    if not hash in lsqr_cache:\n",
    "        print('calcing SVD!')\n",
    "        ATA = A.T @ A\n",
    "        lsqr_cache[hash] = ATA, np.linalg.svd(ATA)\n",
    "        print('SVD done!')\n",
    "\n",
    "    ATA, (U, S, VT) = lsqr_cache[hash]\n",
    "    ATy = A.T @ y\n",
    "    VTATy = VT @ ATy\n",
    "    S = S.reshape((-1, 1))\n",
    "    \n",
    "    neg_norm = lambda lam: -np.linalg.norm(VTATy / (S + lam))        \n",
    "    if neg_norm(0) >= -c:\n",
    "        lam = 0\n",
    "    else:\n",
    "        lam = invert(neg_norm, -c)\n",
    "    return np.linalg.inv(ATA + lam * np.eye(A.shape[1])) @ ATy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c4eca0a6-3e13-47d8-85fc-534430f68081",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Intervene(nn.Module):\n",
    "    '''\n",
    "    Replaces (some subset of) input hidden_states with new_states\n",
    "    Expects (batch_size, seq_length, embed_dim)\n",
    "    '''\n",
    "    def __init__(self, new_states):\n",
    "        super().__init__()\n",
    "        self.new_states = new_states\n",
    "\n",
    "    def forward(self, hidden_states, **kwargs):\n",
    "        for i in [0, 2]:\n",
    "            assert self.new_states.shape[i] == hidden_states.shape[i]\n",
    "        hidden_states[:,:self.new_states.shape[1],:] = self.new_states\n",
    "        return hidden_states, None, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9ffa6c4f-3437-4c25-8c94-166a4ecdf5bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'mistralai/Mistral-7B-v0.1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0877190b-7f66-49be-b066-ec4f92db78f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "Token = {v: k for k, v in tokenizer.get_vocab().items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0dbb203d-1630-422a-b131-c943158d0bb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_tokens(s):\n",
    "    tokens = tokenizer(s)['input_ids']\n",
    "    print('|'.join(Token[t] for t in tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e3adee90-c6e9-4d33-86bf-29da897650cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def topk(v, k=10):\n",
    "    # Takes in logits\n",
    "    #v = softmax(v.flatten())\n",
    "    v = v.flatten()\n",
    "    idxs = v.argsort()[-k:][::-1]\n",
    "    ret = [(Token[i], v[i]) for i in idxs]\n",
    "    return pd.DataFrame(ret, columns=['token', 'logit'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "00f1db30-dfa7-4b9f-ba8f-f9c840db09f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt = glob(\n",
    "    '/workspace/checkpoints/MISTRAL-NECK-SWEEP_*_hidden_idxs-31_hidden_lb-0_token_lb--1_neck_cls-mlp_*',\n",
    ")[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0187ee37-68d0-4bc9-affa-bb98b47de42e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/workspace/checkpoints/MISTRAL-NECK-SWEEP_20231231-131900-D4c1E_hidden_idxs-31_hidden_lb-0_token_lb--1_neck_cls-mlp_epoch=00-val_self_loss=4.63.ckpt'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "97718c0e-9b6e-4cda-bb39-191e656f2e35",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d6d247bca582478284a1552cbd53be29",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = LitFutureModelWithNeck.load_from_checkpoint(ckpt, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b2bf1bfe-6a07-4858-805e-81de2d38d9c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_layers = copy.deepcopy(model.base_model.model.layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a663ee15-2596-4dac-8d68-704b19f4549a",
   "metadata": {},
   "outputs": [],
   "source": [
    "neck = model.future_neck.layers[0].weight.data.cpu().numpy()\n",
    "D = model.base_model.lm_head.weight.data.cpu().numpy()\n",
    "E = model.base_model.model.embed_tokens.weight.data.cpu().numpy()\n",
    "A = D @ neck"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "343cd1a2-36c7-4871-86c5-b0e1070d8e76",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [1, 549, 28250], 'attention_mask': [1, 1, 1]}"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer('platinum')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c2da926e-1a6a-42e9-b47d-6b2f8559f375",
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = np.zeros((32000, 1))\n",
    "t1[tokenizer('platinum').input_ids[-2]] = 1\n",
    "t2 = np.zeros((32000, 1))\n",
    "t2[tokenizer('platinum').input_ids[-1]] = 1\n",
    "v1 = E.T @ t1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "58de1748-2b75-459a-b724-1245cd0e3c3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7fb3986122f0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "24a26bcc-c934-41f6-b11f-d356a7b1d58c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s>|▁pl|atinum\n"
     ]
    }
   ],
   "source": [
    "print_tokens('platinum')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "68f55840-ccad-4eb2-89ab-420d3a10a616",
   "metadata": {},
   "outputs": [],
   "source": [
    "input = tokenizer('My favorite element of the periodic table is', return_tensors='pt').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "85964734-ef7a-4c7b-a24b-d144c0a780bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "del model.base_model.model.layers\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "model.base_model.model.layers = copy.deepcopy(orig_layers)\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cbc9310d-b9d7-4de9-a9f4-1bdd9721e692",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model.base_model(**input, output_hidden_states=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d7f714ce-a561-45bc-b430-9b5c69121f59",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>▁the</td>\n",
       "      <td>10.955291</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>▁arg</td>\n",
       "      <td>9.615419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>▁carbon</td>\n",
       "      <td>9.581732</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>▁gold</td>\n",
       "      <td>9.542505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>▁b</td>\n",
       "      <td>9.452555</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>▁hydro</td>\n",
       "      <td>9.409390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>▁silver</td>\n",
       "      <td>9.310298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>▁probably</td>\n",
       "      <td>9.113802</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>▁flu</td>\n",
       "      <td>9.051174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>▁mer</td>\n",
       "      <td>9.013393</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       token      logit\n",
       "0       ▁the  10.955291\n",
       "1       ▁arg   9.615419\n",
       "2    ▁carbon   9.581732\n",
       "3      ▁gold   9.542505\n",
       "4         ▁b   9.452555\n",
       "5     ▁hydro   9.409390\n",
       "6    ▁silver   9.310298\n",
       "7  ▁probably   9.113802\n",
       "8       ▁flu   9.051174\n",
       "9       ▁mer   9.013393"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(out.logits[0,-1,:].cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2e328d84-0e8e-49b3-b7a6-5e7034d3ef92",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>on</td>\n",
       "      <td>8.596912</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ium</td>\n",
       "      <td>8.572868</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>▁element</td>\n",
       "      <td>8.255547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>icon</td>\n",
       "      <td>8.056475</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>gen</td>\n",
       "      <td>7.965827</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>um</td>\n",
       "      <td>7.789818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>od</td>\n",
       "      <td>7.607457</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>▁the</td>\n",
       "      <td>7.550673</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>an</td>\n",
       "      <td>7.540741</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>cury</td>\n",
       "      <td>7.414806</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      token     logit\n",
       "0        on  8.596912\n",
       "1       ium  8.572868\n",
       "2  ▁element  8.255547\n",
       "3      icon  8.056475\n",
       "4       gen  7.965827\n",
       "5        um  7.789818\n",
       "6        od  7.607457\n",
       "7      ▁the  7.550673\n",
       "8        an  7.540741\n",
       "9      cury  7.414806"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h = out.hidden_states[31][0,-1,:].detach().cpu().numpy().reshape((-1, 1))\n",
    "#topk(A @ np.concatenate([h, v1]))\n",
    "topk(A @ h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "941073d8-7499-45b3-801e-11eff6e44778",
   "metadata": {},
   "outputs": [],
   "source": [
    "#target = 800 * t2 - A @ np.concatenate([h, v1], axis=0)  # output logits have norm around ~800\n",
    "target = 800 * t2 - A @ h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "ec97870c-85cd-4b2c-b39c-97fcf931caa7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 1.000094168294038\n",
      "3 3.0000024176526274\n",
      "5 4.9999936743781515\n",
      "7 6.999983563878301\n",
      "9 9.000068985477538\n",
      "11 10.999995821994638\n",
      "13 13.000050517633445\n",
      "15 15.000021930273176\n",
      "17 17.000056975894278\n",
      "19 19.00001069258778\n",
      "21 20.999901042381783\n",
      "23 22.999962233140653\n",
      "25 25.00004614683991\n",
      "27 27.000099062531113\n",
      "29 29.000081017694367\n",
      "31 30.999967419911545\n",
      "33 33.00006762607578\n",
      "35 34.99993622984038\n",
      "37 37.000032519153024\n",
      "39 38.99994214636536\n"
     ]
    }
   ],
   "source": [
    "v2_dict = {}\n",
    "for eps in range(1, 40, 2):\n",
    "    v2_dict[eps] = constr_lsqr(target, A[:,:4096], eps)\n",
    "    print(eps, np.linalg.norm(v2_dict[eps]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "c7447715-229b-4cde-8a9f-f3c74822aa69",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "      token     logit\n",
      "0       ium  8.912441\n",
      "1      icon  8.591976\n",
      "2  ▁element  8.486414\n",
      "3       gen  8.478638\n",
      "4      cury  8.344257\n",
      "5        on  8.338239\n",
      "6       rom  7.823799\n",
      "7       ith  7.773242\n",
      "8        od  7.733794\n",
      "9        um  7.712451\n",
      "3\n",
      "      token     logit\n",
      "0      cury  9.747360\n",
      "1      icon  9.031064\n",
      "2  ▁element  8.890864\n",
      "3      obal  8.889878\n",
      "4       ium  8.809325\n",
      "5       gen  8.791047\n",
      "6      osph  8.410558\n",
      "7    atinum  8.271071\n",
      "8     xygen  8.206147\n",
      "9     rogen  8.154701\n",
      "5\n",
      "      token     logit\n",
      "0      cury  9.410071\n",
      "1    atinum  8.897562\n",
      "2  ▁element  8.372938\n",
      "3      obal  7.942448\n",
      "4     xygen  7.929935\n",
      "5   ▁atomic  7.664288\n",
      "6     ▁atom  7.656118\n",
      "7      icon  7.591521\n",
      "8      ▁bor  7.267469\n",
      "9   ▁silver  7.192181\n",
      "7\n",
      "      token     logit\n",
      "0    atinum  9.534046\n",
      "1      cury  8.490225\n",
      "2  ▁element  7.115472\n",
      "3   ▁silver  6.783089\n",
      "4     xygen  6.742288\n",
      "5      obal  6.721434\n",
      "6   ▁atomic  6.512379\n",
      "7     ▁atom  6.277526\n",
      "8      icon  6.204196\n",
      "9     ▁gold  6.107668\n",
      "9\n",
      "      token      logit\n",
      "0    atinum  10.672559\n",
      "1      cury   7.726743\n",
      "2   ▁silver   6.595661\n",
      "3     ▁gold   6.249301\n",
      "4  ▁element   6.067153\n",
      "5      obal   5.933531\n",
      "6     xygen   5.777470\n",
      "7   ▁atomic   5.586040\n",
      "8   ▁Silver   5.300992\n",
      "9      icon   5.234757\n",
      "11\n",
      "      token      logit\n",
      "0    atinum  12.133286\n",
      "1      cury   7.111580\n",
      "2   ▁silver   6.629243\n",
      "3     ▁gold   6.599260\n",
      "4   ▁Silver   5.452147\n",
      "5      obal   5.397273\n",
      "6  ▁element   5.236949\n",
      "7     ▁Gold   5.123486\n",
      "8     xygen   5.031588\n",
      "9     arium   5.015334\n",
      "13\n",
      "     token      logit\n",
      "0   atinum  13.772941\n",
      "1    ▁gold   7.079744\n",
      "2  ▁silver   6.807795\n",
      "3     cury   6.622132\n",
      "4    ▁Gold   5.729812\n",
      "5  ▁Silver   5.686037\n",
      "6     atin   5.324232\n",
      "7     gold   5.222122\n",
      "8     obal   5.023960\n",
      "9    arium   4.971543\n",
      "15\n",
      "     token      logit\n",
      "0   atinum  15.488208\n",
      "1    ▁gold   7.624703\n",
      "2  ▁silver   7.066274\n",
      "3    ▁Gold   6.363385\n",
      "4     cury   6.235078\n",
      "5  ▁Silver   5.966180\n",
      "6     atin   5.919941\n",
      "7     gold   5.658168\n",
      "8     Gold   5.379672\n",
      "9    arium   4.958552\n",
      "17\n",
      "     token      logit\n",
      "0   atinum  17.212785\n",
      "1    ▁gold   8.189899\n",
      "2  ▁silver   7.360555\n",
      "3    ▁Gold   6.995448\n",
      "4     atin   6.483948\n",
      "5  ▁Silver   6.264951\n",
      "6     gold   6.122634\n",
      "7     Gold   5.968752\n",
      "8     cury   5.929559\n",
      "9    arium   4.960556\n",
      "19\n",
      "      token      logit\n",
      "0    atinum  18.907424\n",
      "1     ▁gold   8.749065\n",
      "2   ▁silver   7.663958\n",
      "3     ▁Gold   7.608820\n",
      "4      atin   7.008349\n",
      "5      gold   6.589246\n",
      "6   ▁Silver   6.564481\n",
      "7      Gold   6.534073\n",
      "8      cury   5.688566\n",
      "9  ▁diamond   5.113125\n",
      "21\n",
      "      token      logit\n",
      "0    atinum  20.550979\n",
      "1     ▁gold   9.288062\n",
      "2     ▁Gold   8.194387\n",
      "3   ▁silver   7.961591\n",
      "4      atin   7.492143\n",
      "5      Gold   7.070187\n",
      "6      gold   7.043496\n",
      "7   ▁Silver   6.854348\n",
      "8      cury   5.498607\n",
      "9  ▁diamond   5.381237\n",
      "23\n",
      "      token      logit\n",
      "0    atinum  22.133557\n",
      "1     ▁gold   9.800173\n",
      "2     ▁Gold   8.748113\n",
      "3   ▁silver   8.245801\n",
      "4      atin   7.937316\n",
      "5      Gold   7.575045\n",
      "6      gold   7.478086\n",
      "7   ▁Silver   7.129065\n",
      "8   ▁Golden   5.688522\n",
      "9  ▁diamond   5.637721\n",
      "25\n",
      "     token      logit\n",
      "0   atinum  23.651500\n",
      "1    ▁gold  10.282832\n",
      "2    ▁Gold   9.268841\n",
      "3  ▁silver   8.513081\n",
      "4     atin   8.346899\n",
      "5     Gold   8.048532\n",
      "6     gold   7.889812\n",
      "7  ▁Silver   7.386124\n",
      "8  ▁Golden   6.091209\n",
      "9  ▁golden   5.913918\n",
      "27\n",
      "     token      logit\n",
      "0   atinum  25.104942\n",
      "1    ▁gold  10.735858\n",
      "2    ▁Gold   9.757127\n",
      "3  ▁silver   8.762306\n",
      "4     atin   8.724249\n",
      "5     Gold   8.491717\n",
      "6     gold   8.277855\n",
      "7  ▁Silver   7.624823\n",
      "8  ▁Golden   6.468758\n",
      "9  ▁golden   6.262630\n",
      "29\n",
      "     token      logit\n",
      "0   atinum  26.496013\n",
      "1    ▁gold  11.160305\n",
      "2    ▁Gold  10.214426\n",
      "3     atin   9.072636\n",
      "4  ▁silver   8.993642\n",
      "5     Gold   8.906264\n",
      "6     gold   8.642677\n",
      "7  ▁Silver   7.845472\n",
      "8  ▁Golden   6.822632\n",
      "9  ▁golden   6.589260\n",
      "31\n",
      "     token      logit\n",
      "0   atinum  27.827821\n",
      "1    ▁gold  11.557818\n",
      "2    ▁Gold  10.642614\n",
      "3     atin   9.395087\n",
      "4     Gold   9.294092\n",
      "5  ▁silver   9.207918\n",
      "6     gold   8.985394\n",
      "7  ▁Silver   8.048919\n",
      "8  ▁Golden   7.154502\n",
      "9  ▁golden   6.895201\n",
      "33\n",
      "     token      logit\n",
      "0   atinum  29.104065\n",
      "1    ▁gold  11.930329\n",
      "2    ▁Gold  11.043775\n",
      "3     atin   9.694374\n",
      "4     Gold   9.657237\n",
      "5  ▁silver   9.406306\n",
      "6     gold   9.307464\n",
      "7  ▁Silver   8.236303\n",
      "8  ▁Golden   7.466122\n",
      "9  ▁golden   7.182045\n",
      "35\n",
      "     token      logit\n",
      "0   atinum  30.328012\n",
      "1    ▁gold  12.279658\n",
      "2    ▁Gold  11.419835\n",
      "3     Gold   9.997542\n",
      "4     atin   9.972851\n",
      "5     gold   9.610312\n",
      "6  ▁silver   9.590012\n",
      "7  ▁Silver   8.408785\n",
      "8  ▁Golden   7.759072\n",
      "9  ▁golden   7.451280\n",
      "37\n",
      "     token      logit\n",
      "0   atinum  31.503375\n",
      "1    ▁gold  12.607717\n",
      "2    ▁Gold  11.772809\n",
      "3     Gold  10.316902\n",
      "4     atin  10.232715\n",
      "5     gold   9.895485\n",
      "6  ▁silver   9.760329\n",
      "7  ▁Silver   8.567614\n",
      "8  ▁Golden   8.034965\n",
      "9  ▁golden   7.704448\n",
      "39\n",
      "     token      logit\n",
      "0   atinum  32.633140\n",
      "1    ▁gold  12.916149\n",
      "2    ▁Gold  12.104427\n",
      "3     Gold  10.616934\n",
      "4     atin  10.475792\n",
      "5     gold  10.164325\n",
      "6  ▁silver   9.918422\n",
      "7  ▁Silver   8.713922\n",
      "8  ▁Golden   8.295165\n",
      "9  ▁golden   7.942870\n"
     ]
    }
   ],
   "source": [
    "for eps in v2_dict:\n",
    "    print(eps)\n",
    "    print(topk(A @ (h + v2_dict[eps])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ffdfe3a1-914f-4329-847e-cf8712616d49",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.base_model.model.layers.insert(31, Intervene(None)) # dummy, to be replaced"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "e4eace6b-7de1-4821-ac48-2d5ec19f5fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "       token      logit\n",
      "0       ▁the  10.734898\n",
      "1       ▁arg   9.385143\n",
      "2    ▁carbon   9.366797\n",
      "3      ▁gold   9.301023\n",
      "4    ▁silver   9.223804\n",
      "5     ▁hydro   9.166210\n",
      "6         ▁b   9.157669\n",
      "7  ▁probably   9.082529\n",
      "8      ▁iron   8.965122\n",
      "9       ▁flu   8.943411\n",
      "3\n",
      "       token      logit\n",
      "0       ▁the  10.274455\n",
      "1    ▁silver   8.973610\n",
      "2  ▁probably   8.970881\n",
      "3      ▁iron   8.847486\n",
      "4    ▁carbon   8.811537\n",
      "5       ▁arg   8.788260\n",
      "6      ▁gold   8.664562\n",
      "7       ▁flu   8.638989\n",
      "8      ▁lead   8.539852\n",
      "9     ▁hydro   8.539461\n",
      "5\n",
      "         token     logit\n",
      "0         ▁the  9.782408\n",
      "1      ▁silver  8.819941\n",
      "2    ▁probably  8.761072\n",
      "3        ▁iron  8.586790\n",
      "4        ▁gall  8.411600\n",
      "5        ▁lead  8.225731\n",
      "6         ▁cad  8.205074\n",
      "7      ▁carbon  8.153930\n",
      "8         ▁flu  8.147808\n",
      "9  ▁definitely  8.090841\n",
      "7\n",
      "         token     logit\n",
      "0         ▁the  9.435711\n",
      "1      ▁silver  8.943274\n",
      "2    ▁probably  8.581954\n",
      "3        ▁gall  8.356562\n",
      "4        ▁iron  8.279730\n",
      "5         ▁cad  8.208963\n",
      "6  ▁definitely  8.096160\n",
      "7        ▁lead  8.085360\n",
      "8        ▁Gall  7.991706\n",
      "9         ▁ces  7.941221\n",
      "9\n",
      "         token     logit\n",
      "0         ▁the  9.166679\n",
      "1      ▁silver  9.094851\n",
      "2    ▁probably  8.426750\n",
      "3        ▁gall  8.118212\n",
      "4         ▁cad  8.060874\n",
      "5  ▁definitely  8.047551\n",
      "6        ▁iron  7.972907\n",
      "7        ▁lead  7.947025\n",
      "8      ▁Silver  7.890530\n",
      "9         ▁ces  7.805481\n",
      "11\n",
      "         token     logit\n",
      "0      ▁silver  9.196067\n",
      "1         ▁the  8.917044\n",
      "2    ▁probably  8.269397\n",
      "3      ▁Silver  8.078341\n",
      "4  ▁definitely  7.949583\n",
      "5         ▁cad  7.788732\n",
      "6        ▁lead  7.766880\n",
      "7        ▁gall  7.749557\n",
      "8        ▁iron  7.669209\n",
      "9         ▁sil  7.649632\n",
      "13\n",
      "         token     logit\n",
      "0      ▁silver  9.210526\n",
      "1         ▁the  8.665117\n",
      "2      ▁Silver  8.127717\n",
      "3    ▁probably  8.094282\n",
      "4  ▁definitely  7.796915\n",
      "5         ▁sil  7.594189\n",
      "6        ▁lead  7.543444\n",
      "7         ▁ces  7.452817\n",
      "8          ▁pl  7.418966\n",
      "9         ▁cad  7.409161\n",
      "15\n",
      "         token     logit\n",
      "0      ▁silver  9.130018\n",
      "1         ▁the  8.400131\n",
      "2      ▁Silver  8.044044\n",
      "3       atinum  7.960907\n",
      "4    ▁probably  7.890465\n",
      "5  ▁definitely  7.588481\n",
      "6         ▁sil  7.549213\n",
      "7          ▁pl  7.451561\n",
      "8        ▁lead  7.281860\n",
      "9         ▁ces  7.245566\n",
      "17\n",
      "         token     logit\n",
      "0       atinum  9.234026\n",
      "1      ▁silver  8.970887\n",
      "2         ▁the  8.118745\n",
      "3      ▁Silver  7.854090\n",
      "4    ▁probably  7.653748\n",
      "5         ▁sil  7.501927\n",
      "6          ▁pl  7.429740\n",
      "7  ▁definitely  7.329061\n",
      "8         ▁ces  7.022141\n",
      "9        ▁lead  6.989369\n",
      "19\n",
      "         token      logit\n",
      "0       atinum  10.433847\n",
      "1      ▁silver   8.762617\n",
      "2         ▁the   7.823702\n",
      "3      ▁Silver   7.594760\n",
      "4         ▁sil   7.445966\n",
      "5    ▁probably   7.386724\n",
      "6          ▁pl   7.377535\n",
      "7  ▁definitely   7.028361\n",
      "8         ▁ces   6.788015\n",
      "9         ▁tit   6.680572\n",
      "21\n",
      "         token      logit\n",
      "0       atinum  11.568274\n",
      "1      ▁silver   8.536129\n",
      "2         ▁the   7.521470\n",
      "3         ▁sil   7.379837\n",
      "4          ▁pl   7.319099\n",
      "5      ▁Silver   7.302387\n",
      "6    ▁probably   7.096556\n",
      "7  ▁definitely   6.698702\n",
      "8         ▁tit   6.587831\n",
      "9         ▁ces   6.550235\n",
      "23\n",
      "         token      logit\n",
      "0       atinum  12.647903\n",
      "1      ▁silver   8.316224\n",
      "2         ▁sil   7.304702\n",
      "3          ▁pl   7.272895\n",
      "4         ▁the   7.219673\n",
      "5      ▁Silver   7.006085\n",
      "6    ▁probably   6.792184\n",
      "7         ▁tit   6.483467\n",
      "8  ▁definitely   6.352516\n",
      "9         ▁ces   6.315766\n",
      "25\n",
      "       token      logit\n",
      "0     atinum  13.680795\n",
      "1    ▁silver   8.119044\n",
      "2        ▁pl   7.249797\n",
      "3       ▁sil   7.222647\n",
      "4       ▁the   6.925265\n",
      "5    ▁Silver   6.725644\n",
      "6  ▁probably   6.482241\n",
      "7       ▁tit   6.373027\n",
      "8        ▁Pl   6.164523\n",
      "9       ▁ces   6.090194\n",
      "27\n",
      "       token      logit\n",
      "0     atinum  14.671221\n",
      "1    ▁silver   7.952577\n",
      "2        ▁pl   7.253733\n",
      "3       ▁sil   7.135666\n",
      "4       ▁the   6.643438\n",
      "5    ▁Silver   6.472198\n",
      "6       ▁tit   6.260854\n",
      "7        ▁Pl   6.182988\n",
      "8  ▁probably   6.173803\n",
      "9       ▁ces   5.877158\n",
      "29\n",
      "       token      logit\n",
      "0     atinum  15.620055\n",
      "1    ▁silver   7.818759\n",
      "2        ▁pl   7.283553\n",
      "3       ▁sil   7.045273\n",
      "4       ▁the   6.377478\n",
      "5    ▁Silver   6.250337\n",
      "6        ▁Pl   6.226800\n",
      "7       ▁tit   6.149776\n",
      "8  ▁probably   5.872057\n",
      "9       ▁ces   5.678481\n",
      "31\n",
      "       token      logit\n",
      "0     atinum  16.525879\n",
      "1    ▁silver   7.715679\n",
      "2        ▁pl   7.335047\n",
      "3       ▁sil   6.952437\n",
      "4        ▁Pl   6.289811\n",
      "5       ▁the   6.128998\n",
      "6    ▁Silver   6.060309\n",
      "7       ▁tit   6.041301\n",
      "8  ▁probably   5.580464\n",
      "9       ▁ces   5.494608\n",
      "33\n",
      "       token      logit\n",
      "0     atinum  17.386318\n",
      "1    ▁silver   7.639363\n",
      "2        ▁pl   7.402638\n",
      "3       ▁sil   6.857727\n",
      "4        ▁Pl   6.365386\n",
      "5       ▁tit   5.935988\n",
      "6    ▁Silver   5.899849\n",
      "7       ▁the   5.898376\n",
      "8       ▁ces   5.325066\n",
      "9  ▁probably   5.301084\n",
      "35\n",
      "       token      logit\n",
      "0     atinum  18.198450\n",
      "1    ▁silver   7.585096\n",
      "2        ▁pl   7.480514\n",
      "3       ▁sil   6.761475\n",
      "4        ▁Pl   6.447340\n",
      "5       ▁tit   5.833870\n",
      "6    ▁Silver   5.765538\n",
      "7       ▁the   5.685268\n",
      "8       ▁ces   5.168939\n",
      "9  ▁probably   5.035090\n",
      "37\n",
      "     token      logit\n",
      "0   atinum  18.960163\n",
      "1      ▁pl   7.563374\n",
      "2  ▁silver   7.548142\n",
      "3     ▁sil   6.663846\n",
      "4      ▁Pl   6.530471\n",
      "5     ▁tit   5.734644\n",
      "6  ▁Silver   5.653485\n",
      "7     ▁the   5.488769\n",
      "8     ▁ces   5.025025\n",
      "9  ▁carbon   4.892628\n",
      "39\n",
      "     token      logit\n",
      "0   atinum  19.669737\n",
      "1      ▁pl   7.646701\n",
      "2  ▁silver   7.524157\n",
      "3      ▁Pl   6.610625\n",
      "4     ▁sil   6.565014\n",
      "5     ▁tit   5.637928\n",
      "6  ▁Silver   5.559928\n",
      "7     ▁the   5.307813\n",
      "8     ▁ces   4.892132\n",
      "9  ▁carbon   4.811072\n"
     ]
    }
   ],
   "source": [
    "for eps in v2_dict:\n",
    "    print(eps)\n",
    "    new_state = copy.deepcopy(out.hidden_states[31])#[:,:-1,:] # Don't overwrite last token\n",
    "    new_state[0,-1,:] += torch.Tensor(v2_dict[eps].flatten()).to('cuda')\n",
    "    model.base_model.model.layers[31] = Intervene(new_state)\n",
    "    new_out = model.base_model(**input)\n",
    "    print(topk(new_out.logits[0,-1,:].cpu().numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "122830a7-2b83-4599-b343-368a732df1a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "#\n",
    "#model.base_model.model.layers.pop(31)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e468c3a9-f052-46ab-9dbd-21ec5ec884d4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "074466ce-1118-4cbb-9018-fad986a33705",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>▁the</td>\n",
       "      <td>10.955291</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>▁arg</td>\n",
       "      <td>9.615419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>▁carbon</td>\n",
       "      <td>9.581732</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>▁gold</td>\n",
       "      <td>9.542505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>▁b</td>\n",
       "      <td>9.452555</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>▁hydro</td>\n",
       "      <td>9.409390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>▁silver</td>\n",
       "      <td>9.310298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>▁probably</td>\n",
       "      <td>9.113802</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>▁flu</td>\n",
       "      <td>9.051174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>▁mer</td>\n",
       "      <td>9.013393</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       token      logit\n",
       "0       ▁the  10.955291\n",
       "1       ▁arg   9.615419\n",
       "2    ▁carbon   9.581732\n",
       "3      ▁gold   9.542505\n",
       "4         ▁b   9.452555\n",
       "5     ▁hydro   9.409390\n",
       "6    ▁silver   9.310298\n",
       "7  ▁probably   9.113802\n",
       "8       ▁flu   9.051174\n",
       "9       ▁mer   9.013393"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(out.logits[0,-1,:].cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "9f989768-50e2-471d-8f90-bc0c148e6405",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>atinum</td>\n",
       "      <td>16.078453</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>▁silver</td>\n",
       "      <td>7.763600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>▁pl</td>\n",
       "      <td>7.306924</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>▁sil</td>\n",
       "      <td>6.999122</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>▁Pl</td>\n",
       "      <td>6.256324</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>▁the</td>\n",
       "      <td>6.251010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>▁Silver</td>\n",
       "      <td>6.151449</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>▁tit</td>\n",
       "      <td>6.095162</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>▁probably</td>\n",
       "      <td>5.724842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>▁ces</td>\n",
       "      <td>5.584713</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       token      logit\n",
       "0     atinum  16.078453\n",
       "1    ▁silver   7.763600\n",
       "2        ▁pl   7.306924\n",
       "3       ▁sil   6.999122\n",
       "4        ▁Pl   6.256324\n",
       "5       ▁the   6.251010\n",
       "6    ▁Silver   6.151449\n",
       "7       ▁tit   6.095162\n",
       "8  ▁probably   5.724842\n",
       "9       ▁ces   5.584713"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "e6de8aba-deb2-42d9-931f-aea2aedad8dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 0 0 0 0 0 0 0\n",
      "6 0 0 0 0 0 0 0 0\n",
      "265 0 0 0 0 0 0 0 0\n",
      "265 0 0 0 0 0 0 0 0\n",
      "265 1 1 0 0 0 0 1 0\n",
      "265 1 1 1 1 1 1 1 1\n",
      "265 1 1 1 1 1 1 1 1\n",
      "265 2 2 1 1 1 2 2 1\n",
      "265 2 2 2 2 2 2 2 2\n",
      "265 4 2 2 2 2 2 2 2\n",
      "264 4 3 2 2 2 2 2 2\n",
      "265 5 3 3 3 3 3 3 3\n",
      "264 5 3 3 3 3 3 3 3\n",
      "264 5 4 3 3 3 4 4 3\n",
      "264 5 4 4 4 4 4 4 4\n",
      "264 6 5 5 5 5 5 5 5\n",
      "263 6 6 5 5 5 5 6 5\n",
      "263 8 7 6 6 6 6 7 6\n",
      "264 10 9 8 8 7 7 8 7\n",
      "264 12 10 9 9 9 8 10 8\n",
      "267 16 13 11 11 11 9 11 10\n",
      "267 19 15 12 13 12 11 13 12\n",
      "267 20 17 14 14 14 12 15 14\n",
      "268 22 19 15 16 15 13 16 15\n",
      "268 24 21 17 17 17 15 18 17\n",
      "268 24 23 18 19 19 17 19 18\n",
      "268 25 25 20 20 22 18 20 20\n",
      "268 26 26 22 21 24 20 22 21\n",
      "268 27 28 23 23 26 23 24 24\n",
      "267 28 30 25 26 28 26 26 27\n",
      "268 32 35 29 31 33 31 30 32\n",
      "237 38 40 33 36 38 35 35 36\n",
      "345 398 459 372 395 396 324 399 388\n"
     ]
    }
   ],
   "source": [
    "for i in range(33):\n",
    "    print(' '.join(\n",
    "        str(int(out.hidden_states[i][0,j,:].norm().item())) for j in range(9)\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "cad9ca24-7b08-4373-a0c6-7d6f623bccb1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 9, 4096])"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.hidden_states[31]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "70f916ac-0d43-44ec-ab3e-07a14f1c195d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>▁Michigan</td>\n",
       "      <td>13.245934</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>▁Wisconsin</td>\n",
       "      <td>13.089794</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>▁Minnesota</td>\n",
       "      <td>12.887695</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>▁Indiana</td>\n",
       "      <td>12.725050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>▁Iowa</td>\n",
       "      <td>12.543724</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>▁Missouri</td>\n",
       "      <td>12.224394</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>▁Ohio</td>\n",
       "      <td>12.205340</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>▁Illinois</td>\n",
       "      <td>12.158340</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>▁South</td>\n",
       "      <td>11.714095</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>▁definitely</td>\n",
       "      <td>11.326981</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>▁the</td>\n",
       "      <td>11.272463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>▁North</td>\n",
       "      <td>11.176455</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>▁Neb</td>\n",
       "      <td>11.065105</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>▁probably</td>\n",
       "      <td>10.846264</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>▁Kentucky</td>\n",
       "      <td>10.798969</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>▁without</td>\n",
       "      <td>10.536729</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>▁und</td>\n",
       "      <td>10.532318</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>▁Kansas</td>\n",
       "      <td>10.459734</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>,</td>\n",
       "      <td>10.423826</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>▁Colorado</td>\n",
       "      <td>10.289072</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>▁a</td>\n",
       "      <td>10.177610</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>▁also</td>\n",
       "      <td>9.959203</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>▁hands</td>\n",
       "      <td>9.846262</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>▁now</td>\n",
       "      <td>9.797674</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>▁my</td>\n",
       "      <td>9.532832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>▁Oklahoma</td>\n",
       "      <td>9.479873</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>▁by</td>\n",
       "      <td>9.411868</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>▁easily</td>\n",
       "      <td>9.403625</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>▁not</td>\n",
       "      <td>9.403258</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>▁no</td>\n",
       "      <td>9.347988</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>▁certainly</td>\n",
       "      <td>9.334706</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>▁Mont</td>\n",
       "      <td>9.322747</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>▁one</td>\n",
       "      <td>9.223373</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>▁Pennsylvania</td>\n",
       "      <td>9.220661</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>…</td>\n",
       "      <td>9.204512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>▁always</td>\n",
       "      <td>9.194786</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>▁still</td>\n",
       "      <td>9.070404</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>▁right</td>\n",
       "      <td>9.012264</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>▁West</td>\n",
       "      <td>8.999782</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>▁most</td>\n",
       "      <td>8.994453</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            token      logit\n",
       "0       ▁Michigan  13.245934\n",
       "1      ▁Wisconsin  13.089794\n",
       "2      ▁Minnesota  12.887695\n",
       "3        ▁Indiana  12.725050\n",
       "4           ▁Iowa  12.543724\n",
       "5       ▁Missouri  12.224394\n",
       "6           ▁Ohio  12.205340\n",
       "7       ▁Illinois  12.158340\n",
       "8          ▁South  11.714095\n",
       "9     ▁definitely  11.326981\n",
       "10           ▁the  11.272463\n",
       "11         ▁North  11.176455\n",
       "12           ▁Neb  11.065105\n",
       "13      ▁probably  10.846264\n",
       "14      ▁Kentucky  10.798969\n",
       "15       ▁without  10.536729\n",
       "16           ▁und  10.532318\n",
       "17        ▁Kansas  10.459734\n",
       "18              ,  10.423826\n",
       "19      ▁Colorado  10.289072\n",
       "20             ▁a  10.177610\n",
       "21          ▁also   9.959203\n",
       "22         ▁hands   9.846262\n",
       "23           ▁now   9.797674\n",
       "24            ▁my   9.532832\n",
       "25      ▁Oklahoma   9.479873\n",
       "26            ▁by   9.411868\n",
       "27        ▁easily   9.403625\n",
       "28           ▁not   9.403258\n",
       "29            ▁no   9.347988\n",
       "30     ▁certainly   9.334706\n",
       "31          ▁Mont   9.322747\n",
       "32           ▁one   9.223373\n",
       "33  ▁Pennsylvania   9.220661\n",
       "34              …   9.204512\n",
       "35        ▁always   9.194786\n",
       "36         ▁still   9.070404\n",
       "37         ▁right   9.012264\n",
       "38          ▁West   8.999782\n",
       "39          ▁most   8.994453"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(out.logits.detach().cpu().numpy()[0,-1,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "907007ae-94ef-472a-8beb-58dfb98ab8ca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<bound method ModuleList.insert of ModuleList(\n",
       "  (0-31): 32 x MistralDecoderLayer(\n",
       "    (self_attn): MistralAttention(\n",
       "      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "      (rotary_emb): MistralRotaryEmbedding()\n",
       "    )\n",
       "    (mlp): MistralMLP(\n",
       "      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "      (act_fn): SiLU()\n",
       "    )\n",
       "    (input_layernorm): MistralRMSNorm()\n",
       "    (post_attention_layernorm): MistralRMSNorm()\n",
       "  )\n",
       ")>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.base_model.model.layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "69676df0-6e4e-4da1-894b-1567445e10c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s>|▁My|▁favorite|▁state|▁in|▁the|▁Mid|west|▁is|▁Michigan\n"
     ]
    }
   ],
   "source": [
    "print_tokens('My favorite state in the Midwest is Michigan')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "59d8fb83-ab23-469f-8800-ece1c0b025e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<s>'"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Token[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "03079447-af96-456e-a32e-0024e4aa899f",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = np.random.normal(0, 1, (100, 10))\n",
    "y = np.random.normal(0, 1, (100, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "f2c9df26-ff62-4605-80ac-2d0f8e979258",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "238.84033203125\n"
     ]
    }
   ],
   "source": [
    "w = constr_lsqr(y, A, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "6b0941d1-b9a7-48d1-ae7f-f3685aa4297e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.1000000058535913"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.linalg.norm(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "dce72dbc-48e6-45e7-92dc-6c45ea915f98",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.00752921],\n",
       "       [ 0.02647566],\n",
       "       [-0.0483614 ],\n",
       "       [-0.02094679],\n",
       "       [ 0.05765701],\n",
       "       [-0.01300482],\n",
       "       [-0.05031044],\n",
       "       [-0.01093733],\n",
       "       [ 0.00570315],\n",
       "       [ 0.01697087]])"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.linalg.inv(A.T @ A + 238.84033 * np.eye(10)) @ A.T @ y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "0cb5534c-b667-4a04-b292-639bf7f28f8e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.00752921],\n",
       "       [ 0.02647566],\n",
       "       [-0.0483614 ],\n",
       "       [-0.02094679],\n",
       "       [ 0.05765701],\n",
       "       [-0.01300482],\n",
       "       [-0.05031044],\n",
       "       [-0.01093733],\n",
       "       [ 0.00570315],\n",
       "       [ 0.01697087]])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8c5b62e-0259-4ff8-89bf-ab66f853dad8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
