{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "901a4302-a34a-4aad-901a-279ab2a54cdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# need to do this before transformer imports\n",
    "import os\n",
    "os.environ['HF_HOME'] = '/workspace/cache/huggingface/'\n",
    "\n",
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "from evals.utils import *\n",
    "from models.bigram_model import *\n",
    "from models.mlp_model import *\n",
    "from models.future_model import *\n",
    "from data.utils import get_tokenizer\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import nn\n",
    "from itertools import islice\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import gc\n",
    "from glob import glob\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4592e074-143a-48ba-98da-3630fae2c6be",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = 'MISTRAL'\n",
    "dataset = datasets.load_from_disk(f'/workspace/corpus/msmarco/msmarco_{MODEL}_64tokens_1m').with_format('torch', device=torch.device('cuda'))\n",
    "test = dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e993eeed-8586-4de0-be4b-bc009c2bd3f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = DataLoader(test, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7cb9205c-34a1-4327-8314-dbb19af26baa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/workspace/checkpoints/MISTRAL-NECK-SWEEP_20240102-003452-6d828_hidden_idxs-0_hidden_lb--1_token_lb-0_neck_cls-mlp_epoch=00-val_self_loss=6.85.ckpt'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glob('/workspace/checkpoints/MISTRAL-NECK-SWEEP_*_hidden_idxs-*_hidden_lb--1_token_lb-0_neck_cls-mlp*')[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5c3e917e-f25b-4c07-8afb-d4862d312912",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path_d = {\n",
    "    't': glob('/workspace/checkpoints/MISTRAL-NECK-SWEEP_*_hidden_idxs-*_hidden_lb--1_token_lb-0_neck_cls-mlp*')[0],\n",
    "    'ht': glob('/workspace/checkpoints/MISTRAL-NECK-SWEEP_*_hidden_idxs-32_hidden_lb-0_token_lb-0_neck_cls-mlp*')[0]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e33bb6f2-07a9-49aa-a059-e8de390f39c8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7fb3d90c0460>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8bfe6fa9-0731-49f3-be00-a3e312836cd2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a28f31fe6ee4ae09700f57f22d8ea21",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "80059ce85cf0431896c2ce20b858cad8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_d = dict()\n",
    "for name in ckpt_path_d:\n",
    "    model_d[name] = LitFutureModelWithNeck.load_from_checkpoint(ckpt_path_d[name], strict=False).to('cuda')\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dc547f1e-9df1-4ab7-86e1-522299123cf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name in model_d:\n",
    "    model_d[name].loss_func = nn.CrossEntropyLoss(reduction='none')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "01446e33-d9e9-4e25-9303-ecc4e49a9b27",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TODO!!!!! uncomment total_loss term!!!!\n",
      "TODO!!!!! uncomment total_loss term!!!!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [00:13, 13.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TODO!!!!! uncomment total_loss term!!!!\n",
      "TODO!!!!! uncomment total_loss term!!!!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2it [00:27, 13.77s/it]\n"
     ]
    }
   ],
   "source": [
    "losses = {k: [] for k in model_d}\n",
    "ids = []\n",
    "test_iter = iter(loader)\n",
    "for batch in tqdm(islice(test_iter, 2)):\n",
    "    for name, model in model_d.items():\n",
    "        loss = model._compute_loss(batch)\n",
    "        losses[name].append(loss.self_loss.reshape(-1, 63).cpu().detach())  # (seq_length-1)=63\n",
    "    ids += batch['id']\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "for k in losses:\n",
    "    losses[k] = torch.concatenate(losses[k], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "65329cfc-00dd-46de-80d5-8b4a101f0689",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_diffs = losses['t'] - losses['ht']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "45826a9a-e8c8-4909-b01f-22161d51ab3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "topk_val, topk_ind = loss_diffs.flatten().topk(10)\n",
    "#topk_val, topk_ind = (-losses).flatten().topk(10)\n",
    "topk_ind = np.array(np.unravel_index(topk_ind.numpy(), loss_diffs.shape)).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0e137d16-2f0e-402f-a00d-60f8b5440e6b",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "object of type 'module' has no len()",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mTypeError\u001b[0m: object of type 'module' has no len()"
     ]
    }
   ],
   "source": [
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a60451dc-3fb1-444d-93e4-34d04385d60d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_row(data, id):\n",
    "    idx = data['id'].index(id)\n",
    "    return {k: data[k][idx] for k in ['text', 'input_ids', 'attention_mask']}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 230,
   "id": "f2b3019a-db5a-4e49-b145-4aedc5901bc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'mistralai/Mistral-7B-v0.1'\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "Token = {v: k for k, v in tokenizer.get_vocab().items()}\n",
    "\n",
    "def topk(v, k=40, aux=None):\n",
    "    # Takes in logits\n",
    "    #v = softmax(v.flatten())\n",
    "    v = v.flatten()\n",
    "    idxs = v.argsort()[-k:][::-1]\n",
    "    if aux:\n",
    "        ret = [(Token[i], v[i]) + tuple(aux[i]) for i in idxs]\n",
    "        return pd.DataFrame(ret, columns=['token', 'logit'] + list(range(len(aux[0]))))\n",
    "    else:\n",
    "        ret = [(Token[i], v[i]) for i in idxs]\n",
    "        return pd.DataFrame(ret, columns=['token', 'logit'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "ccc10874-48e0-4037-952f-bb2ab6f1e87b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s>|▁A|▁T|AM|▁plane|▁before|▁landing|▁at|▁Sant|os|▁Dum|ont|▁Airport|▁in|▁Rio|▁de|▁Jane|iro|.|▁L|AT|AM|▁Airlines|▁Brasil|,|▁formerly|▁T|AM|▁Airlines|▁(|Port|ug(uese)\n",
      "BASE vs FUTURES:\n",
      "   base      logit           ht      logit       t     logit\n",
      "0  uese  24.244495         uese  15.929122       ,  5.838762\n",
      "1   ese  16.892418           al   9.846756       .  5.667567\n",
      "2    al  15.194197           ug   9.845806    ▁and  4.972284\n",
      "3   ues  13.136635          ian   9.347416       ▁  4.647301\n",
      "4    ue  12.453605  ▁Portuguese   9.073431       -  4.646937\n",
      "5     :  12.266930          age   8.388430     ▁of  4.460546\n",
      "6     .  11.328085          ate   8.297344  <0x0A>  4.394995\n",
      "7   ies  11.047367           te   8.078945     ▁in  4.393285\n",
      "8    es  10.708273            ,   7.970072      ▁(  4.271605\n",
      "9   ess  10.263073     ▁Spanish   7.876432     ▁to  4.180522\n",
      "LOSS_HT: 0.023796604946255684\n",
      "LOSS_T: 13.652191162109375\n",
      "<s>|▁In|▁the|▁summer|time|,|▁cold|▁b|ors|cht|▁is|▁a|▁popular|▁alternative|▁to|▁the|▁a|fore|ment(ioned)\n",
      "BASE vs FUTURES:\n",
      "     base      logit     ht      logit       t     logit\n",
      "0   ioned  20.679462  ioned  15.171043       ,  6.712155\n",
      "1     ion  12.532279     ed   9.269307       .  6.370513\n",
      "2    ined  11.286148   ated   7.987817    ▁and  5.867963\n",
      "3      ed  11.148052  ▁soup   7.861340     ▁of  5.738740\n",
      "4  inated  10.722492      .   7.517274     ▁in  5.411478\n",
      "5    ient  10.433092     ▁“   7.422709       ▁  5.197626\n",
      "6      io  10.319535  ▁meal   7.420746     ▁to  5.101018\n",
      "7    oned  10.179482     on   7.283975       -  4.994463\n",
      "8    ated  10.077990      ,   7.042997      ▁(  4.862326\n",
      "9       o   9.844437    aid   6.990074  <0x0A>  4.827489\n",
      "LOSS_HT: 0.029042495414614677\n",
      "LOSS_T: 12.906168937683105\n",
      "<s>|▁In|▁contrast|,|▁a|▁wrist|▁fract|ure|▁occurs|▁when|▁one|▁of|▁the|▁bones|▁in|▁the|▁wrist|▁breaks|.|▁Sometimes|▁it|▁can|▁be|▁difficult|▁to|▁tell|▁the|▁difference|▁between|▁a|▁wrist|▁spr|ain|▁and|▁fract|ure|,|▁as|▁both|▁injuries|▁generate|▁similar|▁symptoms|▁and|▁are|▁caused|▁by|▁similar|▁accidents|▁â|||▁falls|▁on|▁an|▁out(stret)\n",
      "BASE vs FUTURES:\n",
      "         base      logit          ht      logit     t     logit\n",
      "0       stret  21.085928       stret  17.235571  ▁the  6.546217\n",
      "1          re  15.259857        ▁out  13.766916     ,  6.092660\n",
      "2           -  14.664658        ched  10.280686   ▁of  5.925056\n",
      "3        stre  12.824176       ▁hand  10.029988     .  5.796539\n",
      "4  ▁stretched  12.271562        ▁arm   9.897739    ▁a  5.738971\n",
      "5        ward  10.988081         ▁of   9.764143   ▁to  5.716238\n",
      "6          sp  10.738475         ▁to   9.606495   ▁in  5.681736\n",
      "7       reach  10.707052         ▁or   9.467995     ▁  5.648721\n",
      "8          st   9.869334  ▁stretched   9.239226  ▁and  5.601653\n",
      "9        held   9.100565        ▁the   9.216121  ▁for  4.910690\n",
      "LOSS_HT: 0.09414895623922348\n",
      "LOSS_T: 12.868233680725098\n",
      "<s>|▁An|▁after|sh|ave|▁is|▁a|▁lot|ion|,|▁gel|,|▁bal|m|,|▁powder|,|▁or|▁liquid|▁used|▁mainly|▁by|▁men|▁after|▁they|▁have|▁finished|▁sh|aving|.|▁It|▁may|▁contain|▁an|▁ant|ise(ptic)\n",
      "BASE vs FUTURES:\n",
      "    base      logit    ht      logit       t     logit\n",
      "0   ptic  23.082886  ptic  15.155146       ,  6.445726\n",
      "1    bor  15.710505  icro  10.354185       .  6.003810\n",
      "2     pt  14.863027    id   9.250516    ▁and  5.616667\n",
      "3      b  12.069398    ic   8.970553     ▁in  5.195916\n",
      "4    bum  10.696212     x   8.631854       ▁  5.146494\n",
      "5   ctic   9.552641     .   8.616474     ▁of  5.037909\n",
      "6  metic   9.262023    pt   8.514505     ▁to  4.868144\n",
      "7     iz   8.946691   ung   8.475479    ▁the  4.829064\n",
      "8    ipt   8.929862     ,   8.434600       -  4.801855\n",
      "9  ption   8.370919    ar   8.171955  <0x0A>  4.725385\n",
      "LOSS_HT: 0.05348242446780205\n",
      "LOSS_T: 12.685019493103027\n",
      "<s>|▁Left|-|s|ided|▁hem|ip|ares|is|.|▁After|▁a|▁stroke|▁in|▁the|▁right|▁hem|is|phere|▁the|▁patient|▁is|▁par|aly(zed)\n",
      "BASE vs FUTURES:\n",
      "   base      logit    ht      logit       t     logit\n",
      "0   zed  20.188988   zed  13.666735       ,  6.703173\n",
      "1   zes  12.747564    ed  11.036194       .  6.261438\n",
      "2    ed  12.449883   ▁or   8.763422    ▁and  5.669623\n",
      "3    ze  12.185181     .   8.734177       ▁  5.305880\n",
      "4     z  11.154886    ys   8.725369       -  5.159651\n",
      "5  zing  11.093264     y   8.694987     ▁in  5.137856\n",
      "6  ized  10.745083    ze   8.682026     ▁of  5.116590\n",
      "7    se  10.628230     ,   8.595679  <0x0A>  4.920159\n",
      "8  ised  10.072829  ized   8.452236     ▁to  4.899116\n",
      "9   ced  10.049596   ied   8.404403      ▁(  4.871181\n",
      "LOSS_HT: 0.2087990790605545\n",
      "LOSS_T: 12.477593421936035\n",
      "<s>|▁C|ay|enne|▁pepper|▁is|▁pure|▁ground|▁dried|▁ch|iles|.|▁|1|▁|▁Ch|ili|▁powder|▁is|▁a|▁blend|▁of|▁sp|ices|,|▁of|▁which|▁c|ay(enne)\n",
      "BASE vs FUTURES:\n",
      "    base      logit    ht      logit       t     logit\n",
      "0   enne  20.494986  enne  17.091803       ,  5.963952\n",
      "1  ennes  13.095238    ay  10.822346       .  5.683775\n",
      "2    ene  12.150144   ▁is  10.479982    ▁and  5.072532\n",
      "3   anne  11.372098    um   9.730822       ▁  4.693603\n",
      "4    enn  11.117987     a   9.677046       -  4.669097\n",
      "5     en   9.788445   ote   9.464359     ▁in  4.520266\n",
      "6     ne   9.032915  ▁and   9.252120  <0x0A>  4.415384\n",
      "7      -   8.518456     ,   9.161216     ▁of  4.399898\n",
      "8    een   8.199695     .   9.045592      ▁(  4.341356\n",
      "9      n   8.012783  amon   8.931628     ▁to  4.286759\n",
      "LOSS_HT: 0.030322182923555374\n",
      "LOSS_T: 12.239319801330566\n",
      "<s>|▁Pr|inceton|'|s|▁Word|Net|(|0|.|0|0|▁/|▁|0(▁votes)\n",
      "BASE vs FUTURES:\n",
      "       base      logit        ht      logit  t     logit\n",
      "0    ▁votes  26.362278    ▁votes  18.857811  0  8.393703\n",
      "1   ▁voters  18.318987     ▁vote  14.137239  1  7.406851\n",
      "2     ▁vote  17.123680    ▁Votes  12.709383  .  6.991696\n",
      "3   ▁voting  15.945398         0  12.030953  ,  6.787486\n",
      "4    ▁voted  15.507130    ▁voted  10.021310  2  6.774632\n",
      "5      ▁vot  14.797101  ▁ratings   9.064949  5  6.322122\n",
      "6    ▁Votes  14.541542         )   8.712883  3  6.261115\n",
      "7  ▁ratings  12.766026      vote   8.631129  4  6.163659\n",
      "8      vote  12.061988   ▁rating   7.793754  -  6.105463\n",
      "9        ▁v  12.039245   ▁voting   7.633617  9  6.030373\n",
      "LOSS_HT: 0.017529422417283058\n",
      "LOSS_T: 12.142498016357422\n",
      "<s>|▁In|▁spite|▁of|▁the|▁side|▁effects|▁discussed|▁above|,|▁le|c|ith|in|▁has|▁been|▁used|▁for|▁years|▁to|▁treat|▁a|▁number|▁of|▁diseases|▁including|:|▁(|7|)|▁|1|▁|▁Al|z(heimer)\n",
      "BASE vs FUTURES:\n",
      "     base      logit        ht      logit       t     logit\n",
      "0  heimer  18.982346    heimer  13.571813       ,  5.164254\n",
      "1    heim  15.241155       ine   8.138012       .  5.115744\n",
      "2     hem  11.317997      heim   7.752439    ▁and  4.334286\n",
      "3      he  11.241241         ▁   7.549911       ▁  4.186120\n",
      "4     iem   9.619236    ▁abuse   7.522547       -  4.107759\n",
      "5       ­   8.295248         ,   7.433541  <0x0A>  4.025543\n",
      "6    imer   8.233561       ias   7.310748     ▁of  3.809214\n",
      "7     her   8.005990      osis   7.192941     ▁in  3.776169\n",
      "8       .   7.971806  ▁disease   7.063903      ▁(  3.767641\n",
      "9     ham   7.485911       ▁to   6.933635     ▁to  3.569146\n",
      "LOSS_HT: 0.2354952096939087\n",
      "LOSS_T: 12.33509635925293\n",
      "<s>|▁C|ef|pod|ox|ime|▁is|▁an|▁oral|,|▁third|-|gener|ation|▁c|ep|hal|os|por|in|▁antib|i|otic|.|▁It|▁is|▁active|▁against|▁most|▁Gram|-|pos|itive|▁and|▁Gram|-|negative|▁organ|isms|.|▁Not|able|▁exceptions|▁include|▁P|se|ud|omon|as|▁aer(ugin)\n",
      "BASE vs FUTURES:\n",
      "    base      logit    ht      logit     t     logit\n",
      "0   ugin  21.852465  ugin  14.641742     .  4.603627\n",
      "1     og  15.722291  ▁aer  11.998550     ,  4.562419\n",
      "2     ug  14.971807    er  10.584324  ▁and  3.908187\n",
      "3   ogen  12.880161     .  10.399773   ▁of  3.832026\n",
      "4     ig  12.490062     ,  10.146489     -  3.586709\n",
      "5    gin  12.279133    es   9.631579   ▁in  3.504705\n",
      "6      u  12.235028    ob   9.547806   ▁to  3.483861\n",
      "7   igin  11.669842    ic   9.479141     ▁  3.395664\n",
      "8     ou  11.366109  ▁and   9.409467    ▁(  3.174887\n",
      "9  ugins  11.316910     a   9.335699  ▁for  3.123882\n",
      "LOSS_HT: 0.24056130647659302\n",
      "LOSS_T: 12.238716125488281\n",
      "<s>|▁celebr|ant|▁(|pl|ural|▁celebr|ants|)|▁|1|▁|▁A|▁person|▁who|▁off|ici|ates|▁at|▁a|▁religious|▁ceremony|,|▁especially|▁a|▁marriage|▁or|▁the|▁E(uchar)\n",
      "BASE vs FUTURES:\n",
      "     base      logit         ht      logit       t     logit\n",
      "0   uchar  19.861082      uchar  16.033653       .  4.968087\n",
      "1     uch  11.577105        ion  11.888482       -  4.171089\n",
      "2     lev  11.088056         at  10.658401       ,  4.096358\n",
      "3      uc  10.477900   ▁funeral  10.210421       ▁  4.017223\n",
      "4  ternal   9.335527        ial   9.956802  <0x0A>  3.674757\n",
      "5      uk   8.895060          u   9.864045       1  3.571160\n",
      "6      ur   8.818613         ic   9.533352    ▁and  3.364197\n",
      "7       .   8.770119        ▁of   9.461826       2  3.276031\n",
      "8    urch   8.710697  ▁ceremony   9.438564      ▁(  3.203153\n",
      "9     pic   8.243421         ▁e   9.298520       s  3.198308\n",
      "LOSS_HT: 0.0764482319355011\n",
      "LOSS_T: 12.018959045410156\n"
     ]
    }
   ],
   "source": [
    "for ind, val in zip(topk_ind, topk_val):\n",
    "    row = get_row(test, ids[ind[0]])\n",
    "    input_ids = row['input_ids'][:ind[1] + 2] # loss at seq idx n corresponds to forward pass at idx n+1\n",
    "    input_ids = input_ids.unsqueeze(0) # add batch dim\n",
    "    out = model_d['ht']({'input_ids': input_ids.to('cuda'), 'attention_mask': torch.ones(input_ids.shape).to('cuda')})\n",
    "    base_prev = out.logits[0, ind[1],:]\n",
    "    base = out.logits[0, ind[1] + 1,:]\n",
    "    future = out.future_logits[0, ind[1],:]\n",
    "    out_t = model_d['t']({'input_ids': input_ids.to('cuda'), 'attention_mask': torch.ones(input_ids.shape).to('cuda')})\n",
    "    future_t = out_t.future_logits[0, ind[1],:]\n",
    "    out_str = '|'.join(Token[i] for i in input_ids.cpu().flatten().numpy())\n",
    "    if ind[1] + 2 < 64:\n",
    "        out_str += '(' + Token[row['input_ids'][ind[1] + 2].item()] + ')'\n",
    "    print(out_str)\n",
    "    print('BASE vs FUTURES:')\n",
    "    print(pd.concat([\n",
    "        #topk(base_prev.cpu().numpy(), k=10).rename(columns={'token': 'base-1'}),\n",
    "        topk(base.cpu().numpy(), k=10).rename(columns={'token': 'base'}),\n",
    "        topk(future.cpu().detach().numpy(), k=10).rename(columns={'token': 'ht'}),\n",
    "        topk(future_t.cpu().detach().numpy(), k=10).rename(columns={'token': 't'})\n",
    "    ], axis=1))\n",
    "    print('LOSS_HT:', nn.CrossEntropyLoss()(future, torch.softmax(base, dim=0)).item())\n",
    "    print('LOSS_T:', nn.CrossEntropyLoss()(future_t, torch.softmax(base, dim=0)).item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "66aaa059-b3ad-4c82-8532-0691fc5b45ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_tokens(s):\n",
    "    tokens = tokenizer(s)['input_ids']\n",
    "    print('|'.join(Token[t] for t in tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "id": "cb16e141-9bcf-463b-b885-fe999b38e50a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s>|▁He|▁is|▁suffering|▁from|▁Al|z\n"
     ]
    }
   ],
   "source": [
    "print_tokens('He is suffering from Alz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "20efad8f-4e28-4c17-adb0-80da6cc6b33d",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"He is suffering from Alz\"\n",
    "input = tokenizer(prompt, return_tensors='pt').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "309d6d1c-8539-4159-b21d-b24bad7e789f",
   "metadata": {},
   "outputs": [],
   "source": [
    "out_ht = model_d['ht'](input)\n",
    "out_t = model_d['t'](input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "8f302957-10c4-461f-8bc9-6beed09d6ee6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>heimer</td>\n",
       "      <td>18.213142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>heim</td>\n",
       "      <td>14.982849</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>he</td>\n",
       "      <td>11.628552</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>hem</td>\n",
       "      <td>11.411487</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>iem</td>\n",
       "      <td>9.082198</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>her</td>\n",
       "      <td>8.648829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>imer</td>\n",
       "      <td>8.392187</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>ham</td>\n",
       "      <td>7.791391</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>ah</td>\n",
       "      <td>7.698751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>him</td>\n",
       "      <td>7.618369</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    token      logit\n",
       "0  heimer  18.213142\n",
       "1    heim  14.982849\n",
       "2      he  11.628552\n",
       "3     hem  11.411487\n",
       "4     iem   9.082198\n",
       "5     her   8.648829\n",
       "6    imer   8.392187\n",
       "7     ham   7.791391\n",
       "8      ah   7.698751\n",
       "9     him   7.618369"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(out_ht.logits.cpu().numpy()[0,-1,:], k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "74c0a930-88ee-474c-9763-a25fe4d81bc7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>heimer</td>\n",
       "      <td>15.253980</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>heim</td>\n",
       "      <td>10.111232</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>,</td>\n",
       "      <td>8.829756</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>▁and</td>\n",
       "      <td>8.012557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>.</td>\n",
       "      <td>7.978244</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>ar</td>\n",
       "      <td>7.766553</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>ine</td>\n",
       "      <td>7.724874</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>are</td>\n",
       "      <td>7.595370</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>▁D</td>\n",
       "      <td>7.429148</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ur</td>\n",
       "      <td>7.425232</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    token      logit\n",
       "0  heimer  15.253980\n",
       "1    heim  10.111232\n",
       "2       ,   8.829756\n",
       "3    ▁and   8.012557\n",
       "4       .   7.978244\n",
       "5      ar   7.766553\n",
       "6     ine   7.724874\n",
       "7     are   7.595370\n",
       "8      ▁D   7.429148\n",
       "9      ur   7.425232"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(out_ht.future_logits.cpu().numpy()[0,-2,:], k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "f1ee0bc3-bbae-4fa8-a969-0ac938ed04b8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>,</td>\n",
       "      <td>5.164260</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>.</td>\n",
       "      <td>5.115747</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>▁and</td>\n",
       "      <td>4.334296</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>▁</td>\n",
       "      <td>4.186111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-</td>\n",
       "      <td>4.107753</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>&lt;0x0A&gt;</td>\n",
       "      <td>4.025540</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>▁of</td>\n",
       "      <td>3.809211</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>▁in</td>\n",
       "      <td>3.776166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>▁(</td>\n",
       "      <td>3.767645</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>▁to</td>\n",
       "      <td>3.569150</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    token     logit\n",
       "0       ,  5.164260\n",
       "1       .  5.115747\n",
       "2    ▁and  4.334296\n",
       "3       ▁  4.186111\n",
       "4       -  4.107753\n",
       "5  <0x0A>  4.025540\n",
       "6     ▁of  3.809211\n",
       "7     ▁in  3.776166\n",
       "8      ▁(  3.767645\n",
       "9     ▁to  3.569150"
      ]
     },
     "execution_count": 118,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(out_t.future_logits.cpu().numpy()[0,-2,:], k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "681f5480-a1a4-4677-bef5-45a85a1102a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = model_d['ht'].future_neck.layers[0].weight.data.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "80b60d4b-a027-4b18-ba6d-b8068c3c95b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = model_d['ht'].base_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "id": "e721e906-ffcc-4233-8b0e-02f0e8f3284d",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = base_model(**input, output_hidden_states=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "id": "acd55244-352c-4a77-a570-0ccba6519f46",
   "metadata": {},
   "outputs": [],
   "source": [
    "h = out.hidden_states[32][0,-2,:].reshape((-1, 1)).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "id": "4c5ff815-3090-430b-a562-92a5d82ca116",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = out.hidden_states[0][0,-1,:].reshape((-1, 1)).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "id": "020a3559-bff7-4a25-994c-d828a781b5e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "E = base_model.model.embed_tokens.weight.data.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "be07b4fb-728e-4b6a-8182-f620b5c6bee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "D = base_model.lm_head.weight.data.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "id": "aba2ac89-24c3-4739-af1b-77bf3a91c9c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "Dpinv = np.linalg.pinv(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "c617b8f5-9765-4383-88de-219c44f01aae",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = D @ (A @ np.concatenate([h, t]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "id": "591bf895-a6bc-447c-87e1-ca3142d17e75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [1, 976, 28764, 24556], 'attention_mask': [1, 1, 1, 1]}"
      ]
     },
     "execution_count": 160,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer('Alzheimer')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "44e2f641-c83d-4f44-aae6-0a294ffa0677",
   "metadata": {},
   "outputs": [],
   "source": [
    "heimer = D[[24556],:] @ A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "id": "153d217b-0056-4950-9ba6-e2a8bda335d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[15.601598]], dtype=float32)"
      ]
     },
     "execution_count": 189,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "heimer @ np.concatenate([h, t])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "id": "d1c31b45-f840-4694-bc6a-f72219ac19b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_score = heimer[:,4096:] @ E.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "id": "4e96e3fb-5a5b-46e7-ab6f-0b07d0f8aebd",
   "metadata": {},
   "outputs": [],
   "source": [
    "h_score = heimer[:,:4096] @ D.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 195,
   "id": "1f85c260-dec3-40ce-846b-36ffbaaa866a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>heimer</td>\n",
       "      <td>0.104696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>heim</td>\n",
       "      <td>0.034372</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>▁/******/</td>\n",
       "      <td>0.034076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>BPACK</td>\n",
       "      <td>0.024642</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>pgfscope</td>\n",
       "      <td>0.023644</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>丶</td>\n",
       "      <td>0.023437</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>qpoint</td>\n",
       "      <td>0.022994</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>dhd</td>\n",
       "      <td>0.022988</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>:%.*</td>\n",
       "      <td>0.022927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ҽ</td>\n",
       "      <td>0.022492</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       token     logit\n",
       "0     heimer  0.104696\n",
       "1       heim  0.034372\n",
       "2  ▁/******/  0.034076\n",
       "3      BPACK  0.024642\n",
       "4   pgfscope  0.023644\n",
       "5          丶  0.023437\n",
       "6     qpoint  0.022994\n",
       "7        dhd  0.022988\n",
       "8       :%.*  0.022927\n",
       "9          ҽ  0.022492"
      ]
     },
     "execution_count": 195,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(h_score.flatten(), k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 187,
   "id": "cc1c760e-de9f-49ef-ae65-d82164f8e8eb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.18687496"
      ]
     },
     "execution_count": 187,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h_score[0,24556]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "id": "d4b1a595-1e5c-4bcc-860e-3c0002ed6f38",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[15.427982]], dtype=float32)"
      ]
     },
     "execution_count": 190,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "heimer[:,:4096] @ h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "id": "323ed76c-a035-4291-911b-cd95839eeb9e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>heimer</td>\n",
       "      <td>0.186875</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>heim</td>\n",
       "      <td>0.032063</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>iffs</td>\n",
       "      <td>0.019978</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>▁que</td>\n",
       "      <td>0.019484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>▁S</td>\n",
       "      <td>0.019187</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>&lt;0x0A&gt;</td>\n",
       "      <td>0.018621</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>▁vector</td>\n",
       "      <td>0.017973</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>ً</td>\n",
       "      <td>0.017171</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>ativity</td>\n",
       "      <td>0.017056</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ister</td>\n",
       "      <td>0.016837</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>▁​</td>\n",
       "      <td>0.016629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>▁Ron</td>\n",
       "      <td>0.016616</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>&lt;0x8F&gt;</td>\n",
       "      <td>0.016589</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>ˈ</td>\n",
       "      <td>0.016547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>▁Frank</td>\n",
       "      <td>0.016363</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>cam</td>\n",
       "      <td>0.016220</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>▁Hur</td>\n",
       "      <td>0.016144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>σ</td>\n",
       "      <td>0.016063</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>ead</td>\n",
       "      <td>0.016021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>sen</td>\n",
       "      <td>0.016008</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      token     logit\n",
       "0    heimer  0.186875\n",
       "1      heim  0.032063\n",
       "2      iffs  0.019978\n",
       "3      ▁que  0.019484\n",
       "4        ▁S  0.019187\n",
       "5    <0x0A>  0.018621\n",
       "6   ▁vector  0.017973\n",
       "7         ً  0.017171\n",
       "8   ativity  0.017056\n",
       "9     ister  0.016837\n",
       "10       ▁​  0.016629\n",
       "11     ▁Ron  0.016616\n",
       "12   <0x8F>  0.016589\n",
       "13        ˈ  0.016547\n",
       "14   ▁Frank  0.016363\n",
       "15      cam  0.016220\n",
       "16     ▁Hur  0.016144\n",
       "17        σ  0.016063\n",
       "18      ead  0.016021\n",
       "19      sen  0.016008"
      ]
     },
     "execution_count": 180,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(h_score, k=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "489ee48f-7a73-485c-aaad-2a167e46eba5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>heimer</td>\n",
       "      <td>15.601603</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>heim</td>\n",
       "      <td>10.373470</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>,</td>\n",
       "      <td>8.233823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>ine</td>\n",
       "      <td>7.715692</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ar</td>\n",
       "      <td>7.697968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>are</td>\n",
       "      <td>7.584651</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>▁and</td>\n",
       "      <td>7.517624</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>.</td>\n",
       "      <td>7.445228</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>ur</td>\n",
       "      <td>7.419033</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>hen</td>\n",
       "      <td>7.400190</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>▁syndrome</td>\n",
       "      <td>7.302213</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>▁Park</td>\n",
       "      <td>7.276307</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>▁D</td>\n",
       "      <td>7.243225</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>osis</td>\n",
       "      <td>7.121342</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>ore</td>\n",
       "      <td>7.033943</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>pec</td>\n",
       "      <td>7.031528</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>z</td>\n",
       "      <td>6.996535</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>ism</td>\n",
       "      <td>6.949256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>ra</td>\n",
       "      <td>6.900755</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>▁in</td>\n",
       "      <td>6.803374</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>ia</td>\n",
       "      <td>6.757267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>ian</td>\n",
       "      <td>6.674518</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>yn</td>\n",
       "      <td>6.654577</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>▁disease</td>\n",
       "      <td>6.550241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>ins</td>\n",
       "      <td>6.447694</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>im</td>\n",
       "      <td>6.447405</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>ist</td>\n",
       "      <td>6.329059</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>▁de</td>\n",
       "      <td>6.311876</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>port</td>\n",
       "      <td>6.271644</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>ve</td>\n",
       "      <td>6.265914</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>'</td>\n",
       "      <td>6.241383</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>aria</td>\n",
       "      <td>6.218432</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>an</td>\n",
       "      <td>6.198898</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>oid</td>\n",
       "      <td>6.152793</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>my</td>\n",
       "      <td>6.131958</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>ic</td>\n",
       "      <td>6.126956</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>▁Al</td>\n",
       "      <td>6.058192</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>▁Sy</td>\n",
       "      <td>6.049191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>ment</td>\n",
       "      <td>6.031973</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>▁with</td>\n",
       "      <td>6.013769</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        token      logit\n",
       "0      heimer  15.601603\n",
       "1        heim  10.373470\n",
       "2           ,   8.233823\n",
       "3         ine   7.715692\n",
       "4          ar   7.697968\n",
       "5         are   7.584651\n",
       "6        ▁and   7.517624\n",
       "7           .   7.445228\n",
       "8          ur   7.419033\n",
       "9         hen   7.400190\n",
       "10  ▁syndrome   7.302213\n",
       "11      ▁Park   7.276307\n",
       "12         ▁D   7.243225\n",
       "13       osis   7.121342\n",
       "14        ore   7.033943\n",
       "15        pec   7.031528\n",
       "16          z   6.996535\n",
       "17        ism   6.949256\n",
       "18         ra   6.900755\n",
       "19        ▁in   6.803374\n",
       "20         ia   6.757267\n",
       "21        ian   6.674518\n",
       "22         yn   6.654577\n",
       "23   ▁disease   6.550241\n",
       "24        ins   6.447694\n",
       "25         im   6.447405\n",
       "26        ist   6.329059\n",
       "27        ▁de   6.311876\n",
       "28       port   6.271644\n",
       "29         ve   6.265914\n",
       "30          '   6.241383\n",
       "31       aria   6.218432\n",
       "32         an   6.198898\n",
       "33        oid   6.152793\n",
       "34         my   6.131958\n",
       "35         ic   6.126956\n",
       "36        ▁Al   6.058192\n",
       "37        ▁Sy   6.049191\n",
       "38       ment   6.031973\n",
       "39      ▁with   6.013769"
      ]
     },
     "execution_count": 177,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "id": "ce382ecb-3eac-4e8d-a032-642eef970067",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>z</td>\n",
       "      <td>31.079538</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ugno</td>\n",
       "      <td>24.306639</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>gorith</td>\n",
       "      <td>23.458277</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>cohol</td>\n",
       "      <td>22.437153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>typen</td>\n",
       "      <td>22.169466</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>nitt</td>\n",
       "      <td>20.925621</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>zh</td>\n",
       "      <td>20.354841</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>umin</td>\n",
       "      <td>18.984793</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>▁laug</td>\n",
       "      <td>18.823332</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>kap</td>\n",
       "      <td>18.677513</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    token      logit\n",
       "0       z  31.079538\n",
       "1    ugno  24.306639\n",
       "2  gorith  23.458277\n",
       "3   cohol  22.437153\n",
       "4   typen  22.169466\n",
       "5    nitt  20.925621\n",
       "6      zh  20.354841\n",
       "7    umin  18.984793\n",
       "8   ▁laug  18.823332\n",
       "9     kap  18.677513"
      ]
     },
     "execution_count": 192,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(Dpinv.T @ h, k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "id": "8e77dcba-ef23-4c8a-992f-e1955039e49a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>z</td>\n",
       "      <td>17.002439</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>cohol</td>\n",
       "      <td>12.438921</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ze</td>\n",
       "      <td>11.413498</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>zh</td>\n",
       "      <td>11.141702</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>port</td>\n",
       "      <td>11.036744</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>ve</td>\n",
       "      <td>10.207431</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>ien</td>\n",
       "      <td>9.540986</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>kap</td>\n",
       "      <td>9.415344</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>-</td>\n",
       "      <td>9.073519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>bin</td>\n",
       "      <td>8.962482</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   token      logit\n",
       "0      z  17.002439\n",
       "1  cohol  12.438921\n",
       "2     ze  11.413498\n",
       "3     zh  11.141702\n",
       "4   port  11.036744\n",
       "5     ve  10.207431\n",
       "6    ien   9.540986\n",
       "7    kap   9.415344\n",
       "8      -   9.073519\n",
       "9    bin   8.962482"
      ]
     },
     "execution_count": 200,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(D @ h, k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "id": "df1e51c8-98b1-4f98-8ad5-0f8b46a47861",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>heimer</td>\n",
       "      <td>0.186875</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>heim</td>\n",
       "      <td>0.032063</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>iffs</td>\n",
       "      <td>0.019978</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>▁que</td>\n",
       "      <td>0.019484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>▁S</td>\n",
       "      <td>0.019187</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>&lt;0x0A&gt;</td>\n",
       "      <td>0.018621</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>▁vector</td>\n",
       "      <td>0.017973</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>ً</td>\n",
       "      <td>0.017171</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>ativity</td>\n",
       "      <td>0.017056</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ister</td>\n",
       "      <td>0.016837</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     token     logit\n",
       "0   heimer  0.186875\n",
       "1     heim  0.032063\n",
       "2     iffs  0.019978\n",
       "3     ▁que  0.019484\n",
       "4       ▁S  0.019187\n",
       "5   <0x0A>  0.018621\n",
       "6  ▁vector  0.017973\n",
       "7        ً  0.017171\n",
       "8  ativity  0.017056\n",
       "9    ister  0.016837"
      ]
     },
     "execution_count": 231,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(heimer[:,:4096] @ Dpinv, k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04cdf7f3-df86-44b6-9a08-f361bd5485c1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "id": "32024306-a9de-4400-8286-15569bcd3c21",
   "metadata": {},
   "outputs": [],
   "source": [
    "h_score = (heimer[:,:4096] @ Dpinv).flatten() * (D @ h).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 214,
   "id": "759827e3-a9db-460d-ad89-596739ae63a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "aux = np.concatenate([(heimer[:,:4096] @ Dpinv).T, D @ h], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 221,
   "id": "afcfff71-cd95-44d4-a621-45731d9c5e3e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.441890125"
      ]
     },
     "execution_count": 221,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    ".186875*7.7158"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "id": "5308521b-836a-4093-96e5-9d1e174f3bed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>logit</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>heimer</td>\n",
       "      <td>1.441899</td>\n",
       "      <td>0.186875</td>\n",
       "      <td>7.715849</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>heim</td>\n",
       "      <td>0.246545</td>\n",
       "      <td>0.032063</td>\n",
       "      <td>7.689373</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>z</td>\n",
       "      <td>0.186114</td>\n",
       "      <td>0.010946</td>\n",
       "      <td>17.002439</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>ag</td>\n",
       "      <td>0.109592</td>\n",
       "      <td>0.014028</td>\n",
       "      <td>7.812175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>k</td>\n",
       "      <td>0.098598</td>\n",
       "      <td>0.012084</td>\n",
       "      <td>8.159696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>雅</td>\n",
       "      <td>0.096432</td>\n",
       "      <td>-0.013314</td>\n",
       "      <td>-7.243120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>&lt;0x0A&gt;</td>\n",
       "      <td>0.095901</td>\n",
       "      <td>0.018621</td>\n",
       "      <td>5.150013</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>▁Marcel</td>\n",
       "      <td>0.094632</td>\n",
       "      <td>-0.019627</td>\n",
       "      <td>-4.821455</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>лек</td>\n",
       "      <td>0.090533</td>\n",
       "      <td>-0.010441</td>\n",
       "      <td>-8.671231</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ган</td>\n",
       "      <td>0.089963</td>\n",
       "      <td>-0.012672</td>\n",
       "      <td>-7.099144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>▁Al</td>\n",
       "      <td>0.089942</td>\n",
       "      <td>0.013748</td>\n",
       "      <td>6.542246</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>▁S</td>\n",
       "      <td>0.087304</td>\n",
       "      <td>0.019187</td>\n",
       "      <td>4.550056</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>jsce</td>\n",
       "      <td>0.086272</td>\n",
       "      <td>-0.012469</td>\n",
       "      <td>-6.919111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>&lt;0xA0&gt;</td>\n",
       "      <td>0.086086</td>\n",
       "      <td>-0.013558</td>\n",
       "      <td>-6.349300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>c</td>\n",
       "      <td>0.082310</td>\n",
       "      <td>0.010584</td>\n",
       "      <td>7.777178</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>▁listade</td>\n",
       "      <td>0.080657</td>\n",
       "      <td>-0.011159</td>\n",
       "      <td>-7.228253</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>▁обо</td>\n",
       "      <td>0.079915</td>\n",
       "      <td>-0.009594</td>\n",
       "      <td>-8.329355</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>IsNull</td>\n",
       "      <td>0.079559</td>\n",
       "      <td>-0.012768</td>\n",
       "      <td>-6.230930</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>sen</td>\n",
       "      <td>0.079439</td>\n",
       "      <td>0.016008</td>\n",
       "      <td>4.962504</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>ں</td>\n",
       "      <td>0.079127</td>\n",
       "      <td>-0.011647</td>\n",
       "      <td>-6.793554</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       token     logit         0          1\n",
       "0     heimer  1.441899  0.186875   7.715849\n",
       "1       heim  0.246545  0.032063   7.689373\n",
       "2          z  0.186114  0.010946  17.002439\n",
       "3         ag  0.109592  0.014028   7.812175\n",
       "4          k  0.098598  0.012084   8.159696\n",
       "5          雅  0.096432 -0.013314  -7.243120\n",
       "6     <0x0A>  0.095901  0.018621   5.150013\n",
       "7    ▁Marcel  0.094632 -0.019627  -4.821455\n",
       "8        лек  0.090533 -0.010441  -8.671231\n",
       "9        ган  0.089963 -0.012672  -7.099144\n",
       "10       ▁Al  0.089942  0.013748   6.542246\n",
       "11        ▁S  0.087304  0.019187   4.550056\n",
       "12      jsce  0.086272 -0.012469  -6.919111\n",
       "13    <0xA0>  0.086086 -0.013558  -6.349300\n",
       "14         c  0.082310  0.010584   7.777178\n",
       "15  ▁listade  0.080657 -0.011159  -7.228253\n",
       "16      ▁обо  0.079915 -0.009594  -8.329355\n",
       "17    IsNull  0.079559 -0.012768  -6.230930\n",
       "18       sen  0.079439  0.016008   4.962504\n",
       "19         ں  0.079127 -0.011647  -6.793554"
      ]
     },
     "execution_count": 227,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(h_score, aux=aux, k=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "id": "8e171c8e-580e-4d1f-b3e5-0ecd883f8eae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.00014942093, -6.4162025)"
      ]
     },
     "execution_count": 218,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tuple(aux[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 223,
   "id": "0077a8b6-44aa-44fe-a2e8-076ab8718e54",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 225,
   "id": "8e38a6ad-2a67-4a80-bc45-656d8f7331a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fb2f1946050>]"
      ]
     },
     "execution_count": 225,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAtEUlEQVR4nO3de3yU1YH/8e/MJDNJgFww5IbhJgpFJCCUNFKtlmhES+vu9lVWXKG04mphX2haK/ECpd011ipL16KsWmT72yqgL8VuoVgbjKwapVxSQRFFwGSBhJvJhCTkMnN+fyQZMiSBDJKcJPN5v17Pa2bOc57nOc9hMufLc5lxGGOMAAAALHHabgAAAAhvhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVkXYbkBn+P1+HTp0SAMGDJDD4bDdHAAA0AnGGFVVVSktLU1OZ8fHP3pFGDl06JDS09NtNwMAAJyH0tJSXXzxxR3O7xVhZMCAAZKadiY2NtZyawAAQGd4vV6lp6cHxvGO9Iow0nJqJjY2ljACAEAvc65LLLiAFQAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYFWv+KE8AADQNf6j4FN9UVOvH0wZrvSBMVbawJERAADC2EvbSvX8Owd07GSdtTYQRgAAgFWEEQAAYBVhBAAAWEUYAQAAVoUcRjZv3qzp06crLS1NDodD69at6/Sy77zzjiIiIjR+/PhQNwsAAPqokMNIdXW1MjIytHz58pCWq6io0KxZszR16tRQNwkAAPqwkL9nZNq0aZo2bVrIG7rrrrs0c+ZMuVyukI6mAACAvq1brhl5/vnntW/fPi1evLhT9evq6uT1eoMmAADQN3V5GPn000+1cOFC/fd//7ciIjp3ICY/P19xcXGBKT09vYtbCQAAbOnSMOLz+TRz5kwtWbJEl112WaeXy8vLU2VlZWAqLS3twlYCAACbuvS3aaqqqrR161bt2LFD8+fPlyT5/X4ZYxQREaE///nP+uY3v9lmOY/HI4/H05VNAwAAPUSXhpHY2Fjt3LkzqOypp57Spk2b9PLLL2v48OFduXkAANALhBxGTp48qb179wZe79+/X8XFxRo4cKCGDBmivLw8HTx4UL/73e/kdDo1duzYoOWTkpIUFRXVphwAAISnkMPI1q1bdd111wVe5+bmSpJmz56tVatW6fDhwyopKblwLQQAAH2awxhjbDfiXLxer+Li4lRZWanY2FjbzQEAoM+4+rFNKj1Rq1d/dJUmDEm4oOvu7PjNb9MAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArAo5jGzevFnTp09XWlqaHA6H1q1bd9b6r7zyiq6//noNGjRIsbGxysrK0uuvv36+7QUAAH1MyGGkurpaGRkZWr58eafqb968Wddff702bNigbdu26brrrtP06dO1Y8eOkBsLAAD6nohQF5g2bZqmTZvW6frLli0Lev3II4/otdde0//8z/9owoQJoW4eAAD0Md1+zYjf71dVVZUGDhzY3ZsGAAA9UMhHRr6sxx9/XCdPntT3vve9DuvU1dWprq4u8Nrr9XZH0wAAgAXdemTkhRde0JIlS7R27VolJSV1WC8/P19xcXGBKT09vRtbCQAAulO3hZHVq1frjjvu0Nq1a5WdnX3Wunl5eaqsrAxMpaWl3dRKAADQ3brlNM2LL76oH/zgB1q9erVuvvnmc9b3eDzyeDzd0DIAAGBbyGHk5MmT2rt3b+D1/v37VVxcrIEDB2rIkCHKy8vTwYMH9bvf/U5S06mZ2bNn69e//rUyMzNVVlYmSYqOjlZcXNwF2g0AANBbhXyaZuvWrZowYULgttzc3FxNmDBBixYtkiQdPnxYJSUlgfrPPPOMGhsbNW/ePKWmpgamBQsWXKBdAAAAvVnIR0auvfZaGWM6nL9q1aqg14WFhaFuAgAAhBF+mwYAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAgjJ3lBtluQxgBAAByOBzWtk0YAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVoUcRjZv3qzp06crLS1NDodD69atO+cyhYWFuvLKK+XxeDRy5EitWrXqPJoKAAD6opDDSHV1tTIyMrR8+fJO1d+/f79uvvlmXXfddSouLtY999yjO+64Q6+//nrIjQUAAH1PRKgLTJs2TdOmTet0/RUrVmj48OF64oknJElf+cpX9Pbbb+vf//3flZOTE+rmAQBAH9Pl14wUFRUpOzs7qCwnJ0dFRUVdvWkAANALhHxkJFRlZWVKTk4OKktOTpbX61Vtba2io6PbLFNXV6e6urrAa6/X29XNBAAAlvTIu2ny8/MVFxcXmNLT0203CQAAdJEuDyMpKSkqLy8PKisvL1dsbGy7R0UkKS8vT5WVlYGptLS0q5sJAAAs6fLTNFlZWdqwYUNQ2RtvvKGsrKwOl/F4PPJ4PF3dNAAA0AOEfGTk5MmTKi4uVnFxsaSmW3eLi4tVUlIiqemoxqxZswL177rrLu3bt08//elP9fHHH+upp57S2rVrde+9916YPQAAAL1ayGFk69atmjBhgiZMmCBJys3N1YQJE7Ro0SJJ0uHDhwPBRJKGDx+u9evX64033lBGRoaeeOIJPffcc9zWCwAAJJ3HaZprr71WxpgO57f37arXXnutduzYEeqmAABAGOiRd9MAAIDwQRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAAhjxthuAWEEAABIcljcNmEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFh1XmFk+fLlGjZsmKKiopSZmaktW7actf6yZcs0atQoRUdHKz09Xffee69OnTp1Xg0GAAB9S8hhZM2aNcrNzdXixYu1fft2ZWRkKCcnR0eOHGm3/gsvvKCFCxdq8eLF2r17t377299qzZo1euCBB7504wEAQO8XchhZunSp5s6dqzlz5mjMmDFasWKFYmJitHLlynbrv/vuu5oyZYpmzpypYcOG6YYbbtCtt956zqMpAAAgPIQURurr67Vt2zZlZ2efXoHTqezsbBUVFbW7zFVXXaVt27YFwse+ffu0YcMG3XTTTR1up66uTl6vN2gCAAB9U0QolY8dOyafz6fk5OSg8uTkZH388cftLjNz5kwdO3ZMX//612WMUWNjo+66666znqbJz8/XkiVLQmkaAADopbr8bprCwkI98sgjeuqpp7R9+3a98sorWr9+vX7xi190uExeXp4qKysDU2lpaVc3EwAAWBLSkZHExES5XC6Vl5cHlZeXlyslJaXdZR5++GHdfvvtuuOOOyRJV1xxhaqrq3XnnXfqwQcflNPZNg95PB55PJ5QmgYAAHqpkI6MuN1uTZw4UQUFBYEyv9+vgoICZWVltbtMTU1Nm8DhcrkkSaYn/FQgAACwKqQjI5KUm5ur2bNna9KkSZo8ebKWLVum6upqzZkzR5I0a9YsDR48WPn5+ZKk6dOna+nSpZowYYIyMzO1d+9ePfzww5o+fXoglAAAgPAVchiZMWOGjh49qkWLFqmsrEzjx4/Xxo0bAxe1lpSUBB0Jeeihh+RwOPTQQw/p4MGDGjRokKZPn65/+7d/u3B7AQAAei2H6QXnSrxer+Li4lRZWanY2FjbzQEAoM+Y8ugmHayo1WvzpigjPf6Crruz4ze/TQMAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACw6rzCyPLlyzVs2DBFRUUpMzNTW7ZsOWv9iooKzZs3T6mpqfJ4PLrsssu0YcOG82owAADoWyJCXWDNmjXKzc3VihUrlJmZqWXLliknJ0d79uxRUlJSm/r19fW6/vrrlZSUpJdfflmDBw/W559/rvj4+AvRfgAA0MuFHEaWLl2quXPnas6cOZKkFStWaP369Vq5cqUWLlzYpv7KlSt14sQJvfvuu4qMjJQkDRs27Mu1GgAA9Bkhnaapr6/Xtm3blJ2dfXoFTqeys7NVVFTU7jJ/+MMflJWVpXnz5ik5OVljx47VI488Ip/P1+F26urq5PV6gyYAANA3hRRGjh07Jp/Pp+Tk5KDy5ORklZWVtbvMvn379PLLL8vn82nDhg16+OGH9cQTT+hf//VfO9xOfn6+4uLiAlN6enoozQQAAL1Il99N4/f7lZSUpGeeeUYTJ07UjBkz9OCDD2rFihUdLpOXl6fKysrAVFpa2tXNBAAAloR0zUhiYqJcLpfKy8uDysvLy5WSktLuMqmpqYqMjJTL5QqUfeUrX1FZWZnq6+vldrvbLOPxeOTxeEJpGgAA6KVCOjLidrs1ceJEFRQUBMr8fr8KCgqUlZXV7jJTpkzR3r175ff7A2WffPKJUlNT2w0iAAAgvIR8miY3N1fPPvus/uu//ku7d+/W3Xffrerq6sDdNbNmzVJeXl6g/t13360TJ05owYIF+uSTT7R+/Xo98sgjmjdv3oXbCwAA0GuFfGvvjBkzdPToUS1atEhlZWUaP368Nm7cGLiotaSkRE7n6YyTnp6u119/Xffee6/GjRunwYMHa8GCBbr//vsv3F4AAIBey2GMMbYbcS5er1dxcXGqrKxUbGys7eYAANBnTHl0kw5W1Oq1eVOUkR5/Qdfd2fGb36YBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgBwOe9smjAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwKrzCiPLly/XsGHDFBUVpczMTG3ZsqVTy61evVoOh0O33HLL+WwWAAD0QSGHkTVr1ig3N1eLFy/W9u3blZGRoZycHB05cuSsyx04cEA/+clPdPXVV593YwEAQN8TchhZunSp5s6dqzlz5mjMmDFasWKFYmJitHLlyg6X8fl8uu2227RkyRKNGDHiSzUYAAD0LSGFkfr6em3btk3Z2dmnV+B0Kjs7W0VFRR0u9/Of/1xJSUn64Q9/2Knt1NXVyev1Bk0AAKBvCimMHDt2TD6fT8nJyUHlycnJKisra3eZt99+W7/97W/17LPPdno7+fn5iouLC0zp6emhNBMAAPQiXXo3TVVVlW6//XY9++yzSkxM7PRyeXl5qqysDEylpaVd2EoAAGBTRCiVExMT5XK5VF5eHlReXl6ulJSUNvU/++wzHThwQNOnTw+U+f3+pg1HRGjPnj265JJL2izn8Xjk8XhCaRoAAOilQjoy4na7NXHiRBUUFATK/H6/CgoKlJWV1ab+6NGjtXPnThUXFwemb3/727ruuutUXFzM6RcAABDakRFJys3N1ezZszVp0iRNnjxZy5YtU3V1tebMmSNJmjVrlgYPHqz8/HxFRUVp7NixQcvHx8dLUptyAAAQnkIOIzNmzNDRo0e1aNEilZWVafz48dq4cWPgotaSkhI5nXyxKwAA6ByHMcbYbsS5eL1excXFqbKyUrGxsbabAwBAnzHl0U06WFGrP8yfonEXx1/QdXd2/OYQBgAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgBAGDPG2G4CYQQAAEgOOaxtmzACAACsIowAAACrCCMAAMAqwggAAGHM/uWrhBEAACDJYe/61fMLI8uXL9ewYcMUFRWlzMxMbdmypcO6zz77rK6++molJCQoISFB2dnZZ60PAADCS8hhZM2aNcrNzdXixYu1fft2ZWRkKCcnR0eOHGm3fmFhoW699Va9+eabKioqUnp6um644QYdPHjwSzceAAD0fiGHkaVLl2ru3LmaM2eOxowZoxUrVigmJkYrV65st/7vf/97/ehHP9L48eM1evRoPffcc/L7/SooKPjSjQcAAF9OD/jOs9DCSH19vbZt26bs7OzTK3A6lZ2draKiok6to6amRg0NDRo4cGCHderq6uT1eoMmAADQN4UURo4dOyafz6fk5OSg8uTkZJWVlXVqHffff7/S0tKCAs2Z8vPzFRcXF5jS09NDaSYAAOhFuvVumkcffVSrV6/Wq6++qqioqA7r5eXlqbKyMjCVlpZ2YysBAEB3igilcmJiolwul8rLy4PKy8vLlZKSctZlH3/8cT366KP6y1/+onHjxp21rsfjkcfjCaVpAADgPJge8E0jIR0ZcbvdmjhxYtDFpy0Xo2ZlZXW43GOPPaZf/OIX2rhxoyZNmnT+rQUAAF3C5veMhHRkRJJyc3M1e/ZsTZo0SZMnT9ayZctUXV2tOXPmSJJmzZqlwYMHKz8/X5L0y1/+UosWLdILL7ygYcOGBa4t6d+/v/r3738BdwUAAPRGIYeRGTNm6OjRo1q0aJHKyso0fvx4bdy4MXBRa0lJiZzO0wdcnn76adXX1+u73/1u0HoWL16sn/3sZ1+u9QAAoNcLOYxI0vz58zV//vx25xUWFga9PnDgwPlsAgAAdINe9z0jAACgb3LI3kUjhBEAAGAVYQQAAFhFGAEAIIz1gEtGCCMAAMDu94wQRgAAgFWEEQAAYBVhBAAAWEUYAQAgjPGlZwAAoEfgAlYAABC2CCMAAMAqwggAAGHMNF80wm/TAAAAK+ob/ZIkd4S9SEAYAQAgjNX5CCMAAMASY8zpIyMuwggAAOhmDb7TXzLCkREAANDt6ptP0UgcGQEAABbU1vsCzzkyAgAAut3RqjpJUmJ/j1xObu0FAADd7OjJpjAyaIDHajsIIwAAhKnDFbWSpMT+bqvtIIwAABCm3vnsuCRp3MVxVttBGAEAIAydrGtU4cdHJEnXjUqy2hbCCAAAYeiZzftUVdeo4Yn9dOWQBKttIYwAABBmXis+qKfe3CtJyr3+Mjkt3kkjSRFWtw4AALrN7sNe/UfBp/rTrjJJ0rfGpWp6RprlVhFGAADos6pONWj34SrtKPlC64oPafdhryTJ6ZDmXTdSC6ZearmFTQgjAAD0AUeqTunDQ159dMirDw9V6sNDXn1+vCaoTqTLoevHJOtfvnmpvpIaa6mlbRFGAADo4Rp9fh2pqtPhylMqqzylw5W1TY/eUzr4Ra1KTtToRHV9u8umxkVpTGqsvjFqkKaPS1NCP7vfKdIewggAAN3oVINPVacaVXWqofmxURW19fqipkEV1c2PNfU6Xl2vo1V1OlJVpxPVdfKbs6/X4ZBGJPbT5WlxujwtVmPSYjUmNVYX9bf77aqdQRgBAKADfr9RbYOvaar3qabep5r6xuZHn6rrGlVd36iaOl/TY3NZ63ne2uDg0fqXckMR4XQoOTZKqXFRSolreYxWalyUhl4Uo+GJ/RTj7p3D+nm1evny5frVr36lsrIyZWRk6Mknn9TkyZM7rP/SSy/p4Ycf1oEDB3TppZfql7/8pW666abzbjQAoO/y+43qff6mqbFpamh+XtfYVN7QGDw/8NwXXL++0a+6dsqa6jVvp9Gn2ga/6hp8OtXg06kGv041NoWPusbzCw6d0d8ToQFRTVN8jFvx0ZFKiHEroZ9bCTGRSujn1qABHiUN8GjQAI8S+3ms34LbVUIOI2vWrFFubq5WrFihzMxMLVu2TDk5OdqzZ4+Sktp+g9u7776rW2+9Vfn5+frWt76lF154Qbfccou2b9+usWPHXpCdAIBwZYyR30iNfr98fqNGv5HP1/zoN8HlfqNGX1NZ69ft1vMb+fz+VvODyxtayn1N62r0GzX4/IH1N/ia5jX4jOpaBwFfB+GiVVBoPNf5CEuiIp2KcUcoOtKl/p4IRbubHmPcLvVr9djPHaF+Hpdi3E1lA6IiFBsdqQFREc0BJFL9PRFWfyW3p3EYY0L6V8/MzNRXv/pV/eY3v5Ek+f1+paen61/+5V+0cOHCNvVnzJih6upq/fGPfwyUfe1rX9P48eO1YsWKTm3T6/UqLi5OlZWVio3tOVf/AmhijJExkt8YGTU/GgXK/M0Dplq9bl3vzEdzRr2WAbe9+n4j+fxGxjQNjv5WyzaVq7ncBNUPvPZLPtOyfKv2+o18Rm3X6zfytXpsWcbnN4H1tn5s9J9el695wD8zHHQcBprLm8NFY/Pgf2a9vs7tcirS5ZA7whmYIl1OuV1OeVq/jmgqC3psfh7ZqsxzRv3ICKeiI12KjnQpKtIpT0TTY7S7qSza7VJUhKvPHpXoSp0dv0M6MlJfX69t27YpLy8vUOZ0OpWdna2ioqJ2lykqKlJubm5QWU5OjtatW9fhdurq6lRXVxd47fV6Q2km0CFjzOn/0bUaGHyBgSV4UGn5wPf7m/7n6W8efBr9/sAg1jTAKGg9/k6vt3mwOqOO74zBrdF3er3t1z890PnN6cG7ZZD2txq4/a0GYmOa9qFlUDameTsmOAC0DOqtB/QzAwJ6nginQxEuhyKcTrmcDkU4HacfOyp3tip3dVDe8rp5HREuhyJdzqbtOR2KcDWXOc8ICoFA4JDb5WoVLBxNoaK5LCh4uJxyOAgBfV1IYeTYsWPy+XxKTk4OKk9OTtbHH3/c7jJlZWXt1i8rK+twO/n5+VqyZEkoTUM3aRnMW87ZNvhaH3Y1gdcNZ7xuOSzb4As+fNvQfBi4dXlj8/nixuZDvg1+o4bGlkPBpw8Dt8zztVpHY+vD0L7T/5NsaDVwo+dwOiSHwxH8qKZHp8MhNT+2vHa0V8/pkNPRNDg6HJLL0fS6qVzN5Q65WtbVqtzZUjeonkNOpwLz2q5XgfKWx9PPm+o5WpW3DN5OR9Pg7XQ4FOlyyOV0tg0BHZV3Jhy4gstb+hToDXrkZbd5eXlBR1O8Xq/S09MttqjnMqbpfGzLFd6nr/b2qbahUbX1/qbyhtPldQ1NF2WdOuNirZbndY1NF3QFXzxmAud7+6LWg1OEs2nAaj2ItB5YXG0GoA7mBS2v5nJn06DYakAJ2qajaVBxOdqZd5b2BLbV8rrVoN2yHscZz08PqC11Hc0DaXvz2i7rPGM7bcLCmeVqWw8ApBDDSGJiolwul8rLy4PKy8vLlZKS0u4yKSkpIdWXJI/HI4+n598Xfb78fqOqukZV1jTIe6p5qm2Qt7ZRlbUNOlnXGLglrLrudLCoqfcFh43mkGHzP/tOh4LO30a6nIqMcASet8yLbD6M2/I8orl+yyFdtyv40G5ky/PWh3+bHyNdTf8bjGw5zNy8TITLEVQvcAi5+Xlguea6rf+HycAIAPaEFEbcbrcmTpyogoIC3XLLLZKaLmAtKCjQ/Pnz210mKytLBQUFuueeewJlb7zxhrKyss670T2NMUbe2kYdqqzV4cpaHfHW6Xh1vY6frNcXNfU6UV2vipp6VdQ2qKI5gHTFOXZ3hFMxbpdimi+4inFHND82TdGREYp2OxUV4VJU84VaUZEueSJd8kQ0PY9qfmx9vvbMC8QiXacvAONqcADAlxXyaZrc3FzNnj1bkyZN0uTJk7Vs2TJVV1drzpw5kqRZs2Zp8ODBys/PlyQtWLBA3/jGN/TEE0/o5ptv1urVq7V161Y988wzF3ZPulijz6/PT9To0/IqfXa0WgeOVetw5ammAFJxSrUNvpDXGR3ZdMtX/6gIxUVHBqYBURHq54lQf3eEYppvF4tpvqq7KWA4FR3Zqrx5XoTL2QV7DgBA1wo5jMyYMUNHjx7VokWLVFZWpvHjx2vjxo2Bi1RLSkrkdJ4eFK+66iq98MILeuihh/TAAw/o0ksv1bp163r8d4wcP1mn9/ad0Pv7j2vrgS/02dGT5/zym4H93EqNi1JybJQu6ufWwP5uXdTP3fQlNjFuxcc0B47mR0+Eq5v2BgCAnivk7xmxobu+Z8TnN/rzh2X6f+99rqJ9x9ucSomOdOnS5P4aOai/hiX20+D4aKXGRyktLlopcVGKiiRcAADQoku+Z6Qv++iQV7lri/VxWVWgbHTKAGUOH6jMERfp8rRYpSfE8KU3AABcYIQRSdtLvtD3V26R91SjBkRF6J++NlS3ZQ7RxQkxtpsGAECfF/ZhpLi0Qv/03Puqqfdp0tAEPTtrkhL6uW03CwCAsBH2YSTvlZ2qqffp6yMT9cysib3255cBAOitwvpe0NITNdp92CuX06HfzJxAEAEAwIKwDiN/PXBCkpRxcZziYzg1AwCADWEdRkpP1EqSLkseYLklAACEr7AOI2XeU5Kk5Ngoyy0BACB8hXUYqalvlCQNiOJaEQAAbAnrMOJr/rlbfuwNAAB7CCOSIggjAABYQxiR+Ip3AAAsIoxIcjkIIwAA2BLeYcRwzQgAALaFdxjhAlYAAKwjjIgwAgCATYQREUYAALCJMCIuYAUAwKbwDiNcwAoAgHVh/T3o3514sbJGXKQRg/rZbgoAAGErrMPIbZlDbTcBAICwF9anaQAAgH2EEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFW94ld7jTGSJK/Xa7klAACgs1rG7ZZxvCO9IoxUVVVJktLT0y23BAAAhKqqqkpxcXEdzneYc8WVHsDv9+vQoUMaMGCAHA7HBVuv1+tVenq6SktLFRsbe8HW2xfQN+2jXzpG37SPfukYfdO+vtQvxhhVVVUpLS1NTmfHV4b0iiMjTqdTF198cZetPzY2ttf/g3cV+qZ99EvH6Jv20S8do2/a11f65WxHRFpwASsAALCKMAIAAKwK6zDi8Xi0ePFieTwe203pceib9tEvHaNv2ke/dIy+aV849kuvuIAVAAD0XWF9ZAQAANhHGAEAAFYRRgAAgFWEEQAAYFVYh5Hly5dr2LBhioqKUmZmprZs2WK7SRfMz372MzkcjqBp9OjRgfmnTp3SvHnzdNFFF6l///76h3/4B5WXlweto6SkRDfffLNiYmKUlJSk++67T42NjUF1CgsLdeWVV8rj8WjkyJFatWpVd+xeSDZv3qzp06crLS1NDodD69atC5pvjNGiRYuUmpqq6OhoZWdn69NPPw2qc+LECd12222KjY1VfHy8fvjDH+rkyZNBdT744ANdffXVioqKUnp6uh577LE2bXnppZc0evRoRUVF6YorrtCGDRsu+P521rn65fvf/36b99CNN94YVKcv9kt+fr6++tWvasCAAUpKStItt9yiPXv2BNXpzr+fnvQ51Zm+ufbaa9u8b+66666gOn2tb55++mmNGzcu8CVlWVlZ+tOf/hSYH67vl5CYMLV69WrjdrvNypUrzYcffmjmzp1r4uPjTXl5ue2mXRCLFy82l19+uTl8+HBgOnr0aGD+XXfdZdLT001BQYHZunWr+drXvmauuuqqwPzGxkYzduxYk52dbXbs2GE2bNhgEhMTTV5eXqDOvn37TExMjMnNzTUfffSRefLJJ43L5TIbN27s1n09lw0bNpgHH3zQvPLKK0aSefXVV4PmP/rooyYuLs6sW7fO/O1vfzPf/va3zfDhw01tbW2gzo033mgyMjLMe++9Z/73f//XjBw50tx6662B+ZWVlSY5OdncdtttZteuXebFF1800dHR5j//8z8Ddd555x3jcrnMY489Zj766CPz0EMPmcjISLNz584u74P2nKtfZs+ebW688cag99CJEyeC6vTFfsnJyTHPP/+82bVrlykuLjY33XSTGTJkiDl58mSgTnf9/fS0z6nO9M03vvENM3fu3KD3TWVlZWB+X+ybP/zhD2b9+vXmk08+MXv27DEPPPCAiYyMNLt27TLGhO/7JRRhG0YmT55s5s2bF3jt8/lMWlqayc/Pt9iqC2fx4sUmIyOj3XkVFRUmMjLSvPTSS4Gy3bt3G0mmqKjIGNM0UDmdTlNWVhao8/TTT5vY2FhTV1dnjDHmpz/9qbn88suD1j1jxgyTk5Nzgffmwjlz0PX7/SYlJcX86le/CpRVVFQYj8djXnzxRWOMMR999JGRZP76178G6vzpT38yDofDHDx40BhjzFNPPWUSEhICfWOMMffff78ZNWpU4PX3vvc9c/PNNwe1JzMz0/zzP//zBd3H89FRGPnOd77T4TLh0C/GGHPkyBEjybz11lvGmO79++npn1Nn9o0xTWFkwYIFHS4TLn2TkJBgnnvuOd4vnRSWp2nq6+u1bds2ZWdnB8qcTqeys7NVVFRksWUX1qeffqq0tDSNGDFCt912m0pKSiRJ27ZtU0NDQ9D+jx49WkOGDAnsf1FRka644golJycH6uTk5Mjr9erDDz8M1Gm9jpY6vakP9+/fr7KysqD9iIuLU2ZmZlBfxMfHa9KkSYE62dnZcjqdev/99wN1rrnmGrnd7kCdnJwc7dmzR1988UWgTm/rr8LCQiUlJWnUqFG6++67dfz48cC8cOmXyspKSdLAgQMldd/fT2/4nDqzb1r8/ve/V2JiosaOHau8vDzV1NQE5vX1vvH5fFq9erWqq6uVlZXF+6WTesUP5V1ox44dk8/nC/qHl6Tk5GR9/PHHllp1YWVmZmrVqlUaNWqUDh8+rCVLlujqq6/Wrl27VFZWJrfbrfj4+KBlkpOTVVZWJkkqKytrt39a5p2tjtfrVW1traKjo7to7y6cln1pbz9a72dSUlLQ/IiICA0cODCozvDhw9uso2VeQkJCh/3Vso6e5sYbb9Tf//3fa/jw4frss8/0wAMPaNq0aSoqKpLL5QqLfvH7/brnnns0ZcoUjR07VpK67e/niy++6NGfU+31jSTNnDlTQ4cOVVpamj744APdf//92rNnj1555RVJfbdvdu7cqaysLJ06dUr9+/fXq6++qjFjxqi4uJj3SyeEZRgJB9OmTQs8HzdunDIzMzV06FCtXbu2V4QE2PeP//iPgedXXHGFxo0bp0suuUSFhYWaOnWqxZZ1n3nz5mnXrl16++23bTelx+mob+68887A8yuuuEKpqamaOnWqPvvsM11yySXd3cxuM2rUKBUXF6uyslIvv/yyZs+erbfeest2s3qNsDxNk5iYKJfL1eZq5vLycqWkpFhqVdeKj4/XZZddpr179yolJUX19fWqqKgIqtN6/1NSUtrtn5Z5Z6sTGxvbawJPy76c7b2QkpKiI0eOBM1vbGzUiRMnLkh/9Zb33IgRI5SYmKi9e/dK6vv9Mn/+fP3xj3/Um2++qYsvvjhQ3l1/Pz35c6qjvmlPZmamJAW9b/pi37jdbo0cOVITJ05Ufn6+MjIy9Otf/5r3SyeFZRhxu92aOHGiCgoKAmV+v18FBQXKysqy2LKuc/LkSX322WdKTU3VxIkTFRkZGbT/e/bsUUlJSWD/s7KytHPnzqDB5o033lBsbKzGjBkTqNN6HS11elMfDh8+XCkpKUH74fV69f777wf1RUVFhbZt2xaos2nTJvn9/sAHbVZWljZv3qyGhoZAnTfeeEOjRo1SQkJCoE5v7q//+7//0/Hjx5Wamiqp7/aLMUbz58/Xq6++qk2bNrU5zdRdfz898XPqXH3TnuLiYkkKet/0xb45k9/vV11dXVi/X0Ji+wpaW1avXm08Ho9ZtWqV+eijj8ydd95p4uPjg65m7s1+/OMfm8LCQrN//37zzjvvmOzsbJOYmGiOHDlijGm61WzIkCFm06ZNZuvWrSYrK8tkZWUFlm+51eyGG24wxcXFZuPGjWbQoEHt3mp23333md27d5vly5f3yFt7q6qqzI4dO8yOHTuMJLN06VKzY8cO8/nnnxtjmm7tjY+PN6+99pr54IMPzHe+8512b+2dMGGCef/9983bb79tLr300qBbWCsqKkxycrK5/fbbza5du8zq1atNTExMm1tYIyIizOOPP252795tFi9ebPUW1rP1S1VVlfnJT35iioqKzP79+81f/vIXc+WVV5pLL73UnDp1KrCOvtgvd999t4mLizOFhYVBt6fW1NQE6nTX309P+5w6V9/s3bvX/PznPzdbt241+/fvN6+99poZMWKEueaaawLr6It9s3DhQvPWW2+Z/fv3mw8++MAsXLjQOBwO8+c//9kYE77vl1CEbRgxxpgnn3zSDBkyxLjdbjN58mTz3nvv2W7SBTNjxgyTmppq3G63GTx4sJkxY4bZu3dvYH5tba350Y9+ZBISEkxMTIz5u7/7O3P48OGgdRw4cMBMmzbNREdHm8TERPPjH//YNDQ0BNV58803zfjx443b7TYjRowwzz//fHfsXkjefPNNI6nNNHv2bGNM0+29Dz/8sElOTjYej8dMnTrV7NmzJ2gdx48fN7feeqvp37+/iY2NNXPmzDFVVVVBdf72t7+Zr3/968bj8ZjBgwebRx99tE1b1q5day677DLjdrvN5ZdfbtavX99l+30uZ+uXmpoac8MNN5hBgwaZyMhIM3ToUDN37tw2H2p9sV/a6xNJQe/t7vz76UmfU+fqm5KSEnPNNdeYgQMHGo/HY0aOHGnuu+++oO8ZMabv9c0PfvADM3ToUON2u82gQYPM1KlTA0HEmPB9v4TCYYwx3XccBgAAIFhYXjMCAAB6DsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAq/4/55sM182p1uAAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(sorted(list(h_score)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "861753ba-1a2e-4529-9133-13458c803908",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
