{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "901a4302-a34a-4aad-901a-279ab2a54cdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "from evals.utils import *\n",
    "from models.bigram_model import *\n",
    "from models.mlp_model import *\n",
    "from models.future_model import *\n",
    "from data.utils import get_tokenizer\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import nn\n",
    "from itertools import islice\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import gc\n",
    "from glob import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4592e074-143a-48ba-98da-3630fae2c6be",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = 'MISTRAL'\n",
    "dataset = datasets.load_from_disk(f'/workspace/corpus/msmarco/msmarco_{MODEL}_64tokens_1m').with_format('torch', device=torch.device('cuda'))\n",
    "test = dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e993eeed-8586-4de0-be4b-bc009c2bd3f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = DataLoader(test, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5c3e917e-f25b-4c07-8afb-d4862d312912",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = '/workspace/checkpoints/MISTRAL-NECK-SWEEP_20240102-191556-Ec6c4_hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-lstm_epoch=00-val_self_loss=3.81.ckpt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "49b2660d-a998-4a81-9eae-22734b4242fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/FutureGPT2/src/models/future_model.py:399: UserWarning: WARN: layer_dims present in config but not used in either neck or model!\n",
      "  warnings.warn(f'WARN: {k} present in config but not used in either neck or model!')\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9444ebfa3b304c749db69fd2fcf4aaf3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/FutureGPT2/src/models/future_model.py:77: UserWarning: BASE SIZE: 28649.0234375 MB\n",
      "  warnings.warn(f'BASE SIZE: {size_mb(self.base_model)} MB')\n",
      "/workspace/FutureGPT2/src/models/future_model.py:98: UserWarning: NECK SIZE: 832.140625 MB\n",
      "  warnings.warn(f'NECK SIZE: {size_mb(self.future_neck)} MB')\n",
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:173: Found keys that are in the model state dict but not in the checkpoint: ['base_model.model.embed_tokens.weight', 'base_model.model.layers.0.self_attn.q_proj.weight', 'base_model.model.layers.0.self_attn.k_proj.weight', 'base_model.model.layers.0.self_attn.v_proj.weight', 'base_model.model.layers.0.self_attn.o_proj.weight', 'base_model.model.layers.0.mlp.gate_proj.weight', 'base_model.model.layers.0.mlp.up_proj.weight', 'base_model.model.layers.0.mlp.down_proj.weight', 'base_model.model.layers.0.input_layernorm.weight', 'base_model.model.layers.0.post_attention_layernorm.weight', 'base_model.model.layers.1.self_attn.q_proj.weight', 'base_model.model.layers.1.self_attn.k_proj.weight', 'base_model.model.layers.1.self_attn.v_proj.weight', 'base_model.model.layers.1.self_attn.o_proj.weight', 'base_model.model.layers.1.mlp.gate_proj.weight', 'base_model.model.layers.1.mlp.up_proj.weight', 'base_model.model.layers.1.mlp.down_proj.weight', 'base_model.model.layers.1.input_layernorm.weight', 'base_model.model.layers.1.post_attention_layernorm.weight', 'base_model.model.layers.2.self_attn.q_proj.weight', 'base_model.model.layers.2.self_attn.k_proj.weight', 'base_model.model.layers.2.self_attn.v_proj.weight', 'base_model.model.layers.2.self_attn.o_proj.weight', 'base_model.model.layers.2.mlp.gate_proj.weight', 'base_model.model.layers.2.mlp.up_proj.weight', 'base_model.model.layers.2.mlp.down_proj.weight', 'base_model.model.layers.2.input_layernorm.weight', 'base_model.model.layers.2.post_attention_layernorm.weight', 'base_model.model.layers.3.self_attn.q_proj.weight', 'base_model.model.layers.3.self_attn.k_proj.weight', 'base_model.model.layers.3.self_attn.v_proj.weight', 'base_model.model.layers.3.self_attn.o_proj.weight', 'base_model.model.layers.3.mlp.gate_proj.weight', 'base_model.model.layers.3.mlp.up_proj.weight', 'base_model.model.layers.3.mlp.down_proj.weight', 'base_model.model.layers.3.input_layernorm.weight', 'base_model.model.layers.3.post_attention_layernorm.weight', 'base_model.model.layers.4.self_attn.q_proj.weight', 'base_model.model.layers.4.self_attn.k_proj.weight', 'base_model.model.layers.4.self_attn.v_proj.weight', 'base_model.model.layers.4.self_attn.o_proj.weight', 'base_model.model.layers.4.mlp.gate_proj.weight', 'base_model.model.layers.4.mlp.up_proj.weight', 'base_model.model.layers.4.mlp.down_proj.weight', 'base_model.model.layers.4.input_layernorm.weight', 'base_model.model.layers.4.post_attention_layernorm.weight', 'base_model.model.layers.5.self_attn.q_proj.weight', 'base_model.model.layers.5.self_attn.k_proj.weight', 'base_model.model.layers.5.self_attn.v_proj.weight', 'base_model.model.layers.5.self_attn.o_proj.weight', 'base_model.model.layers.5.mlp.gate_proj.weight', 'base_model.model.layers.5.mlp.up_proj.weight', 'base_model.model.layers.5.mlp.down_proj.weight', 'base_model.model.layers.5.input_layernorm.weight', 'base_model.model.layers.5.post_attention_layernorm.weight', 'base_model.model.layers.6.self_attn.q_proj.weight', 'base_model.model.layers.6.self_attn.k_proj.weight', 'base_model.model.layers.6.self_attn.v_proj.weight', 'base_model.model.layers.6.self_attn.o_proj.weight', 'base_model.model.layers.6.mlp.gate_proj.weight', 'base_model.model.layers.6.mlp.up_proj.weight', 'base_model.model.layers.6.mlp.down_proj.weight', 'base_model.model.layers.6.input_layernorm.weight', 'base_model.model.layers.6.post_attention_layernorm.weight', 'base_model.model.layers.7.self_attn.q_proj.weight', 'base_model.model.layers.7.self_attn.k_proj.weight', 'base_model.model.layers.7.self_attn.v_proj.weight', 'base_model.model.layers.7.self_attn.o_proj.weight', 'base_model.model.layers.7.mlp.gate_proj.weight', 'base_model.model.layers.7.mlp.up_proj.weight', 'base_model.model.layers.7.mlp.down_proj.weight', 'base_model.model.layers.7.input_layernorm.weight', 'base_model.model.layers.7.post_attention_layernorm.weight', 'base_model.model.layers.8.self_attn.q_proj.weight', 'base_model.model.layers.8.self_attn.k_proj.weight', 'base_model.model.layers.8.self_attn.v_proj.weight', 'base_model.model.layers.8.self_attn.o_proj.weight', 'base_model.model.layers.8.mlp.gate_proj.weight', 'base_model.model.layers.8.mlp.up_proj.weight', 'base_model.model.layers.8.mlp.down_proj.weight', 'base_model.model.layers.8.input_layernorm.weight', 'base_model.model.layers.8.post_attention_layernorm.weight', 'base_model.model.layers.9.self_attn.q_proj.weight', 'base_model.model.layers.9.self_attn.k_proj.weight', 'base_model.model.layers.9.self_attn.v_proj.weight', 'base_model.model.layers.9.self_attn.o_proj.weight', 'base_model.model.layers.9.mlp.gate_proj.weight', 'base_model.model.layers.9.mlp.up_proj.weight', 'base_model.model.layers.9.mlp.down_proj.weight', 'base_model.model.layers.9.input_layernorm.weight', 'base_model.model.layers.9.post_attention_layernorm.weight', 'base_model.model.layers.10.self_attn.q_proj.weight', 'base_model.model.layers.10.self_attn.k_proj.weight', 'base_model.model.layers.10.self_attn.v_proj.weight', 'base_model.model.layers.10.self_attn.o_proj.weight', 'base_model.model.layers.10.mlp.gate_proj.weight', 'base_model.model.layers.10.mlp.up_proj.weight', 'base_model.model.layers.10.mlp.down_proj.weight', 'base_model.model.layers.10.input_layernorm.weight', 'base_model.model.layers.10.post_attention_layernorm.weight', 'base_model.model.layers.11.self_attn.q_proj.weight', 'base_model.model.layers.11.self_attn.k_proj.weight', 'base_model.model.layers.11.self_attn.v_proj.weight', 'base_model.model.layers.11.self_attn.o_proj.weight', 'base_model.model.layers.11.mlp.gate_proj.weight', 'base_model.model.layers.11.mlp.up_proj.weight', 'base_model.model.layers.11.mlp.down_proj.weight', 'base_model.model.layers.11.input_layernorm.weight', 'base_model.model.layers.11.post_attention_layernorm.weight', 'base_model.model.layers.12.self_attn.q_proj.weight', 'base_model.model.layers.12.self_attn.k_proj.weight', 'base_model.model.layers.12.self_attn.v_proj.weight', 'base_model.model.layers.12.self_attn.o_proj.weight', 'base_model.model.layers.12.mlp.gate_proj.weight', 'base_model.model.layers.12.mlp.up_proj.weight', 'base_model.model.layers.12.mlp.down_proj.weight', 'base_model.model.layers.12.input_layernorm.weight', 'base_model.model.layers.12.post_attention_layernorm.weight', 'base_model.model.layers.13.self_attn.q_proj.weight', 'base_model.model.layers.13.self_attn.k_proj.weight', 'base_model.model.layers.13.self_attn.v_proj.weight', 'base_model.model.layers.13.self_attn.o_proj.weight', 'base_model.model.layers.13.mlp.gate_proj.weight', 'base_model.model.layers.13.mlp.up_proj.weight', 'base_model.model.layers.13.mlp.down_proj.weight', 'base_model.model.layers.13.input_layernorm.weight', 'base_model.model.layers.13.post_attention_layernorm.weight', 'base_model.model.layers.14.self_attn.q_proj.weight', 'base_model.model.layers.14.self_attn.k_proj.weight', 'base_model.model.layers.14.self_attn.v_proj.weight', 'base_model.model.layers.14.self_attn.o_proj.weight', 'base_model.model.layers.14.mlp.gate_proj.weight', 'base_model.model.layers.14.mlp.up_proj.weight', 'base_model.model.layers.14.mlp.down_proj.weight', 'base_model.model.layers.14.input_layernorm.weight', 'base_model.model.layers.14.post_attention_layernorm.weight', 'base_model.model.layers.15.self_attn.q_proj.weight', 'base_model.model.layers.15.self_attn.k_proj.weight', 'base_model.model.layers.15.self_attn.v_proj.weight', 'base_model.model.layers.15.self_attn.o_proj.weight', 'base_model.model.layers.15.mlp.gate_proj.weight', 'base_model.model.layers.15.mlp.up_proj.weight', 'base_model.model.layers.15.mlp.down_proj.weight', 'base_model.model.layers.15.input_layernorm.weight', 'base_model.model.layers.15.post_attention_layernorm.weight', 'base_model.model.layers.16.self_attn.q_proj.weight', 'base_model.model.layers.16.self_attn.k_proj.weight', 'base_model.model.layers.16.self_attn.v_proj.weight', 'base_model.model.layers.16.self_attn.o_proj.weight', 'base_model.model.layers.16.mlp.gate_proj.weight', 'base_model.model.layers.16.mlp.up_proj.weight', 'base_model.model.layers.16.mlp.down_proj.weight', 'base_model.model.layers.16.input_layernorm.weight', 'base_model.model.layers.16.post_attention_layernorm.weight', 'base_model.model.layers.17.self_attn.q_proj.weight', 'base_model.model.layers.17.self_attn.k_proj.weight', 'base_model.model.layers.17.self_attn.v_proj.weight', 'base_model.model.layers.17.self_attn.o_proj.weight', 'base_model.model.layers.17.mlp.gate_proj.weight', 'base_model.model.layers.17.mlp.up_proj.weight', 'base_model.model.layers.17.mlp.down_proj.weight', 'base_model.model.layers.17.input_layernorm.weight', 'base_model.model.layers.17.post_attention_layernorm.weight', 'base_model.model.layers.18.self_attn.q_proj.weight', 'base_model.model.layers.18.self_attn.k_proj.weight', 'base_model.model.layers.18.self_attn.v_proj.weight', 'base_model.model.layers.18.self_attn.o_proj.weight', 'base_model.model.layers.18.mlp.gate_proj.weight', 'base_model.model.layers.18.mlp.up_proj.weight', 'base_model.model.layers.18.mlp.down_proj.weight', 'base_model.model.layers.18.input_layernorm.weight', 'base_model.model.layers.18.post_attention_layernorm.weight', 'base_model.model.layers.19.self_attn.q_proj.weight', 'base_model.model.layers.19.self_attn.k_proj.weight', 'base_model.model.layers.19.self_attn.v_proj.weight', 'base_model.model.layers.19.self_attn.o_proj.weight', 'base_model.model.layers.19.mlp.gate_proj.weight', 'base_model.model.layers.19.mlp.up_proj.weight', 'base_model.model.layers.19.mlp.down_proj.weight', 'base_model.model.layers.19.input_layernorm.weight', 'base_model.model.layers.19.post_attention_layernorm.weight', 'base_model.model.layers.20.self_attn.q_proj.weight', 'base_model.model.layers.20.self_attn.k_proj.weight', 'base_model.model.layers.20.self_attn.v_proj.weight', 'base_model.model.layers.20.self_attn.o_proj.weight', 'base_model.model.layers.20.mlp.gate_proj.weight', 'base_model.model.layers.20.mlp.up_proj.weight', 'base_model.model.layers.20.mlp.down_proj.weight', 'base_model.model.layers.20.input_layernorm.weight', 'base_model.model.layers.20.post_attention_layernorm.weight', 'base_model.model.layers.21.self_attn.q_proj.weight', 'base_model.model.layers.21.self_attn.k_proj.weight', 'base_model.model.layers.21.self_attn.v_proj.weight', 'base_model.model.layers.21.self_attn.o_proj.weight', 'base_model.model.layers.21.mlp.gate_proj.weight', 'base_model.model.layers.21.mlp.up_proj.weight', 'base_model.model.layers.21.mlp.down_proj.weight', 'base_model.model.layers.21.input_layernorm.weight', 'base_model.model.layers.21.post_attention_layernorm.weight', 'base_model.model.layers.22.self_attn.q_proj.weight', 'base_model.model.layers.22.self_attn.k_proj.weight', 'base_model.model.layers.22.self_attn.v_proj.weight', 'base_model.model.layers.22.self_attn.o_proj.weight', 'base_model.model.layers.22.mlp.gate_proj.weight', 'base_model.model.layers.22.mlp.up_proj.weight', 'base_model.model.layers.22.mlp.down_proj.weight', 'base_model.model.layers.22.input_layernorm.weight', 'base_model.model.layers.22.post_attention_layernorm.weight', 'base_model.model.layers.23.self_attn.q_proj.weight', 'base_model.model.layers.23.self_attn.k_proj.weight', 'base_model.model.layers.23.self_attn.v_proj.weight', 'base_model.model.layers.23.self_attn.o_proj.weight', 'base_model.model.layers.23.mlp.gate_proj.weight', 'base_model.model.layers.23.mlp.up_proj.weight', 'base_model.model.layers.23.mlp.down_proj.weight', 'base_model.model.layers.23.input_layernorm.weight', 'base_model.model.layers.23.post_attention_layernorm.weight', 'base_model.model.layers.24.self_attn.q_proj.weight', 'base_model.model.layers.24.self_attn.k_proj.weight', 'base_model.model.layers.24.self_attn.v_proj.weight', 'base_model.model.layers.24.self_attn.o_proj.weight', 'base_model.model.layers.24.mlp.gate_proj.weight', 'base_model.model.layers.24.mlp.up_proj.weight', 'base_model.model.layers.24.mlp.down_proj.weight', 'base_model.model.layers.24.input_layernorm.weight', 'base_model.model.layers.24.post_attention_layernorm.weight', 'base_model.model.layers.25.self_attn.q_proj.weight', 'base_model.model.layers.25.self_attn.k_proj.weight', 'base_model.model.layers.25.self_attn.v_proj.weight', 'base_model.model.layers.25.self_attn.o_proj.weight', 'base_model.model.layers.25.mlp.gate_proj.weight', 'base_model.model.layers.25.mlp.up_proj.weight', 'base_model.model.layers.25.mlp.down_proj.weight', 'base_model.model.layers.25.input_layernorm.weight', 'base_model.model.layers.25.post_attention_layernorm.weight', 'base_model.model.layers.26.self_attn.q_proj.weight', 'base_model.model.layers.26.self_attn.k_proj.weight', 'base_model.model.layers.26.self_attn.v_proj.weight', 'base_model.model.layers.26.self_attn.o_proj.weight', 'base_model.model.layers.26.mlp.gate_proj.weight', 'base_model.model.layers.26.mlp.up_proj.weight', 'base_model.model.layers.26.mlp.down_proj.weight', 'base_model.model.layers.26.input_layernorm.weight', 'base_model.model.layers.26.post_attention_layernorm.weight', 'base_model.model.layers.27.self_attn.q_proj.weight', 'base_model.model.layers.27.self_attn.k_proj.weight', 'base_model.model.layers.27.self_attn.v_proj.weight', 'base_model.model.layers.27.self_attn.o_proj.weight', 'base_model.model.layers.27.mlp.gate_proj.weight', 'base_model.model.layers.27.mlp.up_proj.weight', 'base_model.model.layers.27.mlp.down_proj.weight', 'base_model.model.layers.27.input_layernorm.weight', 'base_model.model.layers.27.post_attention_layernorm.weight', 'base_model.model.layers.28.self_attn.q_proj.weight', 'base_model.model.layers.28.self_attn.k_proj.weight', 'base_model.model.layers.28.self_attn.v_proj.weight', 'base_model.model.layers.28.self_attn.o_proj.weight', 'base_model.model.layers.28.mlp.gate_proj.weight', 'base_model.model.layers.28.mlp.up_proj.weight', 'base_model.model.layers.28.mlp.down_proj.weight', 'base_model.model.layers.28.input_layernorm.weight', 'base_model.model.layers.28.post_attention_layernorm.weight', 'base_model.model.layers.29.self_attn.q_proj.weight', 'base_model.model.layers.29.self_attn.k_proj.weight', 'base_model.model.layers.29.self_attn.v_proj.weight', 'base_model.model.layers.29.self_attn.o_proj.weight', 'base_model.model.layers.29.mlp.gate_proj.weight', 'base_model.model.layers.29.mlp.up_proj.weight', 'base_model.model.layers.29.mlp.down_proj.weight', 'base_model.model.layers.29.input_layernorm.weight', 'base_model.model.layers.29.post_attention_layernorm.weight', 'base_model.model.layers.30.self_attn.q_proj.weight', 'base_model.model.layers.30.self_attn.k_proj.weight', 'base_model.model.layers.30.self_attn.v_proj.weight', 'base_model.model.layers.30.self_attn.o_proj.weight', 'base_model.model.layers.30.mlp.gate_proj.weight', 'base_model.model.layers.30.mlp.up_proj.weight', 'base_model.model.layers.30.mlp.down_proj.weight', 'base_model.model.layers.30.input_layernorm.weight', 'base_model.model.layers.30.post_attention_layernorm.weight', 'base_model.model.layers.31.self_attn.q_proj.weight', 'base_model.model.layers.31.self_attn.k_proj.weight', 'base_model.model.layers.31.self_attn.v_proj.weight', 'base_model.model.layers.31.self_attn.o_proj.weight', 'base_model.model.layers.31.mlp.gate_proj.weight', 'base_model.model.layers.31.mlp.up_proj.weight', 'base_model.model.layers.31.mlp.down_proj.weight', 'base_model.model.layers.31.input_layernorm.weight', 'base_model.model.layers.31.post_attention_layernorm.weight', 'base_model.model.norm.weight', 'base_model.lm_head.weight']\n"
     ]
    }
   ],
   "source": [
    "model = LitFutureModelWithNeck.load_from_checkpoint(ckpt_path, strict=False).to('cuda')\n",
    "# don't reduce loss\n",
    "model.loss_func = nn.CrossEntropyLoss(reduction='none')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a10bf518-d445-457f-b543-e96dc4a45a63",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "01446e33-d9e9-4e25-9303-ecc4e49a9b27",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 704/704 [1:23:52<00:00,  7.15s/it]\n"
     ]
    }
   ],
   "source": [
    "losses = []\n",
    "ids = []\n",
    "test_iter = iter(loader)\n",
    "for batch in tqdm(test_iter):\n",
    "    loss = model._compute_loss(batch)\n",
    "    losses.append(loss.self_loss.reshape(-1, 63).cpu().detach())  # (seq_length-1)=63\n",
    "    ids += batch['id']\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "losses = torch.concatenate(losses, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "45826a9a-e8c8-4909-b01f-22161d51ab3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "topk_val, topk_ind = losses.flatten().topk(10)\n",
    "#topk_val, topk_ind = (-losses).flatten().topk(10)\n",
    "topk_ind = np.array(np.unravel_index(topk_ind.numpy(), losses.shape)).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a60451dc-3fb1-444d-93e4-34d04385d60d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_row(data, id):\n",
    "    idx = data['id'].index(id)\n",
    "    return {k: data[k][idx] for k in ['text', 'input_ids', 'attention_mask']}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f2b3019a-db5a-4e49-b145-4aedc5901bc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'mistralai/Mistral-7B-v0.1'\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "Token = {v: k for k, v in tokenizer.get_vocab().items()}\n",
    "\n",
    "def topk(v, k=40):\n",
    "    # Takes in logits\n",
    "    #v = softmax(v.flatten())\n",
    "    v = v.flatten()\n",
    "    idxs = v.argsort()[-k:][::-1]\n",
    "    ret = [(Token[i], v[i]) for i in idxs]\n",
    "    return pd.DataFrame(ret, columns=['token', 'logit'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ccc10874-48e0-4037-952f-bb2ab6f1e87b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s>|▁In|▁June|▁|2|0|1|2|,|▁Den|ny|'|s|▁opened|▁a|▁location|▁in|▁the|▁Las|▁Am(Ã)\n",
      "BASE vs FUTURE:\n",
      "   token      logit token     logit\n",
      "0   éric  20.796623    ag  9.692738\n",
      "1      Ã  12.286123    as  9.523293\n",
      "2      é  11.724434    es  9.291723\n",
      "3    ist  11.619908    os  8.810611\n",
      "4  érica  11.352891    ad  8.638937\n",
      "5     ig  11.206367     e  8.476806\n",
      "6    ric  10.407958   pes  8.458270\n",
      "7     el  10.261961    ac  8.347697\n",
      "8    éri  10.244709     a  8.338440\n",
      "9     ér   9.979428   ena  8.296510\n",
      "LOSS: 18.58062171936035 18.580528259277344\n",
      "<s>|▁|1|1|▁|3|▁|Ã|·|▁|1|1|▁|3|▁=|▁|1|,|3|3|1|▁True|▁False|.|▁We|egy|:|▁False|▁User|:|▁Sim|pl|ify|.|▁y|▁|5|▁|Â|·|▁y|▁|3|▁|Ã|·|▁y|▁|2|▁We(egy)\n",
      "BASE vs FUTURE:\n",
      "   token      logit  token     logit\n",
      "0    egy  16.067135      '  8.341678\n",
      "1    edy   9.887799   ▁can  8.096248\n",
      "2   ▁can   9.612230     ▁=  8.027686\n",
      "3      e   8.764141     ▁x  7.747435\n",
      "4  ▁have   8.721185      â  7.640388\n",
      "5  ▁need   8.693643  ▁have  7.571069\n",
      "6  ▁will   7.806245   ▁are  7.564444\n",
      "7  ▁know   7.727398      x  7.316980\n",
      "8   ▁don   7.716074    ▁is  7.259876\n",
      "9   ▁are   7.683526      ▁  7.207473\n",
      "LOSS: 17.593915939331055 17.593936920166016\n",
      "<s>|▁I|▁heard|▁that|▁N|isk|ay|una|,|▁G(uilder)\n",
      "BASE vs FUTURE:\n",
      "    token      logit  token      logit\n",
      "0  uilder  18.142620      .  10.468679\n",
      "1      lo  14.882296      Y   9.658442\n",
      "2       E  12.209717      H   9.244341\n",
      "3    ales  10.907859      N   8.890127\n",
      "4       .  10.826200      ,   8.834965\n",
      "5    ates  10.681667      J   8.454570\n",
      "6    raft  10.608150      I   8.267923\n",
      "7    anse  10.575612     .,   8.196075\n",
      "8     off  10.561547      -   7.844780\n",
      "9     arn  10.495750  ▁York   7.637235\n",
      "LOSS: 17.194414138793945 17.194263458251953\n",
      "<s>|▁John|s|▁Hop|kins|▁Queen|▁of|▁Tur|f|▁(|B|alt|imore|,|▁MD|)|▁-|▁October|▁|1|8|th|▁(|optional|,|▁not|▁included|▁in|▁fees|)|▁Prime|▁Time|▁Rec|ru|iting|▁Show|case|▁(|Bel|▁Air|,|▁MD|)-|▁November|▁|7|th|.|▁L|ax|▁Cl|ash|▁(|Bel(▁Air)\n",
      "BASE vs FUTURE:\n",
      "  token      logit    token      logit\n",
      "0  ▁Air  17.423073       ly  11.909500\n",
      "1  mont  12.170218     ster  11.329329\n",
      "2    ts  11.465906      ley  11.209556\n",
      "3   air  10.625484     ater  11.050723\n",
      "4   Air  10.546558  chester  11.033018\n",
      "5   lev  10.473010      ton  10.913441\n",
      "6   ved  10.448994       st  10.585052\n",
      "7  camp  10.328986     ston  10.509165\n",
      "8     -  10.065748       on  10.500232\n",
      "9   ▁Al   9.782031       le  10.457557\n",
      "LOSS: 17.12163543701172 17.121728897094727\n",
      "<s>|▁By|▁knowing|▁what|▁these|▁cancer|-|f|ight|ing|▁foods|▁are|,|▁we|▁can|▁take|▁positive|▁steps|▁in|▁the|▁fight|▁against|▁cancer|.|▁Lead|ing|▁cancer|▁authority|▁Richard|▁Bel(iveau)\n",
      "BASE vs FUTURE:\n",
      "   token      logit token      logit\n",
      "0  iveau  16.420172    an  10.884651\n",
      "1     kn  12.188009   son  10.763093\n",
      "2    itz  11.913871     f  10.646927\n",
      "3   ding  11.851681    is  10.640945\n",
      "4    its  11.594748    on  10.447121\n",
      "5  insky  11.176389    ry  10.292603\n",
      "6    zer  11.021322    us  10.103113\n",
      "7   cher  10.983014   ell   9.878038\n",
      "8      k  10.533964   ley   9.861928\n",
      "9  court  10.279200   man   9.845502\n",
      "LOSS: 16.553255081176758 16.55303382873535\n",
      "<s>|▁List|▁of|▁Average|▁Tem|po(▁()\n",
      "BASE vs FUTURE:\n",
      "    token      logit   token      logit\n",
      "0  ▁Songs  12.204136  atures  17.953848\n",
      "1  ▁songs  10.818007   ature  17.685072\n",
      "2     ▁Tr   9.519403   aries  13.614298\n",
      "3    ▁Run   9.337012  ations  13.150720\n",
      "4      ▁R   9.331960    ages  11.608529\n",
      "5  ▁Music   9.258056    ures  11.585028\n",
      "6  <0x0A>   8.907178    atur  11.523750\n",
      "7      ▁P   8.648770   ation  11.381439\n",
      "8     ▁Dr   8.645649   ities  11.042308\n",
      "9     ▁Gu   8.488974   ators  10.907338\n",
      "LOSS: 16.525625228881836 16.525827407836914\n",
      "<s>|▁Welcome|▁at|▁El|▁Ref|ug|io|.|▁|▁El|▁Ref|ug|io|▁|▁restaurant|▁and|▁gu|es|th|ouse|▁lies|▁at|▁one|▁of|▁the|▁most|▁beautiful|▁places|▁in|▁the|▁south|▁of|▁T|ener|ife|.|▁|9|4|0|▁meters|▁above|▁seal(evel)\n",
      "BASE vs FUTURE:\n",
      "    token      logit    token      logit\n",
      "0    evel  19.300224   ▁level  12.271986\n",
      "1  ▁level  15.445610     ▁sea  11.582130\n",
      "2      ev  13.317312     ▁the  11.158568\n",
      "3       -  11.538998        ,  10.669075\n",
      "4     vel  11.229339     ▁and  10.610079\n",
      "5      vl  10.848573        .  10.528693\n",
      "6   level  10.529313      ▁in   9.221312\n",
      "7     ▁le  10.263854  ▁levels   9.029104\n",
      "8       v  10.064889      ▁of   8.716483\n",
      "9     eve   9.906376      ▁on   8.613585\n",
      "LOSS: 16.494029998779297 16.49410057067871\n",
      "<s>|▁A|▁regular|▁poly|hed|ron|▁is|▁identified|▁by|▁its|▁Sch(la)\n",
      "BASE vs FUTURE:\n",
      "  token      logit   token      logit\n",
      "0    lä  17.621168     ell  10.471432\n",
      "1    la  15.084376      id   9.785460\n",
      "2   leg  14.629351      ot   9.639959\n",
      "3  warz  12.569276      om   9.554816\n",
      "4     a  12.552008      et   9.415544\n",
      "5    ön  12.443191    agon   9.378013\n",
      "6   oen  12.038251    olar   9.281439\n",
      "7    le  10.951105     ism   8.934696\n",
      "8     l  10.798338      el   8.691120\n",
      "9     ö  10.510221  ateral   8.683899\n",
      "LOSS: 16.37184715270996 16.371850967407227\n",
      "<s>|▁UN|IVER|S|ITY|▁OF|▁P|HO|EN|IX|▁F|AF|SA|▁Code|▁Your|▁one|▁stop|▁resource|▁to|▁anything|▁relating|▁to|▁UN(IVER)\n",
      "BASE vs FUTURE:\n",
      "   token      logit token      logit\n",
      "0   IVER  16.826038     F  11.244146\n",
      "1     IV  12.425361    ▁F  10.465271\n",
      "2      I  11.332151    FA   9.779150\n",
      "3  ivers   9.926439    PA   9.224319\n",
      "4     IT   9.903395    AF   9.046834\n",
      "5      P   9.620754    SA   8.759499\n",
      "6      C   9.508973    FS   8.756845\n",
      "7      i   8.864655     U   8.507065\n",
      "8   iver   8.588744    AA   8.495877\n",
      "9      L   8.398342    FC   8.357748\n",
      "LOSS: 16.327285766601562 16.32726287841797\n",
      "<s>|▁Gran|by|▁is|▁br|im|ming|▁with|▁exciting|▁activities|▁to|▁enjoy|▁all|▁summer|▁long|▁like|▁h|iking|,|▁world|-|class|▁golf|ing|,|▁mountain|▁b|iking|,|▁bo|ating|,|▁r|aft|ing|,|▁horse|back|▁riding|,|▁hunting|,|▁Gold(▁Medal)\n",
      "BASE vs FUTURE:\n",
      "    token      logit     token      logit\n",
      "0  ▁Medal  15.238127         ,  10.858891\n",
      "1      ▁R  11.176836         f  10.540323\n",
      "2  ▁medal  11.174384       ing  10.188580\n",
      "3       -  10.637672      ▁and   9.913383\n",
      "4    ▁Min  10.485248         -   9.560018\n",
      "5    ▁Pro  10.388824  ▁fishing   8.985096\n",
      "6   ▁Camp  10.066567      ▁ski   8.718757\n",
      "7      ▁P   9.585752      fish   8.674867\n",
      "8   ▁King   9.426716     water   8.342726\n",
      "9      ▁p   9.383149       ery   8.127048\n",
      "LOSS: 16.29709243774414 16.29722785949707\n"
     ]
    }
   ],
   "source": [
    "for ind, val in zip(topk_ind, topk_val):\n",
    "    row = get_row(test, ids[ind[0]])\n",
    "    input_ids = row['input_ids'][:ind[1] + 2] # loss at seq idx n corresponds to forward pass at idx n+1\n",
    "    input_ids = input_ids.unsqueeze(0) # add batch dim\n",
    "    out = model({'input_ids': input_ids.to('cuda'), 'attention_mask': torch.ones(input_ids.shape).to('cuda')})\n",
    "    base = out.logits[0, ind[1] + 1,:]\n",
    "    future = out.future_logits[0, ind[1],:]\n",
    "    out_str = '|'.join(Token[i] for i in input_ids.cpu().flatten().numpy())\n",
    "    if ind[1] + 2 < 64:\n",
    "        out_str += '(' + Token[row['input_ids'][ind[1] + 2].item()] + ')'\n",
    "    print(out_str)\n",
    "    print('BASE vs FUTURE:')\n",
    "    print(pd.concat([\n",
    "        topk(base.cpu().numpy(), k=10),\n",
    "        topk(future.cpu().detach().numpy(), k=10)\n",
    "    ], axis=1))\n",
    "    print('LOSS:', val.item(), nn.CrossEntropyLoss()(future, torch.softmax(base, dim=0)).item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20efad8f-4e28-4c17-adb0-80da6cc6b33d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
