{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "46dd1463-a71a-45a8-87ab-b1a6a8cb7a0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "from evals.utils import *\n",
    "from models.bigram_model import *\n",
    "from models.mlp_model import *\n",
    "from models.future_model import *\n",
    "from data.utils import get_tokenizer\n",
    "\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import gc\n",
    "from glob import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8c426b34-e173-4d97-85ef-358f12e04b60",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Suppress warnings (there are a lot b/c of how the checkpoints are stored...)\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9f8d767c-1ba0-4509-859b-4f1ff6f0e9d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "24a53c7d-274a-48f8-9759-c54492ba2f1b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9fe5ca453f93468a9002e191cb2ca8ee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89a8821eba894f59a89dc7a2f5f6a600",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "debf8f1c0deb4b57a9eb4a170839e396",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a37b6e948cc84d2ca253e7424ee3788e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c6a76643e84149b9bce564145c1004db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1092aec034834501978a6f0a711de831",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c6096ae0e6644646b10b84f5d9f966a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b8eec4745ccd4979ba0e0514b8752783",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "gpt2_tokenizer = get_tokenizer('gpt2')\n",
    "mistral_tokenizer = get_tokenizer('mistralai/Mistral-7B-v0.1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3091cc8-0251-4855-b480-b2d10b28a2d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(model_cls, ckpt_path, kwargs):\n",
    "    return getattr(model_cls, 'load_from_checkpoint')(ckpt_path, strict=False, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ac8d3f2-be0e-4cd2-8120-76f7a6162e15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_futures(prompt, models, tokenizer, max_tokens=1, topk=10):\n",
    "    for name, args in {**models}.items():\n",
    "        print('MODEL:', name)\n",
    "        model = load_model(*args)\n",
    "        # generate_future from FutureGPT2/src/evals/utils.py\n",
    "        print(generate_future(prompt, model, tokenizer, max_tokens=max_tokens, topk=topk).T)\n",
    "    \n",
    "        # Free up GPU memory\n",
    "        del model\n",
    "        gc.collect()\n",
    "        torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "771da756-57db-4974-947e-c71e61684a20",
   "metadata": {},
   "source": [
    "## Neck hdim sweep (GPT2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3d933b0-cd63-4bad-a5b6-17aeb4f909b5",
   "metadata": {},
   "source": [
    "neck_h4^i is 2-layer neck with 4^i hidden dimension"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73bf93af-626c-4125-a441-e83787391580",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {\n",
    "    f'neck_h4^{i}': (\n",
    "        LitFutureModelWithNeck, \n",
    "        f'/workspace/checkpoints/fc_neck_sweep/h4_{i}.ckpt', \n",
    "        {'neck_cls': 'mlp', 'hidden_idxs': 12, 'hidden_lb': 0, 'token_lb': 0}\n",
    "    )\n",
    "    for i in range(8)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "042275da-6ae1-4dd1-b212-4a2d3dd2c181",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = 'Alice is visiting Japan, so she\\'s exchanging her dollars for'\n",
    "# next token prediction; top10:\n",
    "gen_futures(prompt, models, gpt2_tokenizer, max_tokens=1, topk=10)\n",
    "# next 20 tokens, top1:\n",
    "# gen_futures(prompt, models, 20, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b827c28e-15d4-4643-8673-d47203be2223",
   "metadata": {},
   "source": [
    "## Finetune sweep (GPT2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5953319f-a92c-493c-b186-37c6c0d5704b",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_finetune_model = lambda blr, kappa: glob(\n",
    "        f'/workspace/checkpoints/FINETUNE-KAPPA-SWEEP_*_base_lr-{blr}_kappa-{kappa}_*.ckpt',\n",
    ")[0]\n",
    "kappas = ['0.0001', '0.001', '0.01', '0.1', '1']\n",
    "blrs = ['1e-05']   # base learning rate\n",
    "\n",
    "models = {\n",
    "    f'finetune_blr:{blr}_k:{k}': (\n",
    "        LitFutureModelWithNeck, \n",
    "        get_finetune_model(blr, k), \n",
    "        {'neck_cls': 'mlp', 'hidden_idxs': 12, 'hidden_lb': 0, 'token_lb': 0}\n",
    "    )\n",
    "    for blr in blrs\n",
    "    for k in kappas\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a25421dd-7d1c-4fde-8f80-93d6ac0a1b28",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_futures('Alice is visiting Japan, so she\\'s exchanging her dollars for', models, gpt2_tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6af41e40-a75a-47d9-90ec-c4504eb37d93",
   "metadata": {},
   "source": [
    "## Hidden/Token lookbacks, MLP/LSTM Sweep (GPT2 and Mistral)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8669d516-719e-4a81-a11e-062930f5faef",
   "metadata": {},
   "source": [
    "### Checkpoint arguments:\n",
    "- hidden_idxs: index of hidden state layer\n",
    "\n",
    "- hidden_lb:\n",
    "  - if k, then will input hidden state from positions {t-k, ..., t}\n",
    "  - set to -1 to ignore hidden state.\n",
    "  - for LSTM, this is the lookback *per position* fed into the LSTM\n",
    "  - so, since LSTM is already recurrent over the full history, should only set to 0 or -1 (to ignore hidden)\n",
    "\n",
    "- token_lb: if k, then will input embedded tokens from positions {t-k+1, ..., t+1}\n",
    "  - set to -1 to ignore input tokens\n",
    "  \n",
    "- neck_cls: either 'mlp' or 'lstm'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3b5257c6-f81d-48bf-97fe-3c15f124d05c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ckpt(sweep_name, hidden_idxs, hidden_lb, token_lb, neck_cls):\n",
    "    pattern = '/workspace/checkpoints/' + \\\n",
    "        '_'.join([\n",
    "            sweep_name,\n",
    "            '*'\n",
    "            f'hidden_idxs-{hidden_idxs}',\n",
    "            f'hidden_lb-{hidden_lb}',\n",
    "            f'token_lb-{token_lb}',\n",
    "            f'neck_cls-{neck_cls}',\n",
    "            '*'\n",
    "        ]) + '.ckpt'\n",
    "    try:\n",
    "        return glob(pattern)[0]\n",
    "    except IndexError:\n",
    "        #print(f'WARN: couldn\\'t match {pattern}.')\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b922d4dd-3520-4fe3-9d1d-7245ecae53f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "ckpts = []\n",
    "# NECK-SWEEP2 is GPT2\n",
    "sweeps = ['NECK-SWEEP2', 'MISTRAL-NECK-SWEEP']\n",
    "hidden_idxs = [0, 11, 12, 31, 32]\n",
    "hidden_lbs = [-1, 0, 1]\n",
    "token_lbs = [-1, 0, 1]\n",
    "neck_cls = ['mlp', 'lstm']\n",
    "# Only some subset of the full product exists\n",
    "# But it's easier to just try to load and ignore failures\n",
    "# instead of explicitly enumerating everything\n",
    "gpt2_models = {\n",
    "    '_'.join(str(a) for a in args): (\n",
    "        LitFutureModelWithNeck,\n",
    "        get_ckpt(*args),\n",
    "        {}\n",
    "    )\n",
    "    for args in product(['NECK-SWEEP2'], hidden_idxs, hidden_lbs, token_lbs, neck_cls)\n",
    "    if get_ckpt(*args) is not None\n",
    "}\n",
    "mistral_models = {\n",
    "    '-'.join(str(a) for a in args): (\n",
    "        LitFutureModelWithNeck,\n",
    "        get_ckpt(*args),\n",
    "        {}\n",
    "    )\n",
    "    for args in product(['MISTRAL-NECK-SWEEP'], hidden_idxs, hidden_lbs, token_lbs, neck_cls)\n",
    "    if get_ckpt(*args) is not None\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a88277e2-78a5-48de-9dfc-957645c24350",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/workspace/checkpoints/MISTRAL-NECK-SWEEP_20240102-191556-Ec6c4_hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-lstm_epoch=00-val_self_loss=3.81.ckpt'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_ckpt('MISTRAL-NECK-SWEEP', 31, 0, 0, 'lstm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "e268244f-14d2-489b-bc2f-8129d3baa9b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MODEL: NECK-SWEEP2_0_-1_-1_mlp\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.000519\n",
      "base_token_1     Japanese\n",
      "base_prob_1       0.00006\n",
      "base_token_2        money\n",
      "base_prob_2      0.000199\n",
      "base_token_3            a\n",
      "base_prob_3      0.013693\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000019\n",
      "base_token_5          the\n",
      "base_prob_5      0.027852\n",
      "base_token_6         cash\n",
      "base_prob_6      0.000039\n",
      "base_token_7         some\n",
      "base_prob_7      0.000852\n",
      "base_token_8      dollars\n",
      "base_prob_8       0.00013\n",
      "base_token_9       things\n",
      "base_prob_9      0.000173\n",
      "future_token_0          ,\n",
      "future_prob_0    0.036268\n",
      "future_token_1          .\n",
      "future_prob_1    0.035682\n",
      "future_token_2        the\n",
      "future_prob_2    0.027852\n",
      "future_token_3         of\n",
      "future_prob_3    0.020562\n",
      "future_token_4        and\n",
      "future_prob_4    0.018334\n",
      "future_token_5         to\n",
      "future_prob_5    0.016543\n",
      "future_token_6          a\n",
      "future_prob_6    0.013693\n",
      "future_token_7         in\n",
      "future_prob_7      0.0119\n",
      "future_token_8         \\n\n",
      "future_prob_8    0.011599\n",
      "future_token_9         is\n",
      "future_prob_9    0.010196\n",
      "MODEL: NECK-SWEEP2_0_-1_0_mlp\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.003132\n",
      "base_token_1     Japanese\n",
      "base_prob_1       0.00007\n",
      "base_token_2        money\n",
      "base_prob_2      0.000423\n",
      "base_token_3            a\n",
      "base_prob_3      0.071177\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000002\n",
      "base_token_5          the\n",
      "base_prob_5      0.108302\n",
      "base_token_6         cash\n",
      "base_prob_6      0.000286\n",
      "base_token_7         some\n",
      "base_prob_7      0.005296\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.000043\n",
      "base_token_9       things\n",
      "base_prob_9      0.000404\n",
      "future_token_0        the\n",
      "future_prob_0    0.108302\n",
      "future_token_1          a\n",
      "future_prob_1    0.071177\n",
      "future_token_2       your\n",
      "future_prob_2    0.020452\n",
      "future_token_3       this\n",
      "future_prob_3    0.013381\n",
      "future_token_4         an\n",
      "future_prob_4    0.013194\n",
      "future_token_5        all\n",
      "future_prob_5    0.011225\n",
      "future_token_6       more\n",
      "future_prob_6    0.010777\n",
      "future_token_7        you\n",
      "future_prob_7    0.010492\n",
      "future_token_8       each\n",
      "future_prob_8    0.008977\n",
      "future_token_9    example\n",
      "future_prob_9    0.008188\n",
      "MODEL: NECK-SWEEP2_0_-1_0_lstm\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.124284\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.000556\n",
      "base_token_2        money\n",
      "base_prob_2        0.0142\n",
      "base_token_3            a\n",
      "base_prob_3      0.064785\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000159\n",
      "base_token_5          the\n",
      "base_prob_5      0.059456\n",
      "base_token_6         cash\n",
      "base_prob_6      0.001663\n",
      "base_token_7         some\n",
      "base_prob_7      0.006372\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.001671\n",
      "base_token_9       things\n",
      "base_prob_9      0.001635\n",
      "future_token_0        her\n",
      "future_prob_0    0.124284\n",
      "future_token_1          a\n",
      "future_prob_1    0.064785\n",
      "future_token_2        the\n",
      "future_prob_2    0.059456\n",
      "future_token_3        you\n",
      "future_prob_3    0.017692\n",
      "future_token_4          $\n",
      "future_prob_4    0.016802\n",
      "future_token_5      money\n",
      "future_prob_5      0.0142\n",
      "future_token_6         an\n",
      "future_prob_6    0.011211\n",
      "future_token_7       your\n",
      "future_prob_7    0.010907\n",
      "future_token_8      their\n",
      "future_prob_8    0.010238\n",
      "future_token_9       more\n",
      "future_prob_9    0.009674\n",
      "MODEL: NECK-SWEEP2_0_-1_1_mlp\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.001255\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.000101\n",
      "base_token_2        money\n",
      "base_prob_2      0.001584\n",
      "base_token_3            a\n",
      "base_prob_3       0.11584\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000028\n",
      "base_token_5          the\n",
      "base_prob_5      0.093858\n",
      "base_token_6         cash\n",
      "base_prob_6      0.003401\n",
      "base_token_7         some\n",
      "base_prob_7      0.006826\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.001972\n",
      "base_token_9       things\n",
      "base_prob_9      0.000087\n",
      "future_token_0          a\n",
      "future_prob_0     0.11584\n",
      "future_token_1        the\n",
      "future_prob_1    0.093858\n",
      "future_token_2       each\n",
      "future_prob_2    0.060341\n",
      "future_token_3       more\n",
      "future_prob_3    0.037576\n",
      "future_token_4         an\n",
      "future_prob_4    0.023519\n",
      "future_token_5          $\n",
      "future_prob_5    0.022551\n",
      "future_token_6       your\n",
      "future_prob_6    0.018999\n",
      "future_token_7      every\n",
      "future_prob_7    0.016455\n",
      "future_token_8        you\n",
      "future_prob_8    0.013637\n",
      "future_token_9       this\n",
      "future_prob_9     0.01298\n",
      "MODEL: NECK-SWEEP2_11_0_-1_mlp\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.090438\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.000487\n",
      "base_token_2        money\n",
      "base_prob_2      0.002836\n",
      "base_token_3            a\n",
      "base_prob_3      0.076708\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000523\n",
      "base_token_5          the\n",
      "base_prob_5      0.066191\n",
      "base_token_6         cash\n",
      "base_prob_6      0.002885\n",
      "base_token_7         some\n",
      "base_prob_7      0.004967\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.000516\n",
      "base_token_9       things\n",
      "base_prob_9      0.000418\n",
      "future_token_0        her\n",
      "future_prob_0    0.090438\n",
      "future_token_1          a\n",
      "future_prob_1    0.076708\n",
      "future_token_2        the\n",
      "future_prob_2    0.066191\n",
      "future_token_3       with\n",
      "future_prob_3    0.026052\n",
      "future_token_4        and\n",
      "future_prob_4    0.018284\n",
      "future_token_5        she\n",
      "future_prob_5    0.014421\n",
      "future_token_6         to\n",
      "future_prob_6    0.012691\n",
      "future_token_7         an\n",
      "future_prob_7    0.012276\n",
      "future_token_8      their\n",
      "future_prob_8    0.011546\n",
      "future_token_9        for\n",
      "future_prob_9    0.011402\n",
      "MODEL: NECK-SWEEP2_11_0_-1_lstm\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.215586\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.001449\n",
      "base_token_2        money\n",
      "base_prob_2      0.015469\n",
      "base_token_3            a\n",
      "base_prob_3      0.044468\n",
      "base_token_4          yen\n",
      "base_prob_4      0.001208\n",
      "base_token_5          the\n",
      "base_prob_5      0.041694\n",
      "base_token_6         cash\n",
      "base_prob_6      0.011249\n",
      "base_token_7         some\n",
      "base_prob_7      0.003291\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.002287\n",
      "base_token_9       things\n",
      "base_prob_9      0.000719\n",
      "future_token_0        her\n",
      "future_prob_0    0.215586\n",
      "future_token_1          a\n",
      "future_prob_1    0.044468\n",
      "future_token_2        she\n",
      "future_prob_2    0.042884\n",
      "future_token_3        the\n",
      "future_prob_3    0.041694\n",
      "future_token_4      money\n",
      "future_prob_4    0.015469\n",
      "future_token_5         \\n\n",
      "future_prob_5    0.014646\n",
      "future_token_6        She\n",
      "future_prob_6    0.012819\n",
      "future_token_7       cash\n",
      "future_prob_7    0.011249\n",
      "future_token_8        his\n",
      "future_prob_8     0.00837\n",
      "future_token_9      other\n",
      "future_prob_9    0.007123\n",
      "MODEL: NECK-SWEEP2_11_0_0_mlp\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.303324\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.000228\n",
      "base_token_2        money\n",
      "base_prob_2       0.01203\n",
      "base_token_3            a\n",
      "base_prob_3       0.05737\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000358\n",
      "base_token_5          the\n",
      "base_prob_5      0.040195\n",
      "base_token_6         cash\n",
      "base_prob_6      0.005734\n",
      "base_token_7         some\n",
      "base_prob_7      0.004449\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.002195\n",
      "base_token_9       things\n",
      "base_prob_9      0.000855\n",
      "future_token_0        her\n",
      "future_prob_0    0.303324\n",
      "future_token_1          a\n",
      "future_prob_1     0.05737\n",
      "future_token_2        the\n",
      "future_prob_2    0.040195\n",
      "future_token_3        his\n",
      "future_prob_3     0.01246\n",
      "future_token_4      money\n",
      "future_prob_4     0.01203\n",
      "future_token_5      their\n",
      "future_prob_5    0.009084\n",
      "future_token_6         an\n",
      "future_prob_6    0.008781\n",
      "future_token_7    herself\n",
      "future_prob_7    0.008343\n",
      "future_token_8        him\n",
      "future_prob_8    0.007028\n",
      "future_token_9       this\n",
      "future_prob_9    0.007009\n",
      "MODEL: NECK-SWEEP2_11_0_0_lstm\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.267138\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.002888\n",
      "base_token_2        money\n",
      "base_prob_2      0.013923\n",
      "base_token_3            a\n",
      "base_prob_3      0.097506\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000681\n",
      "base_token_5          the\n",
      "base_prob_5      0.058863\n",
      "base_token_6         cash\n",
      "base_prob_6      0.008179\n",
      "base_token_7         some\n",
      "base_prob_7      0.004916\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.000933\n",
      "base_token_9       things\n",
      "base_prob_9      0.004609\n",
      "future_token_0        her\n",
      "future_prob_0    0.267138\n",
      "future_token_1          a\n",
      "future_prob_1    0.097506\n",
      "future_token_2        the\n",
      "future_prob_2    0.058863\n",
      "future_token_3          $\n",
      "future_prob_3    0.016801\n",
      "future_token_4      money\n",
      "future_prob_4    0.013923\n",
      "future_token_5        two\n",
      "future_prob_5    0.011843\n",
      "future_token_6        his\n",
      "future_prob_6    0.010378\n",
      "future_token_7       more\n",
      "future_prob_7    0.010138\n",
      "future_token_8         an\n",
      "future_prob_8    0.008892\n",
      "future_token_9       cash\n",
      "future_prob_9    0.008179\n",
      "MODEL: NECK-SWEEP2_11_0_1_mlp\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.134753\n",
      "base_token_1     Japanese\n",
      "base_prob_1       0.00131\n",
      "base_token_2        money\n",
      "base_prob_2       0.00281\n",
      "base_token_3            a\n",
      "base_prob_3      0.122515\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000589\n",
      "base_token_5          the\n",
      "base_prob_5      0.088128\n",
      "base_token_6         cash\n",
      "base_prob_6      0.002026\n",
      "base_token_7         some\n",
      "base_prob_7      0.010687\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.000407\n",
      "base_token_9       things\n",
      "base_prob_9      0.000638\n",
      "future_token_0        her\n",
      "future_prob_0    0.134753\n",
      "future_token_1          a\n",
      "future_prob_1    0.122515\n",
      "future_token_2        the\n",
      "future_prob_2    0.088128\n",
      "future_token_3        his\n",
      "future_prob_3    0.025511\n",
      "future_token_4      their\n",
      "future_prob_4    0.014786\n",
      "future_token_5         an\n",
      "future_prob_5    0.014081\n",
      "future_token_6       some\n",
      "future_prob_6    0.010687\n",
      "future_token_7       more\n",
      "future_prob_7    0.009619\n",
      "future_token_8        him\n",
      "future_prob_8    0.009302\n",
      "future_token_9       this\n",
      "future_prob_9    0.009002\n",
      "MODEL: NECK-SWEEP2_12_0_-1_mlp\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.085628\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.001441\n",
      "base_token_2        money\n",
      "base_prob_2      0.003858\n",
      "base_token_3            a\n",
      "base_prob_3      0.082581\n",
      "base_token_4          yen\n",
      "base_prob_4       0.00053\n",
      "base_token_5          the\n",
      "base_prob_5      0.067411\n",
      "base_token_6         cash\n",
      "base_prob_6      0.004092\n",
      "base_token_7         some\n",
      "base_prob_7      0.003151\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.000712\n",
      "base_token_9       things\n",
      "base_prob_9       0.00066\n",
      "future_token_0        her\n",
      "future_prob_0    0.085628\n",
      "future_token_1          a\n",
      "future_prob_1    0.082581\n",
      "future_token_2        the\n",
      "future_prob_2    0.067411\n",
      "future_token_3         \\n\n",
      "future_prob_3     0.02716\n",
      "future_token_4         to\n",
      "future_prob_4    0.022787\n",
      "future_token_5        she\n",
      "future_prob_5    0.020903\n",
      "future_token_6        and\n",
      "future_prob_6    0.019087\n",
      "future_token_7       with\n",
      "future_prob_7    0.019059\n",
      "future_token_8        for\n",
      "future_prob_8    0.017628\n",
      "future_token_9        you\n",
      "future_prob_9    0.009555\n",
      "MODEL: NECK-SWEEP2_12_0_-1_lstm\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.207444\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.004069\n",
      "base_token_2        money\n",
      "base_prob_2       0.00399\n",
      "base_token_3            a\n",
      "base_prob_3      0.058028\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000877\n",
      "base_token_5          the\n",
      "base_prob_5      0.053711\n",
      "base_token_6         cash\n",
      "base_prob_6      0.001326\n",
      "base_token_7         some\n",
      "base_prob_7      0.003419\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.000644\n",
      "base_token_9       things\n",
      "base_prob_9      0.001333\n",
      "future_token_0        her\n",
      "future_prob_0    0.207444\n",
      "future_token_1          a\n",
      "future_prob_1    0.058028\n",
      "future_token_2        the\n",
      "future_prob_2    0.053711\n",
      "future_token_3        she\n",
      "future_prob_3    0.017125\n",
      "future_token_4         \\n\n",
      "future_prob_4    0.013503\n",
      "future_token_5        his\n",
      "future_prob_5    0.012602\n",
      "future_token_6        She\n",
      "future_prob_6     0.01231\n",
      "future_token_7        him\n",
      "future_prob_7    0.011678\n",
      "future_token_8        and\n",
      "future_prob_8    0.008884\n",
      "future_token_9         an\n",
      "future_prob_9    0.007484\n",
      "MODEL: NECK-SWEEP2_12_0_0_mlp\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.205952\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.000998\n",
      "base_token_2        money\n",
      "base_prob_2        0.0094\n",
      "base_token_3            a\n",
      "base_prob_3      0.083687\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000207\n",
      "base_token_5          the\n",
      "base_prob_5      0.092602\n",
      "base_token_6         cash\n",
      "base_prob_6      0.004529\n",
      "base_token_7         some\n",
      "base_prob_7      0.005016\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.000742\n",
      "base_token_9       things\n",
      "base_prob_9      0.003051\n",
      "future_token_0        her\n",
      "future_prob_0    0.205952\n",
      "future_token_1        the\n",
      "future_prob_1    0.092602\n",
      "future_token_2          a\n",
      "future_prob_2    0.083687\n",
      "future_token_3        his\n",
      "future_prob_3    0.015557\n",
      "future_token_4       this\n",
      "future_prob_4    0.013158\n",
      "future_token_5      their\n",
      "future_prob_5    0.010844\n",
      "future_token_6         an\n",
      "future_prob_6    0.010296\n",
      "future_token_7      money\n",
      "future_prob_7      0.0094\n",
      "future_token_8     people\n",
      "future_prob_8    0.008999\n",
      "future_token_9      those\n",
      "future_prob_9    0.007512\n",
      "MODEL: NECK-SWEEP2_12_0_0_lstm\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.112875\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.003578\n",
      "base_token_2        money\n",
      "base_prob_2      0.013261\n",
      "base_token_3            a\n",
      "base_prob_3      0.060213\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000421\n",
      "base_token_5          the\n",
      "base_prob_5      0.062743\n",
      "base_token_6         cash\n",
      "base_prob_6      0.005423\n",
      "base_token_7         some\n",
      "base_prob_7       0.00648\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.000356\n",
      "base_token_9       things\n",
      "base_prob_9      0.006009\n",
      "future_token_0        her\n",
      "future_prob_0    0.112875\n",
      "future_token_1        the\n",
      "future_prob_1    0.062743\n",
      "future_token_2          a\n",
      "future_prob_2    0.060213\n",
      "future_token_3        him\n",
      "future_prob_3    0.014481\n",
      "future_token_4      money\n",
      "future_prob_4    0.013261\n",
      "future_token_5        his\n",
      "future_prob_5    0.011398\n",
      "future_token_6         me\n",
      "future_prob_6     0.00894\n",
      "future_token_7         an\n",
      "future_prob_7    0.008876\n",
      "future_token_8       this\n",
      "future_prob_8    0.008765\n",
      "future_token_9      gifts\n",
      "future_prob_9    0.007159\n",
      "MODEL: NECK-SWEEP2_12_0_1_mlp\n",
      "                        0\n",
      "base_token_0          her\n",
      "base_prob_0      0.097632\n",
      "base_token_1     Japanese\n",
      "base_prob_1      0.002428\n",
      "base_token_2        money\n",
      "base_prob_2      0.006128\n",
      "base_token_3            a\n",
      "base_prob_3      0.087783\n",
      "base_token_4          yen\n",
      "base_prob_4      0.000457\n",
      "base_token_5          the\n",
      "base_prob_5      0.099261\n",
      "base_token_6         cash\n",
      "base_prob_6      0.004875\n",
      "base_token_7         some\n",
      "base_prob_7      0.008205\n",
      "base_token_8      dollars\n",
      "base_prob_8      0.000596\n",
      "base_token_9       things\n",
      "base_prob_9      0.002268\n",
      "future_token_0        the\n",
      "future_prob_0    0.099261\n",
      "future_token_1        her\n",
      "future_prob_1    0.097632\n",
      "future_token_2          a\n",
      "future_prob_2    0.087783\n",
      "future_token_3        his\n",
      "future_prob_3    0.014626\n",
      "future_token_4       that\n",
      "future_prob_4    0.012771\n",
      "future_token_5         an\n",
      "future_prob_5    0.012099\n",
      "future_token_6       this\n",
      "future_prob_6    0.011954\n",
      "future_token_7     people\n",
      "future_prob_7    0.010593\n",
      "future_token_8       more\n",
      "future_prob_8    0.009798\n",
      "future_token_9      their\n",
      "future_prob_9     0.00908\n"
     ]
    }
   ],
   "source": [
    "prompt = 'Alice is visiting Japan, so she\\'s exchanging her dollars for'\n",
    "gen_futures(prompt, gpt2_models, gpt2_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "675ec2cd-2b3b-4626-99ab-1c1448a6de6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_futures(prompt, mistral_models, mistral_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d26c430-e42d-4033-ba77-8b21452ad97f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
