{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cd7f73e8-44ee-4802-bec8-b7185b33dca1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# need to do this before transformer imports\n",
    "import os\n",
    "os.environ['HF_HOME'] = '/workspace/cache/huggingface/'\n",
    "\n",
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "from evals.utils import *\n",
    "from models.bigram_model import *\n",
    "from models.mlp_model import *\n",
    "from models.future_model import *\n",
    "from data.utils import get_tokenizer\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import nn\n",
    "from itertools import islice\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import Dataset\n",
    "from torch import nn\n",
    "\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import gc\n",
    "from glob import glob\n",
    "import numpy as np\n",
    "import copy\n",
    "from num2words import num2words\n",
    "from collections import defaultdict\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ced0d178-8176-4cb8-9ab9-a9e882564fa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'mistralai/Mistral-7B-v0.1'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "Token = {v: k for k, v in tokenizer.get_vocab().items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0ec6d897-0b7a-45d3-8527-9a6de94fa65b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_tokens(s):\n",
    "    tokens = tokenizer(s)['input_ids']\n",
    "    print('|'.join(Token[t] for t in tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cb73a94f-9cf5-4ae6-96d2-9ecd64847fa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def topk(v, k=10):\n",
    "    # Takes in logits\n",
    "    if isinstance(v, torch.Tensor):\n",
    "        v = v.detach().cpu().numpy()\n",
    "    v = v.flatten()\n",
    "    idxs = v.argsort()[-k:][::-1]\n",
    "    ret = [(Token[i], v[i]) for i in idxs]\n",
    "    return pd.DataFrame(ret, columns=['token', 'logit'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1f5f6ed6-0bd1-48ac-90e9-1267d3311bde",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt = glob(\n",
    "    '/workspace/checkpoints/MISTRAL-NECK-SWEEP_*_hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-lstm_*',\n",
    ")[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "094beeab-85bc-413d-a50a-294455b301e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/workspace/checkpoints/MISTRAL-NECK-SWEEP_20240102-191556-Ec6c4_hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-lstm_epoch=00-val_self_loss=3.81.ckpt'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7d7284c0-4235-4b00-9a58-a514cd17c90d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e8e5380839bd4073b9b9e673a2ffb053",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = LitFutureModelWithNeck.load_from_checkpoint(ckpt, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9c03cffa-7390-4dce-8605-60172520af99",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7f863c240eb0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "057a4775-a50e-4385-a9c4-6edada62fbe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "input = tokenizer('Alice has two apples and Bob has four apples. In total, they have', return_tensors='pt').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fa5287d0-73ec-4d92-8f6c-aa5a46ed75a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "nums = {i: num2words(i) for i in range(100) if len(tokenizer(num2words(i)).input_ids) == 2} # 1-token numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ab9b9829-4468-4178-9bc6-15355e65bc3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gc.collect(); torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b95b38bd-0e10-4be2-a03a-fe243d41bd52",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 300/300 [01:08<00:00,  4.41it/s]\n"
     ]
    }
   ],
   "source": [
    "from itertools import combinations_with_replacement as car\n",
    "prompt = 'Alice has five apples and Bob has twelve apples. Combined, Alice and Bob have seventeen apples. Carol has {} apples and Dave has {} apples. Combined, Carol and Dave have' \n",
    "\n",
    "total = 0\n",
    "base_topk = defaultdict(lambda: 0)\n",
    "future_topk = defaultdict(lambda: 0)\n",
    "ks = [1, 2, 3, 5, 10]\n",
    "for i, j in tqdm(list(car(nums, 2))):\n",
    "    total += 1\n",
    "    gc.collect(); torch.cuda.empty_cache()\n",
    "    input = tokenizer(\n",
    "        prompt.format(nums[i], nums[j]),\n",
    "        return_tensors='pt'\n",
    "    ).to('cuda')\n",
    "    out = model(input)\n",
    "    if i + j not in nums:\n",
    "        continue\n",
    "    base_tok = topk(out.logits[0,-1,:]).token.tolist()\n",
    "    future_tok = topk(out.future_logits[0,-2,:]).token.tolist()\n",
    "    for k in ks:\n",
    "        if '▁' + nums[i+j] in base_tok[:k]:\n",
    "            base_topk[k] += 1\n",
    "        if '▁' + nums[i+j] in future_tok[:k]:\n",
    "            future_topk[k] += 1\n",
    "for k in base_topk:\n",
    "    base_topk[k] /= total\n",
    "for k in future_topk:\n",
    "    future_topk[k] /= total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1bc8f9e5-0f56-4ed2-afb3-c709a549dc61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(<function __main__.<lambda>()>,\n",
       "            {1: 0.37,\n",
       "             2: 0.38,\n",
       "             3: 0.38666666666666666,\n",
       "             5: 0.3933333333333333,\n",
       "             10: 0.3933333333333333})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_topk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3e7a3664-0865-4be3-a7cb-b622243b5787",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(<function __main__.<lambda>()>,\n",
       "            {10: 0.06666666666666667,\n",
       "             1: 0.006666666666666667,\n",
       "             2: 0.013333333333333334,\n",
       "             3: 0.016666666666666666,\n",
       "             5: 0.023333333333333334})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "future_topk"
   ]
  },
  {
   "cell_type": "raw",
   "id": "e1a2968e-bd0c-4483-ba89-0af8d4a4e5e9",
   "metadata": {},
   "source": [
    "base MISTRAL:\n",
    "{1: 0.37, 2: 0.38, 3: 0.38666666666666666, 5: 0.3933333333333333, 10: 0.3933333333333333})\n",
    "hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-mlp:\n",
    "    {0: 0, 3: 0.0033333333333333335, 5: 0.01, 10: 0.04})\n",
    "hidden_idxs-32_hidden_lb-0_token_lb-0_neck_cls-mlp:\n",
    "    {0: 0, 3: 0.01, 5: 0.02666666666666667, 10: 0.09666666666666666}\n",
    "hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-lstm:\n",
    "    {1: 0.006666666666666667 ,2: 0.013333333333333334 ,3: 0.016666666666666666, 5: 0.023333333333333334, 10: 0.06666666666666667,}\n",
    "hidden_idxs-32_hidden_lb-0_token_lb-0_neck_cls-lstm:\n",
    "    {0: 0,3: 0.0033333333333333335, 5: 0.016666666666666666, 10: 0.03333333333333333})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "f1090fb1-51b1-483b-952f-1ba0003eec72",
   "metadata": {},
   "outputs": [],
   "source": [
    "input = tokenizer(\n",
    "    #prompt.format('five', 'eleven'),\n",
    "    #'Fried foods are high in cholesterol',\n",
    "    'The director of the Manhattan project was J Robert Oppenheimer',\n",
    "    return_tensors='pt'\n",
    ").to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "05b4afee-5da8-411f-a346-6881f1ab2f1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "6d4f9801-6a80-40e1-be9b-0dae73bc239a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>heimer</td>\n",
       "      <td>19.516129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>h</td>\n",
       "      <td>13.359095</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>he</td>\n",
       "      <td>12.359792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>heim</td>\n",
       "      <td>12.232853</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>her</td>\n",
       "      <td>10.319977</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>ha</td>\n",
       "      <td>10.146824</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>hem</td>\n",
       "      <td>10.030427</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>ham</td>\n",
       "      <td>9.772432</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>him</td>\n",
       "      <td>9.236302</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>-</td>\n",
       "      <td>8.828691</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    token      logit\n",
       "0  heimer  19.516129\n",
       "1       h  13.359095\n",
       "2      he  12.359792\n",
       "3    heim  12.232853\n",
       "4     her  10.319977\n",
       "5      ha  10.146824\n",
       "6     hem  10.030427\n",
       "7     ham   9.772432\n",
       "8     him   9.236302\n",
       "9       -   8.828691"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(out.logits[0, -2,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7ae844af-0ad3-4cfa-9d0d-7c642d8b6f7d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>.</td>\n",
       "      <td>8.544149</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>,</td>\n",
       "      <td>8.523829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>hen</td>\n",
       "      <td>8.339357</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>en</td>\n",
       "      <td>7.998469</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ard</td>\n",
       "      <td>7.784670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>man</td>\n",
       "      <td>7.768543</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>hal</td>\n",
       "      <td>7.739163</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>berg</td>\n",
       "      <td>7.691470</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>h</td>\n",
       "      <td>7.520702</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>pe</td>\n",
       "      <td>7.334779</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  token     logit\n",
       "0     .  8.544149\n",
       "1     ,  8.523829\n",
       "2   hen  8.339357\n",
       "3    en  7.998469\n",
       "4   ard  7.784670\n",
       "5   man  7.768543\n",
       "6   hal  7.739163\n",
       "7  berg  7.691470\n",
       "8     h  7.520702\n",
       "9    pe  7.334779"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(out.future_logits[0,-3,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6b82f436-5f3f-46bd-95db-a294d80d4e96",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 10, 32000])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.future_logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8d2aad86-3424-4318-b516-99d326b09ce5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'▁six' in topk(out.logits[0, -1,:]).token.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c280a5e1-37a1-40e4-a60f-59c77dd2ec06",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
