{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "68d040d2-bcbc-453a-bc25-d1e8345914c7",
   "metadata": {},
   "source": [
    "# Align Codex tokens to FlanT5 tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9259214c-72cc-414e-8853-3f68e1f01b93",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9d3d28f6-20a3-49ad-932d-6e63917e7686",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys \n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f2a543f8-d3ae-470d-82a5-23e67d17de87",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/llm/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import pickle\n",
    "import random\n",
    "import editdistance\n",
    "\n",
    "from tqdm import tqdm\n",
    "from transformers import T5Tokenizer, GPT2Tokenizer\n",
    "from src.utils import (parse_codex_outputs, \n",
    "                       parse_flan_t5_outputs, \n",
    "                       vis_prob_flow, \n",
    "                       vis_heatmap, \n",
    "                       test_acc, \n",
    "                       majority_vote_acc,\n",
    "                       ClosestToken,\n",
    "                       transform_codex_token_to_t5_token,\n",
    "                       print_transformed_probs\n",
    "                      )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "851c69bf-41c7-430b-98c2-2895f999a903",
   "metadata": {},
   "source": [
    "# Read Codex and Flan T5 predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "62476234-fe3c-49b3-a081-dcae662d1a16",
   "metadata": {},
   "outputs": [],
   "source": [
    "codex_questions = pickle.load(open('codex_questions.pkl', 'rb'))\n",
    "codex_answers = pickle.load(open('codex_answers.pkl', 'rb'))\n",
    "codex_predictions = pickle.load(open('codex_predictions.pkl', 'rb'))\n",
    "codex_per_step_probs = pickle.load(open('codex_per_step_probs.pkl', 'rb'))\n",
    "codex_prediction_labels = pickle.load(open('codex_prediction_labels.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "76b5c010-cfcb-47d2-905b-e7ee768f7371",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pickle.dump(codex_prediction_labels, open('codex_prediction_labels.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "23305687-b007-4d7f-89d1-ab7e75eb1415",
   "metadata": {},
   "outputs": [],
   "source": [
    "# codex_prediction_labels = pickle.load(open('codex_prediction_labels.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "fa71daa4-cf1d-4657-9d35-3d9c15def557",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 7473, pred 6262, acc 0.8379\n"
     ]
    }
   ],
   "source": [
    "_, codex_prediction_labels = majority_vote_acc(codex_predictions, codex_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "2bf01ca0-1e3e-42f8-84e8-690f04e18eaf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7473"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(codex_prediction_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "85d920b7-3a33-4993-9561-ce2308a92d49",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7473"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(codex_per_step_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "52675043-e893-4cc9-a6a5-e85b4d914e0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "flan_questions = pickle.load(open('flan_questions.pkl', 'rb'))\n",
    "flan_answers = pickle.load(open('flan_answers.pkl', 'rb'))\n",
    "flan_predictions = pickle.load(open('flan_predictions.pkl', 'rb'))\n",
    "flan_per_step_probs = pickle.load(open('flan_per_step_probs.pkl', 'rb'))\n",
    "flan_prediction_labels = pickle.load(open('flan_prediction_labels.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "7ee937f7-524c-44a7-b9ac-be2b279f73ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 7473, pred 1512, acc 0.2023\n"
     ]
    }
   ],
   "source": [
    "# _, flan_prediction_labels = majority_vote_acc(flan_predictions, flan_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "b4d2a139-9357-4026-a17f-fec2f1d78734",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pickle.dump(flan_prediction_labels, open('flan_prediction_labels.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "80ba6858-94a4-457c-8ac0-5398e889428f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7473"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(flan_per_step_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1c299874-93dc-40e0-a1a1-9a1d772946ce",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Model output 0: \\nIf Natalia sold 48 clips in April, she sold 48/2 = 24 clips in May.\\nIn total, Natalia sold 48 + 24 = 72 clips in April and May.\\nThe answer is 72\\n'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_predictions[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "e6b23c24-8dab-4600-a3cb-5af17be90ce7",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[' step', '\\n', 'If', ' Nat', 'alia', ' sold', ' 48', ' clips', ' in', ' April', ',', ' she', ' sold', ' 48', '/', '2', ' =', ' 24', ' clips', ' in', ' May', '.', '\\n', 'In', ' total', ',', ' Nat', 'alia', ' sold', ' 48', ' +', ' 24', ' =', ' 72', ' clips', ' in', ' April', ' and', ' May', '.', '\\n', 'The', ' answer', ' is', ' 72', '\\n', '\\n']\n"
     ]
    }
   ],
   "source": [
    "print([tk[0][0] for tk in codex_per_step_probs[0][0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a9bf9854-7464-4a77-b47e-5598a4ae3c05",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('▁In', 0.2814),\n",
       " ('▁Nat', 0.5805),\n",
       " ('▁In', 0.2814),\n",
       " ('▁She', 0.036),\n",
       " ('▁First', 0.0167),\n",
       " ('▁If', 0.0123)]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "flan_per_step_probs[0][0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "90fcffee-79dd-45b6-a894-80fd434bcd97",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xxl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e46a9325-cb23-4a13-a51c-f6871c3134a7",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'eos_token': '</s>',\n",
       " 'unk_token': '<unk>',\n",
       " 'pad_token': '<pad>',\n",
       " 'additional_special_tokens': ['<extra_id_0>',\n",
       "  '<extra_id_1>',\n",
       "  '<extra_id_2>',\n",
       "  '<extra_id_3>',\n",
       "  '<extra_id_4>',\n",
       "  '<extra_id_5>',\n",
       "  '<extra_id_6>',\n",
       "  '<extra_id_7>',\n",
       "  '<extra_id_8>',\n",
       "  '<extra_id_9>',\n",
       "  '<extra_id_10>',\n",
       "  '<extra_id_11>',\n",
       "  '<extra_id_12>',\n",
       "  '<extra_id_13>',\n",
       "  '<extra_id_14>',\n",
       "  '<extra_id_15>',\n",
       "  '<extra_id_16>',\n",
       "  '<extra_id_17>',\n",
       "  '<extra_id_18>',\n",
       "  '<extra_id_19>',\n",
       "  '<extra_id_20>',\n",
       "  '<extra_id_21>',\n",
       "  '<extra_id_22>',\n",
       "  '<extra_id_23>',\n",
       "  '<extra_id_24>',\n",
       "  '<extra_id_25>',\n",
       "  '<extra_id_26>',\n",
       "  '<extra_id_27>',\n",
       "  '<extra_id_28>',\n",
       "  '<extra_id_29>',\n",
       "  '<extra_id_30>',\n",
       "  '<extra_id_31>',\n",
       "  '<extra_id_32>',\n",
       "  '<extra_id_33>',\n",
       "  '<extra_id_34>',\n",
       "  '<extra_id_35>',\n",
       "  '<extra_id_36>',\n",
       "  '<extra_id_37>',\n",
       "  '<extra_id_38>',\n",
       "  '<extra_id_39>',\n",
       "  '<extra_id_40>',\n",
       "  '<extra_id_41>',\n",
       "  '<extra_id_42>',\n",
       "  '<extra_id_43>',\n",
       "  '<extra_id_44>',\n",
       "  '<extra_id_45>',\n",
       "  '<extra_id_46>',\n",
       "  '<extra_id_47>',\n",
       "  '<extra_id_48>',\n",
       "  '<extra_id_49>',\n",
       "  '<extra_id_50>',\n",
       "  '<extra_id_51>',\n",
       "  '<extra_id_52>',\n",
       "  '<extra_id_53>',\n",
       "  '<extra_id_54>',\n",
       "  '<extra_id_55>',\n",
       "  '<extra_id_56>',\n",
       "  '<extra_id_57>',\n",
       "  '<extra_id_58>',\n",
       "  '<extra_id_59>',\n",
       "  '<extra_id_60>',\n",
       "  '<extra_id_61>',\n",
       "  '<extra_id_62>',\n",
       "  '<extra_id_63>',\n",
       "  '<extra_id_64>',\n",
       "  '<extra_id_65>',\n",
       "  '<extra_id_66>',\n",
       "  '<extra_id_67>',\n",
       "  '<extra_id_68>',\n",
       "  '<extra_id_69>',\n",
       "  '<extra_id_70>',\n",
       "  '<extra_id_71>',\n",
       "  '<extra_id_72>',\n",
       "  '<extra_id_73>',\n",
       "  '<extra_id_74>',\n",
       "  '<extra_id_75>',\n",
       "  '<extra_id_76>',\n",
       "  '<extra_id_77>',\n",
       "  '<extra_id_78>',\n",
       "  '<extra_id_79>',\n",
       "  '<extra_id_80>',\n",
       "  '<extra_id_81>',\n",
       "  '<extra_id_82>',\n",
       "  '<extra_id_83>',\n",
       "  '<extra_id_84>',\n",
       "  '<extra_id_85>',\n",
       "  '<extra_id_86>',\n",
       "  '<extra_id_87>',\n",
       "  '<extra_id_88>',\n",
       "  '<extra_id_89>',\n",
       "  '<extra_id_90>',\n",
       "  '<extra_id_91>',\n",
       "  '<extra_id_92>',\n",
       "  '<extra_id_93>',\n",
       "  '<extra_id_94>',\n",
       "  '<extra_id_95>',\n",
       "  '<extra_id_96>',\n",
       "  '<extra_id_97>',\n",
       "  '<extra_id_98>',\n",
       "  '<extra_id_99>']}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.special_tokens_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "301d61da-ea1a-480a-ba64-bff8c7f31bdb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Model output 0: \\nIf Natalia sold 48 clips in April, she sold 48/2 = 24 clips in May.\\nIn total, Natalia sold 48 + 24 = 72 clips in April and May.\\nThe answer is 72\\n'"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_predictions[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c2865281-db16-4bff-9b0c-23bdb8d27ae1",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[(' step', 1.0),\n",
       "  ('!', 0.0),\n",
       "  (' stepped', 0.0),\n",
       "  (' stepping', 0.0),\n",
       "  (' step', 1.0),\n",
       "  (' steps', 0.0)],\n",
       " [('\\n', 0.9862),\n",
       "  ('!', 0.0002),\n",
       "  ('.', 0.009),\n",
       "  ('\\n', 0.9862),\n",
       "  (' ', 0.001),\n",
       "  (':', 0.0025)],\n",
       " [('If', 0.1009),\n",
       "  ('If', 0.1009),\n",
       "  ('She', 0.0416),\n",
       "  ('The', 0.0394),\n",
       "  ('Nat', 0.3286),\n",
       "  ('In', 0.2685)],\n",
       " [(' Nat', 0.6018),\n",
       "  (' Nat', 0.6018),\n",
       "  (' in', 0.0261),\n",
       "  (' we', 0.0125),\n",
       "  (' 48', 0.0335),\n",
       "  (' she', 0.2588)],\n",
       " [('alia', 0.9968),\n",
       "  ('ale', 0.0001),\n",
       "  ('ilia', 0.0001),\n",
       "  ('lia', 0.0002),\n",
       "  (' sold', 0.0013),\n",
       "  ('alia', 0.9968)],\n",
       " [(' sold', 0.9573),\n",
       "  (' sells', 0.0217),\n",
       "  (' sold', 0.9573),\n",
       "  (' is', 0.0022),\n",
       "  (\"'s\", 0.0028),\n",
       "  (' had', 0.0024)],\n",
       " [(' 48', 0.4455),\n",
       "  (' clips', 0.3353),\n",
       "  (' half', 0.1318),\n",
       "  (' to', 0.0351),\n",
       "  (' 48', 0.4455),\n",
       "  (' twice', 0.0058)],\n",
       " [(' clips', 0.9248),\n",
       "  (' clips', 0.9248),\n",
       "  (' friends', 0.0062),\n",
       "  (' to', 0.0058),\n",
       "  (' of', 0.0109),\n",
       "  (' in', 0.0345)],\n",
       " [(' in', 0.7597),\n",
       "  (',', 0.0206),\n",
       "  (' to', 0.2018),\n",
       "  (' in', 0.7597),\n",
       "  (' and', 0.0031),\n",
       "  (' then', 0.0028)],\n",
       " [(' April', 0.9759),\n",
       "  (' May', 0.0039),\n",
       "  (' total', 0.0025),\n",
       "  (' the', 0.0067),\n",
       "  (' April', 0.9759),\n",
       "  (' apr', 0.0026)],\n",
       " [(',', 0.7709),\n",
       "  (',', 0.7709),\n",
       "  (' to', 0.0033),\n",
       "  (' and', 0.1443),\n",
       "  (' she', 0.0074),\n",
       "  (' then', 0.0602)],\n",
       " [(' she', 0.1407),\n",
       "  (' in', 0.0182),\n",
       "  (' and', 0.2126),\n",
       "  (' that', 0.0235),\n",
       "  (' she', 0.1407),\n",
       "  (' then', 0.5304)],\n",
       " [(' sold', 0.815),\n",
       "  (' must', 0.0433),\n",
       "  (' sold', 0.815),\n",
       "  (' will', 0.0162),\n",
       "  (' would', 0.0389),\n",
       "  (' then', 0.0111)],\n",
       " [(' 48', 0.556),\n",
       "  (' 24', 0.1033),\n",
       "  (' half', 0.1671),\n",
       "  (' 1', 0.0266),\n",
       "  (' (', 0.0178),\n",
       "  (' 48', 0.556)],\n",
       " [('/', 0.7438),\n",
       "  (' /', 0.0471),\n",
       "  ('/', 0.7438),\n",
       "  (' *', 0.0386),\n",
       "  (' x', 0.0391),\n",
       "  ('*', 0.0442)],\n",
       " [('2', 0.9811),\n",
       "  ('12', 0.0016),\n",
       "  ('1', 0.0023),\n",
       "  ('2', 0.9811),\n",
       "  ('4', 0.0085),\n",
       "  (' 2', 0.0019)],\n",
       " [(' =', 0.8517),\n",
       "  (' clips', 0.0251),\n",
       "  ('=', 0.099),\n",
       "  (' in', 0.0068),\n",
       "  (' or', 0.0085),\n",
       "  (' =', 0.8517)],\n",
       " [(' 24', 0.9896),\n",
       "  (' 12', 0.0005),\n",
       "  ('24', 0.007),\n",
       "  (' 24', 0.9896),\n",
       "  (' ', 0.0002),\n",
       "  (' 2', 0.0004)],\n",
       " [(' clips', 0.8775),\n",
       "  (' less', 0.0026),\n",
       "  (' clips', 0.8775),\n",
       "  (' in', 0.106),\n",
       "  (' more', 0.0026),\n",
       "  (' fewer', 0.0034)],\n",
       " [(' in', 0.9914),\n",
       "  (' during', 0.0008),\n",
       "  (' less', 0.001),\n",
       "  (' the', 0.0011),\n",
       "  (' to', 0.0013),\n",
       "  (' in', 0.9914)],\n",
       " [(' May', 0.9877),\n",
       "  (' May', 0.9877),\n",
       "  (' the', 0.0026),\n",
       "  (' March', 0.0006),\n",
       "  (' April', 0.0012),\n",
       "  (' may', 0.0054)],\n",
       " [('.', 0.8685),\n",
       "  (',', 0.0226),\n",
       "  (' since', 0.0052),\n",
       "  ('.', 0.8685),\n",
       "  ('\\n', 0.0827),\n",
       "  (' because', 0.0075)],\n",
       " [('\\n', 0.9843),\n",
       "  (' She', 0.0007),\n",
       "  (' So', 0.0015),\n",
       "  ('\\n', 0.9843),\n",
       "  (' ', 0.0052),\n",
       "  (' In', 0.001)],\n",
       " [('In', 0.2022),\n",
       "  ('So', 0.1476),\n",
       "  ('Therefore', 0.0765),\n",
       "  ('She', 0.0584),\n",
       "  ('The', 0.1467),\n",
       "  ('In', 0.2022)],\n",
       " [(' total', 0.8519),\n",
       "  (' both', 0.0057),\n",
       "  (' May', 0.0097),\n",
       "  (' total', 0.8519),\n",
       "  (' the', 0.0058),\n",
       "  (' April', 0.1015)],\n",
       " [(',', 0.8402),\n",
       "  (',', 0.8402),\n",
       "  (' Nat', 0.0197),\n",
       "  (' in', 0.0042),\n",
       "  (' she', 0.114),\n",
       "  (' then', 0.0114)],\n",
       " [(' Nat', 0.2935),\n",
       "  (' Nat', 0.2935),\n",
       "  (' in', 0.0108),\n",
       "  (' 48', 0.0141),\n",
       "  (' she', 0.5699),\n",
       "  (' then', 0.0726)],\n",
       " [('alia', 0.9996),\n",
       "  ('lia', 0.0001),\n",
       "  ('elia', 0.0002),\n",
       "  ('alias', 0.0),\n",
       "  ('ala', 0.0),\n",
       "  ('alia', 0.9996)],\n",
       " [(' sold', 0.9757),\n",
       "  (' sells', 0.0026),\n",
       "  (' sold', 0.9757),\n",
       "  (' has', 0.0071),\n",
       "  (' will', 0.0025),\n",
       "  (' then', 0.0048)],\n",
       " [(' 48', 0.9477),\n",
       "  (' clips', 0.0153),\n",
       "  (' 24', 0.0158),\n",
       "  (' a', 0.0032),\n",
       "  (' 48', 0.9477),\n",
       "  (' 72', 0.0048)],\n",
       " [(' +', 0.6164),\n",
       "  ('+', 0.3547),\n",
       "  (' +', 0.6164),\n",
       "  (' clips', 0.0236),\n",
       "  (' in', 0.0019),\n",
       "  (' April', 0.0005)],\n",
       " [(' 24', 0.992),\n",
       "  (' 12', 0.0002),\n",
       "  ('24', 0.0065),\n",
       "  (' 24', 0.992),\n",
       "  (' 2', 0.0002),\n",
       "  (' 48', 0.0002)],\n",
       " [(' =', 0.9853),\n",
       "  (',', 0.0002),\n",
       "  (' clips', 0.0072),\n",
       "  (' ', 0.0002),\n",
       "  ('=', 0.0066),\n",
       "  (' =', 0.9853)],\n",
       " [(' 72', 0.9951),\n",
       "  (' ', 0.0002),\n",
       "  (' 70', 0.0003),\n",
       "  ('72', 0.0026),\n",
       "  (' 7', 0.0001),\n",
       "  (' 72', 0.9951)],\n",
       " [(' clips', 0.9888),\n",
       "  (' clip', 0.001),\n",
       "  ('.', 0.0012),\n",
       "  (' clips', 0.9888),\n",
       "  ('\\n', 0.003),\n",
       "  (' total', 0.0018)],\n",
       " [(' in', 0.359),\n",
       "  ('.', 0.3778),\n",
       "  (' altogether', 0.0955),\n",
       "  ('\\n', 0.1112),\n",
       "  (' in', 0.359),\n",
       "  (' over', 0.0129)],\n",
       " [(' April', 0.8876),\n",
       "  (' both', 0.0564),\n",
       "  (' May', 0.0049),\n",
       "  (' total', 0.0139),\n",
       "  (' the', 0.0216),\n",
       "  (' April', 0.8876)],\n",
       " [(' and', 0.9928),\n",
       "  (',', 0.0004),\n",
       "  (' &', 0.0024),\n",
       "  (' +', 0.0024),\n",
       "  (' and', 0.9928),\n",
       "  (' plus', 0.0005)],\n",
       " [(' May', 0.9962),\n",
       "  (' May', 0.9962),\n",
       "  (' June', 0.0001),\n",
       "  (' in', 0.0023),\n",
       "  (' may', 0.0005),\n",
       "  (' then', 0.0001)],\n",
       " [('.', 0.8727),\n",
       "  ('.', 0.8727),\n",
       "  (' altogether', 0.004),\n",
       "  (' together', 0.0127),\n",
       "  ('\\n', 0.0838),\n",
       "  (' combined', 0.0231)],\n",
       " [('\\n', 0.995),\n",
       "  ('\\n', 0.995),\n",
       "  (' ', 0.0035),\n",
       "  (' The', 0.0005),\n",
       "  ('  ', 0.0002),\n",
       "  ('\\n\\n', 0.0002)],\n",
       " [('The', 0.9803),\n",
       "  ('\\n', 0.0085),\n",
       "  ('So', 0.001),\n",
       "  ('Therefore', 0.0009),\n",
       "  ('Answer', 0.0008),\n",
       "  ('The', 0.9803)],\n",
       " [(' answer', 0.9972),\n",
       "  (' number', 0.0003),\n",
       "  (' Answer', 0.0006),\n",
       "  (' total', 0.0005),\n",
       "  (' an', 0.0002),\n",
       "  (' answer', 0.9972)],\n",
       " [(' is', 0.999),\n",
       "  (':', 0.0001),\n",
       "  (' to', 0.0001),\n",
       "  (' in', 0.0001),\n",
       "  (' is', 0.999),\n",
       "  (' 72', 0.0004)],\n",
       " [(' 72', 0.9981),\n",
       "  (' 24', 0.0001),\n",
       "  (' ', 0.0004),\n",
       "  (' 48', 0.0001),\n",
       "  (' 7', 0.0001),\n",
       "  (' 72', 0.9981)],\n",
       " [('\\n', 0.9577),\n",
       "  ('.', 0.0077),\n",
       "  ('\\n', 0.9577),\n",
       "  (' ', 0.0018),\n",
       "  ('<|endoftext|>', 0.0083),\n",
       "  ('\\n\\n', 0.0214)],\n",
       " [('\\n', 0.973),\n",
       "  ('``', 0.0018),\n",
       "  ('\\n', 0.973),\n",
       "  (' ', 0.0083),\n",
       "  ('\"\"\"', 0.0046),\n",
       "  (\"''\", 0.0028)]]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_per_step_probs[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "10a48d61-3f71-488c-8277-f62508313b27",
   "metadata": {},
   "outputs": [],
   "source": [
    "closest_token = ClosestToken(tokenizer.get_vocab())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "eec19cb5-b800-4409-a494-ad66d2e2b470",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[(' step', 1.0),\n",
       "  ('!', 0.0),\n",
       "  (' stepped', 0.0),\n",
       "  (' stepping', 0.0),\n",
       "  (' step', 1.0),\n",
       "  (' steps', 0.0)],\n",
       " [('\\n', 0.9862),\n",
       "  ('!', 0.0002),\n",
       "  ('.', 0.009),\n",
       "  ('\\n', 0.9862),\n",
       "  (' ', 0.001),\n",
       "  (':', 0.0025)],\n",
       " [('If', 0.1009),\n",
       "  ('If', 0.1009),\n",
       "  ('She', 0.0416),\n",
       "  ('The', 0.0394),\n",
       "  ('Nat', 0.3286),\n",
       "  ('In', 0.2685)],\n",
       " [(' Nat', 0.6018),\n",
       "  (' Nat', 0.6018),\n",
       "  (' in', 0.0261),\n",
       "  (' we', 0.0125),\n",
       "  (' 48', 0.0335),\n",
       "  (' she', 0.2588)],\n",
       " [('alia', 0.9968),\n",
       "  ('ale', 0.0001),\n",
       "  ('ilia', 0.0001),\n",
       "  ('lia', 0.0002),\n",
       "  (' sold', 0.0013),\n",
       "  ('alia', 0.9968)],\n",
       " [(' sold', 0.9573),\n",
       "  (' sells', 0.0217),\n",
       "  (' sold', 0.9573),\n",
       "  (' is', 0.0022),\n",
       "  (\"'s\", 0.0028),\n",
       "  (' had', 0.0024)],\n",
       " [(' 48', 0.4455),\n",
       "  (' clips', 0.3353),\n",
       "  (' half', 0.1318),\n",
       "  (' to', 0.0351),\n",
       "  (' 48', 0.4455),\n",
       "  (' twice', 0.0058)],\n",
       " [(' clips', 0.9248),\n",
       "  (' clips', 0.9248),\n",
       "  (' friends', 0.0062),\n",
       "  (' to', 0.0058),\n",
       "  (' of', 0.0109),\n",
       "  (' in', 0.0345)],\n",
       " [(' in', 0.7597),\n",
       "  (',', 0.0206),\n",
       "  (' to', 0.2018),\n",
       "  (' in', 0.7597),\n",
       "  (' and', 0.0031),\n",
       "  (' then', 0.0028)],\n",
       " [(' April', 0.9759),\n",
       "  (' May', 0.0039),\n",
       "  (' total', 0.0025),\n",
       "  (' the', 0.0067),\n",
       "  (' April', 0.9759),\n",
       "  (' apr', 0.0026)],\n",
       " [(',', 0.7709),\n",
       "  (',', 0.7709),\n",
       "  (' to', 0.0033),\n",
       "  (' and', 0.1443),\n",
       "  (' she', 0.0074),\n",
       "  (' then', 0.0602)],\n",
       " [(' she', 0.1407),\n",
       "  (' in', 0.0182),\n",
       "  (' and', 0.2126),\n",
       "  (' that', 0.0235),\n",
       "  (' she', 0.1407),\n",
       "  (' then', 0.5304)],\n",
       " [(' sold', 0.815),\n",
       "  (' must', 0.0433),\n",
       "  (' sold', 0.815),\n",
       "  (' will', 0.0162),\n",
       "  (' would', 0.0389),\n",
       "  (' then', 0.0111)],\n",
       " [(' 48', 0.556),\n",
       "  (' 24', 0.1033),\n",
       "  (' half', 0.1671),\n",
       "  (' 1', 0.0266),\n",
       "  (' (', 0.0178),\n",
       "  (' 48', 0.556)],\n",
       " [('/', 0.7438),\n",
       "  (' /', 0.0471),\n",
       "  ('/', 0.7438),\n",
       "  (' *', 0.0386),\n",
       "  (' x', 0.0391),\n",
       "  ('*', 0.0442)],\n",
       " [('2', 0.9811),\n",
       "  ('12', 0.0016),\n",
       "  ('1', 0.0023),\n",
       "  ('2', 0.9811),\n",
       "  ('4', 0.0085),\n",
       "  (' 2', 0.0019)],\n",
       " [(' =', 0.8517),\n",
       "  (' clips', 0.0251),\n",
       "  ('=', 0.099),\n",
       "  (' in', 0.0068),\n",
       "  (' or', 0.0085),\n",
       "  (' =', 0.8517)],\n",
       " [(' 24', 0.9896),\n",
       "  (' 12', 0.0005),\n",
       "  ('24', 0.007),\n",
       "  (' 24', 0.9896),\n",
       "  (' ', 0.0002),\n",
       "  (' 2', 0.0004)],\n",
       " [(' clips', 0.8775),\n",
       "  (' less', 0.0026),\n",
       "  (' clips', 0.8775),\n",
       "  (' in', 0.106),\n",
       "  (' more', 0.0026),\n",
       "  (' fewer', 0.0034)],\n",
       " [(' in', 0.9914),\n",
       "  (' during', 0.0008),\n",
       "  (' less', 0.001),\n",
       "  (' the', 0.0011),\n",
       "  (' to', 0.0013),\n",
       "  (' in', 0.9914)],\n",
       " [(' May', 0.9877),\n",
       "  (' May', 0.9877),\n",
       "  (' the', 0.0026),\n",
       "  (' March', 0.0006),\n",
       "  (' April', 0.0012),\n",
       "  (' may', 0.0054)],\n",
       " [('.', 0.8685),\n",
       "  (',', 0.0226),\n",
       "  (' since', 0.0052),\n",
       "  ('.', 0.8685),\n",
       "  ('\\n', 0.0827),\n",
       "  (' because', 0.0075)],\n",
       " [('\\n', 0.9843),\n",
       "  (' She', 0.0007),\n",
       "  (' So', 0.0015),\n",
       "  ('\\n', 0.9843),\n",
       "  (' ', 0.0052),\n",
       "  (' In', 0.001)],\n",
       " [('In', 0.2022),\n",
       "  ('So', 0.1476),\n",
       "  ('Therefore', 0.0765),\n",
       "  ('She', 0.0584),\n",
       "  ('The', 0.1467),\n",
       "  ('In', 0.2022)],\n",
       " [(' total', 0.8519),\n",
       "  (' both', 0.0057),\n",
       "  (' May', 0.0097),\n",
       "  (' total', 0.8519),\n",
       "  (' the', 0.0058),\n",
       "  (' April', 0.1015)],\n",
       " [(',', 0.8402),\n",
       "  (',', 0.8402),\n",
       "  (' Nat', 0.0197),\n",
       "  (' in', 0.0042),\n",
       "  (' she', 0.114),\n",
       "  (' then', 0.0114)],\n",
       " [(' Nat', 0.2935),\n",
       "  (' Nat', 0.2935),\n",
       "  (' in', 0.0108),\n",
       "  (' 48', 0.0141),\n",
       "  (' she', 0.5699),\n",
       "  (' then', 0.0726)],\n",
       " [('alia', 0.9996),\n",
       "  ('lia', 0.0001),\n",
       "  ('elia', 0.0002),\n",
       "  ('alias', 0.0),\n",
       "  ('ala', 0.0),\n",
       "  ('alia', 0.9996)],\n",
       " [(' sold', 0.9757),\n",
       "  (' sells', 0.0026),\n",
       "  (' sold', 0.9757),\n",
       "  (' has', 0.0071),\n",
       "  (' will', 0.0025),\n",
       "  (' then', 0.0048)],\n",
       " [(' 48', 0.9477),\n",
       "  (' clips', 0.0153),\n",
       "  (' 24', 0.0158),\n",
       "  (' a', 0.0032),\n",
       "  (' 48', 0.9477),\n",
       "  (' 72', 0.0048)],\n",
       " [(' +', 0.6164),\n",
       "  ('+', 0.3547),\n",
       "  (' +', 0.6164),\n",
       "  (' clips', 0.0236),\n",
       "  (' in', 0.0019),\n",
       "  (' April', 0.0005)],\n",
       " [(' 24', 0.992),\n",
       "  (' 12', 0.0002),\n",
       "  ('24', 0.0065),\n",
       "  (' 24', 0.992),\n",
       "  (' 2', 0.0002),\n",
       "  (' 48', 0.0002)],\n",
       " [(' =', 0.9853),\n",
       "  (',', 0.0002),\n",
       "  (' clips', 0.0072),\n",
       "  (' ', 0.0002),\n",
       "  ('=', 0.0066),\n",
       "  (' =', 0.9853)],\n",
       " [(' 72', 0.9951),\n",
       "  (' ', 0.0002),\n",
       "  (' 70', 0.0003),\n",
       "  ('72', 0.0026),\n",
       "  (' 7', 0.0001),\n",
       "  (' 72', 0.9951)],\n",
       " [(' clips', 0.9888),\n",
       "  (' clip', 0.001),\n",
       "  ('.', 0.0012),\n",
       "  (' clips', 0.9888),\n",
       "  ('\\n', 0.003),\n",
       "  (' total', 0.0018)],\n",
       " [(' in', 0.359),\n",
       "  ('.', 0.3778),\n",
       "  (' altogether', 0.0955),\n",
       "  ('\\n', 0.1112),\n",
       "  (' in', 0.359),\n",
       "  (' over', 0.0129)],\n",
       " [(' April', 0.8876),\n",
       "  (' both', 0.0564),\n",
       "  (' May', 0.0049),\n",
       "  (' total', 0.0139),\n",
       "  (' the', 0.0216),\n",
       "  (' April', 0.8876)],\n",
       " [(' and', 0.9928),\n",
       "  (',', 0.0004),\n",
       "  (' &', 0.0024),\n",
       "  (' +', 0.0024),\n",
       "  (' and', 0.9928),\n",
       "  (' plus', 0.0005)],\n",
       " [(' May', 0.9962),\n",
       "  (' May', 0.9962),\n",
       "  (' June', 0.0001),\n",
       "  (' in', 0.0023),\n",
       "  (' may', 0.0005),\n",
       "  (' then', 0.0001)],\n",
       " [('.', 0.8727),\n",
       "  ('.', 0.8727),\n",
       "  (' altogether', 0.004),\n",
       "  (' together', 0.0127),\n",
       "  ('\\n', 0.0838),\n",
       "  (' combined', 0.0231)],\n",
       " [('\\n', 0.995),\n",
       "  ('\\n', 0.995),\n",
       "  (' ', 0.0035),\n",
       "  (' The', 0.0005),\n",
       "  ('  ', 0.0002),\n",
       "  ('\\n\\n', 0.0002)],\n",
       " [('The', 0.9803),\n",
       "  ('\\n', 0.0085),\n",
       "  ('So', 0.001),\n",
       "  ('Therefore', 0.0009),\n",
       "  ('Answer', 0.0008),\n",
       "  ('The', 0.9803)],\n",
       " [(' answer', 0.9972),\n",
       "  (' number', 0.0003),\n",
       "  (' Answer', 0.0006),\n",
       "  (' total', 0.0005),\n",
       "  (' an', 0.0002),\n",
       "  (' answer', 0.9972)],\n",
       " [(' is', 0.999),\n",
       "  (':', 0.0001),\n",
       "  (' to', 0.0001),\n",
       "  (' in', 0.0001),\n",
       "  (' is', 0.999),\n",
       "  (' 72', 0.0004)],\n",
       " [(' 72', 0.9981),\n",
       "  (' 24', 0.0001),\n",
       "  (' ', 0.0004),\n",
       "  (' 48', 0.0001),\n",
       "  (' 7', 0.0001),\n",
       "  (' 72', 0.9981)],\n",
       " [('\\n', 0.9577),\n",
       "  ('.', 0.0077),\n",
       "  ('\\n', 0.9577),\n",
       "  (' ', 0.0018),\n",
       "  ('<|endoftext|>', 0.0083),\n",
       "  ('\\n\\n', 0.0214)],\n",
       " [('\\n', 0.973),\n",
       "  ('``', 0.0018),\n",
       "  ('\\n', 0.973),\n",
       "  (' ', 0.0083),\n",
       "  ('\"\"\"', 0.0046),\n",
       "  (\"''\", 0.0028)]]"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_per_step_probs[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "70ceada1-9be4-48a0-b0e0-896e7c2b0968",
   "metadata": {},
   "outputs": [],
   "source": [
    "transferred_per_step_probs, mask, transform_result = transform_codex_token_to_t5_token(codex_per_step_probs[0][0], closest_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f66ee13e-2de8-42b8-ae1b-d63884cfffef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'blank_before_number': 6,\n",
       " 'blank_after_number': 3,\n",
       " 'blank_step_before_number': 0,\n",
       " 'blank_step_after_number': 1}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "eb4bfaf5-465c-49ff-a9d3-4c54d058b379",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6885e5a2-e9a4-47e2-a7a4-aa893de993f0",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[('If', 0.1009),\n",
       "  ('If', 0.1009),\n",
       "  ('She', 0.0416),\n",
       "  ('The', 0.0394),\n",
       "  ('▁at', 0.3286),\n",
       "  ('In', 0.2685)],\n",
       " [('▁Nat', 0.6018),\n",
       "  ('▁Nat', 0.6018),\n",
       "  ('▁in', 0.0261),\n",
       "  ('▁we', 0.0125),\n",
       "  ('48', 0.0335),\n",
       "  ('▁she', 0.2588)],\n",
       " [('alia', 0.9968),\n",
       "  ('▁le', 0.0001),\n",
       "  ('ilia', 0.0001),\n",
       "  ('la', 0.0002),\n",
       "  ('▁sold', 0.0013),\n",
       "  ('alia', 0.9968)],\n",
       " [('▁sold', 0.9573),\n",
       "  ('▁sell', 0.0217),\n",
       "  ('▁sold', 0.9573),\n",
       "  ('▁is', 0.0022),\n",
       "  ('s', 0.0028),\n",
       "  ('▁had', 0.0024)],\n",
       " [('▁48', 0.4455),\n",
       "  ('▁clips', 0.3353),\n",
       "  ('▁half', 0.1318),\n",
       "  ('▁to', 0.0351),\n",
       "  ('48', 0.4455),\n",
       "  ('▁twice', 0.0058)],\n",
       " [('▁clips', 0.9248),\n",
       "  ('▁clips', 0.9248),\n",
       "  ('▁friends', 0.0062),\n",
       "  ('▁to', 0.0058),\n",
       "  ('▁of', 0.0109),\n",
       "  ('▁in', 0.0345)],\n",
       " [('▁in', 0.7597),\n",
       "  (',', 0.0206),\n",
       "  ('▁to', 0.2018),\n",
       "  ('▁in', 0.7597),\n",
       "  ('▁and', 0.0031),\n",
       "  ('▁then', 0.0028)],\n",
       " [('▁April', 0.9759),\n",
       "  ('▁May', 0.0039),\n",
       "  ('▁total', 0.0025),\n",
       "  ('▁the', 0.0067),\n",
       "  ('▁April', 0.9759),\n",
       "  ('ar', 0.0026)],\n",
       " [(',', 0.7709),\n",
       "  (',', 0.7709),\n",
       "  ('▁to', 0.0033),\n",
       "  ('▁and', 0.1443),\n",
       "  ('▁she', 0.0074),\n",
       "  ('▁then', 0.0602)],\n",
       " [('▁she', 0.1407),\n",
       "  ('▁in', 0.0182),\n",
       "  ('▁and', 0.2126),\n",
       "  ('▁that', 0.0235),\n",
       "  ('▁she', 0.1407),\n",
       "  ('▁then', 0.5304)],\n",
       " [('▁sold', 0.815),\n",
       "  ('▁must', 0.0433),\n",
       "  ('▁sold', 0.815),\n",
       "  ('▁will', 0.0162),\n",
       "  ('▁would', 0.0389),\n",
       "  ('▁then', 0.0111)],\n",
       " [('▁48', 0.556),\n",
       "  ('▁24', 0.1033),\n",
       "  ('▁half', 0.1671),\n",
       "  ('▁1', 0.0266),\n",
       "  ('▁(', 0.0178),\n",
       "  ('48', 0.556)],\n",
       " [('▁', 0), ('▁', 0), ('▁', 0), ('▁', 0), ('▁', 0), ('▁', 0)],\n",
       " [('/', 0.7438),\n",
       "  ('/', 0.0471),\n",
       "  ('/', 0.7438),\n",
       "  ('▁*', 0.0386),\n",
       "  ('x', 0.0391),\n",
       "  ('*', 0.0442)],\n",
       " [('▁2', 0.9811),\n",
       "  ('12', 0.0016),\n",
       "  ('1', 0.0023),\n",
       "  ('2', 0.9811),\n",
       "  ('4', 0.0085),\n",
       "  ('▁2', 0.0019)],\n",
       " [('▁=', 0.8517),\n",
       "  ('▁clips', 0.0251),\n",
       "  ('=', 0.099),\n",
       "  ('▁in', 0.0068),\n",
       "  ('▁or', 0.0085),\n",
       "  ('=', 0.8517)],\n",
       " [('▁24', 0.9896),\n",
       "  ('▁12', 0.0005),\n",
       "  ('24', 0.007),\n",
       "  ('▁24', 0.9896),\n",
       "  ('▁', 0.0002),\n",
       "  ('▁2', 0.0004)],\n",
       " [('▁clips', 0.8775),\n",
       "  ('▁less', 0.0026),\n",
       "  ('▁clips', 0.8775),\n",
       "  ('▁in', 0.106),\n",
       "  ('▁more', 0.0026),\n",
       "  ('fewer', 0.0034)],\n",
       " [('▁in', 0.9914),\n",
       "  ('▁during', 0.0008),\n",
       "  ('▁less', 0.001),\n",
       "  ('▁the', 0.0011),\n",
       "  ('▁to', 0.0013),\n",
       "  ('▁in', 0.9914)],\n",
       " [('▁May', 0.9877),\n",
       "  ('▁May', 0.9877),\n",
       "  ('▁the', 0.0026),\n",
       "  ('▁March', 0.0006),\n",
       "  ('▁April', 0.0012),\n",
       "  ('▁may', 0.0054)],\n",
       " [('.', 0.8685),\n",
       "  (',', 0.0226),\n",
       "  ('▁since', 0.0052),\n",
       "  ('.', 0.8685),\n",
       "  ('▁', 0.0827),\n",
       "  ('▁because', 0.0075)],\n",
       " [('▁', 0.9843),\n",
       "  ('▁She', 0.0007),\n",
       "  ('▁So', 0.0015),\n",
       "  ('▁', 0.9843),\n",
       "  ('▁', 0.0052),\n",
       "  ('▁In', 0.001)],\n",
       " [('In', 0.2022),\n",
       "  ('So', 0.1476),\n",
       "  ('▁Therefore', 0.0765),\n",
       "  ('She', 0.0584),\n",
       "  ('The', 0.1467),\n",
       "  ('In', 0.2022)],\n",
       " [('▁total', 0.8519),\n",
       "  ('▁both', 0.0057),\n",
       "  ('▁May', 0.0097),\n",
       "  ('▁total', 0.8519),\n",
       "  ('▁the', 0.0058),\n",
       "  ('▁April', 0.1015)],\n",
       " [(',', 0.8402),\n",
       "  (',', 0.8402),\n",
       "  ('▁Nat', 0.0197),\n",
       "  ('▁in', 0.0042),\n",
       "  ('▁she', 0.114),\n",
       "  ('▁then', 0.0114)],\n",
       " [('▁Nat', 0.2935),\n",
       "  ('▁Nat', 0.2935),\n",
       "  ('▁in', 0.0108),\n",
       "  ('48', 0.0141),\n",
       "  ('▁she', 0.5699),\n",
       "  ('▁then', 0.0726)],\n",
       " [('alia', 0.9996),\n",
       "  ('la', 0.0001),\n",
       "  ('elia', 0.0002),\n",
       "  ('alia', 0.0),\n",
       "  ('▁la', 0.0),\n",
       "  ('alia', 0.9996)],\n",
       " [('▁sold', 0.9757),\n",
       "  ('▁sell', 0.0026),\n",
       "  ('▁sold', 0.9757),\n",
       "  ('▁has', 0.0071),\n",
       "  ('▁will', 0.0025),\n",
       "  ('▁then', 0.0048)],\n",
       " [('▁48', 0.9477),\n",
       "  ('▁clips', 0.0153),\n",
       "  ('▁24', 0.0158),\n",
       "  ('a', 0.0032),\n",
       "  ('48', 0.9477),\n",
       "  ('72', 0.0048)],\n",
       " [('▁+', 0.6164),\n",
       "  ('+', 0.3547),\n",
       "  ('+', 0.6164),\n",
       "  ('▁clips', 0.0236),\n",
       "  ('▁in', 0.0019),\n",
       "  ('▁April', 0.0005)],\n",
       " [('▁24', 0.992),\n",
       "  ('▁12', 0.0002),\n",
       "  ('24', 0.0065),\n",
       "  ('▁24', 0.992),\n",
       "  ('▁2', 0.0002),\n",
       "  ('48', 0.0002)],\n",
       " [('▁=', 0.9853),\n",
       "  (',', 0.0002),\n",
       "  ('▁clips', 0.0072),\n",
       "  ('▁', 0.0002),\n",
       "  ('=', 0.0066),\n",
       "  ('=', 0.9853)],\n",
       " [('▁72', 0.9951),\n",
       "  ('▁', 0.0002),\n",
       "  ('70', 0.0003),\n",
       "  ('72', 0.0026),\n",
       "  ('▁7', 0.0001),\n",
       "  ('72', 0.9951)],\n",
       " [('▁clips', 0.9888),\n",
       "  ('▁clip', 0.001),\n",
       "  ('.', 0.0012),\n",
       "  ('▁clips', 0.9888),\n",
       "  ('▁', 0.003),\n",
       "  ('▁total', 0.0018)],\n",
       " [('▁in', 0.359),\n",
       "  ('.', 0.3778),\n",
       "  ('▁altogether', 0.0955),\n",
       "  ('▁', 0.1112),\n",
       "  ('▁in', 0.359),\n",
       "  ('▁over', 0.0129)],\n",
       " [('▁April', 0.8876),\n",
       "  ('▁both', 0.0564),\n",
       "  ('▁May', 0.0049),\n",
       "  ('▁total', 0.0139),\n",
       "  ('▁the', 0.0216),\n",
       "  ('▁April', 0.8876)],\n",
       " [('▁and', 0.9928),\n",
       "  (',', 0.0004),\n",
       "  ('&', 0.0024),\n",
       "  ('+', 0.0024),\n",
       "  ('▁and', 0.9928),\n",
       "  ('▁plus', 0.0005)],\n",
       " [('▁May', 0.9962),\n",
       "  ('▁May', 0.9962),\n",
       "  ('▁June', 0.0001),\n",
       "  ('▁in', 0.0023),\n",
       "  ('▁may', 0.0005),\n",
       "  ('▁then', 0.0001)],\n",
       " [('.', 0.8727),\n",
       "  ('.', 0.8727),\n",
       "  ('▁altogether', 0.004),\n",
       "  ('▁together', 0.0127),\n",
       "  ('▁', 0.0838),\n",
       "  ('▁combined', 0.0231)],\n",
       " [('▁', 0.995),\n",
       "  ('▁', 0.995),\n",
       "  ('▁', 0.0035),\n",
       "  ('▁The', 0.0005),\n",
       "  ('▁', 0.0002),\n",
       "  ('▁', 0.0002)],\n",
       " [('The', 0.9803),\n",
       "  ('▁', 0.0085),\n",
       "  ('So', 0.001),\n",
       "  ('▁Therefore', 0.0009),\n",
       "  ('▁Answer', 0.0008),\n",
       "  ('The', 0.9803)],\n",
       " [('▁answer', 0.9972),\n",
       "  ('▁number', 0.0003),\n",
       "  ('▁Answer', 0.0006),\n",
       "  ('▁total', 0.0005),\n",
       "  ('▁an', 0.0002),\n",
       "  ('▁answer', 0.9972)],\n",
       " [('▁is', 0.999),\n",
       "  (':', 0.0001),\n",
       "  ('▁to', 0.0001),\n",
       "  ('▁in', 0.0001),\n",
       "  ('▁is', 0.999),\n",
       "  ('72', 0.0004)],\n",
       " [('▁72', 0.9981),\n",
       "  ('▁24', 0.0001),\n",
       "  ('▁', 0.0004),\n",
       "  ('48', 0.0001),\n",
       "  ('▁7', 0.0001),\n",
       "  ('72', 0.9981)],\n",
       " [('▁', 0.9577),\n",
       "  ('.', 0.0077),\n",
       "  ('▁', 0.9577),\n",
       "  ('▁', 0.0018),\n",
       "  ('▁context', 0.0083),\n",
       "  ('▁', 0.0214)],\n",
       " [('▁', 0.973),\n",
       "  ('▁', 0.0018),\n",
       "  ('▁', 0.973),\n",
       "  ('▁', 0.0083),\n",
       "  ('▁\"', 0.0046),\n",
       "  (\"'\", 0.0028)],\n",
       " [('</s>', 1.0), ('</s>', 1.0), ('▁', 0), ('▁', 0), ('▁', 0), ('▁', 0)]]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transferred_per_step_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "ea194a2b-6161-49b9-bd33-73f9ec587ae8",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁step',\n",
       " '▁',\n",
       " 'If',\n",
       " '▁Nat',\n",
       " 'alia',\n",
       " '▁sold',\n",
       " '48',\n",
       " '▁clips',\n",
       " '▁in',\n",
       " '▁April',\n",
       " ',',\n",
       " '▁she',\n",
       " '▁sold',\n",
       " '48',\n",
       " '/',\n",
       " '2',\n",
       " '=',\n",
       " '▁24',\n",
       " '▁clips',\n",
       " '▁in',\n",
       " '▁May',\n",
       " '.',\n",
       " '▁',\n",
       " 'In',\n",
       " '▁total',\n",
       " ',',\n",
       " '▁Nat',\n",
       " 'alia',\n",
       " '▁sold',\n",
       " '48',\n",
       " '+',\n",
       " '▁24',\n",
       " '=',\n",
       " '72',\n",
       " '▁clips',\n",
       " '▁in',\n",
       " '▁April',\n",
       " '▁and',\n",
       " '▁May',\n",
       " '.',\n",
       " '▁',\n",
       " 'The',\n",
       " '▁answer',\n",
       " '▁is',\n",
       " '72',\n",
       " '▁',\n",
       " '▁']"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(t[0][0] for t in transferred_per_step_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "0af26fd5-4405-47a2-8662-490fc2d91e97",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'If Natalia sold 48 clips in April, she sold 48 / 2 = 24 clips in May. In total, Natalia sold 48 + 24 = 72 clips in April and May. The answer is 72</s>'"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenizer.convert_tokens_to_ids(list(t[0][0] for t in transferred_per_step_probs)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "05f4e679-539c-460d-a376-419b96fa96df",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = 'abcdef'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "1d548a68-5861-4db1-8bd3-70b53d75def6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'abcd'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[:-2]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a0a7333-5b39-46c6-9e84-e1a51e21d4eb",
   "metadata": {},
   "source": [
    "# Transform the codex decoded tokens to be flan tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "221a7683-5776-4e46-9b7a-eea17b288f6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed_codex_per_step_probs = []\n",
    "transformed_mask = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "9eb7c186-99a5-4eb7-a0ae-8558d785fc50",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "672"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(transformed_codex_per_step_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ce64f0d9-16a3-4888-ad8c-6b8b4bbf6023",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7473/7473 [20:18<00:00,  6.13it/s]\n"
     ]
    }
   ],
   "source": [
    "for qi, q in tqdm(enumerate(codex_per_step_probs), total=len(codex_per_step_probs)):\n",
    "    transformed_q = []\n",
    "    transformed_m = []\n",
    "    for aid, ai in enumerate(q):\n",
    "        probs, mask, _ = transform_codex_token_to_t5_token(qi, aid, ai, closest_token)\n",
    "        transformed_q.append(probs)\n",
    "        transformed_m.append(mask)\n",
    "    transformed_codex_per_step_probs.append(transformed_q)\n",
    "    transformed_mask.append(transformed_m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "83813551-10c0-4b06-bac8-0656c31de667",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "If Tin works 10 hours, she will work 8 hours at her regular pay of $ 18.00 and 2 hours of overtime. The overtime pay is calculated by adding the regular pay plus half the regular pay. So, she will earn $ 18.00 + 1 / 2 ( 18.00 )= $ 18.00 + $ 9.00 = $ 27.00 per hour for her overtime pay. She will earn $ 18.00 x 8 = $ 144.00 for her regular pay and $ 27.00 x 2 = $ 54.00 for her overtime pay. In total, she will earn $ 144.00 + $ 54.00 = $ 198.00 for the day. If she works 10 hours every day for 5 days, she will earn $ 198.00 x 5 = $ 990.00 The answer is 990</s>\n"
     ]
    }
   ],
   "source": [
    "print_transformed_probs(tokenizer, transferred_per_step_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "fd934a46-7119-42e9-b435-3308241c4a0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(transformed_codex_per_step_probs, open('codex_transformed_per_step_probs.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "f4768227-05a1-42bf-b70c-40fd44c4d245",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(transformed_mask, open('codex_mask_after_transform.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9272f128-0adf-4234-936c-5f2d5bb615b3",
   "metadata": {},
   "source": [
    "# Use tokenizer to transform everything to index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "40f68da5-1208-4e02-a412-df74b78b87a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "questions = []\n",
    "for q in codex_questions:\n",
    "    q_ = q.split(':')[1:]\n",
    "    questions.append(':'.join(q_).strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "0c4dc615-face-4143-81ea-3dfa369daf17",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?'"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "questions[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "3ca5552b-205f-4795-a26b-704492833759",
   "metadata": {},
   "outputs": [],
   "source": [
    "codex_questions_idx = tokenizer(questions[:-1], return_attention_mask=False)['input_ids']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "4a611329-431d-4bb6-a65e-6db40e186506",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7473"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(codex_questions_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "2db18f65-8a60-48be-9b19-f4498ce80be0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(codex_questions_idx, open('codex_questions_idx.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "f312ab2d-059d-4bbc-8372-9cddee480add",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?</s>'"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(codex_questions_idx[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "e973bae8-8c06-4ca2-b635-fb1e375ba6c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May.\\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\\n#### 72\\n'"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_answers[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "d16a8f80-c847-4eeb-b685-fbb165b542af",
   "metadata": {},
   "outputs": [],
   "source": [
    "answers = []\n",
    "for a in codex_answers:\n",
    "    a_ = a.split(':')[1:]\n",
    "    answers.append(':'.join(a_).strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "b6a5e2c7-c53a-4903-b7b3-79b6e4406f20",
   "metadata": {},
   "outputs": [],
   "source": [
    "codex_answers_idx = tokenizer(answers[:-1], return_attention_mask=False)['input_ids']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "fb180c33-4ae1-4696-9438-e7be959d6417",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(codex_answers_idx, open('codex_answers_idx.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "25188a43-afae-46f8-aed9-e6a8b8fc0ddc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Natalia sold 48/2 = <unk> 48/2=24>>24 clips in May. Natalia sold 48+24 = <unk> 48+24=72>>72 clips altogether in April and May. #### 72</s>'"
      ]
     },
     "execution_count": 108,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(codex_answers_idx[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "115918b8-3b2d-4615-807a-4c73b5f4a3b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7473"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(transformed_codex_per_step_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b0031741-3f4f-49fb-807f-40b6a6953f62",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7473/7473 [03:39<00:00, 34.11it/s]\n"
     ]
    }
   ],
   "source": [
    "codex_per_step_probs_idx = []\n",
    "vocab = tokenizer.get_vocab()\n",
    "for q in tqdm(transformed_codex_per_step_probs):\n",
    "    q_ = []\n",
    "    for a in q:\n",
    "        a_ = []\n",
    "        for at in a:\n",
    "            at_ = []\n",
    "            for tk, p in at:\n",
    "                idx = vocab[tk]\n",
    "                at_.append((idx, p))\n",
    "            a_.append(at_)\n",
    "        q_.append(a_)\n",
    "    codex_per_step_probs_idx.append(q_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "71efdd2c-d83f-4f9b-a9e4-d1da5ab9fdb1",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁',\n",
       " 't',\n",
       " '▁30',\n",
       " '▁',\n",
       " ',',\n",
       " '▁An',\n",
       " 'ika',\n",
       " '▁is',\n",
       " '▁4',\n",
       " '▁',\n",
       " '/',\n",
       " '▁3',\n",
       " '▁the',\n",
       " 'age',\n",
       " '▁of',\n",
       " '▁add',\n",
       " 'e',\n",
       " '.',\n",
       " '▁',\n",
       " 'So',\n",
       " ',',\n",
       " '▁add',\n",
       " 'e',\n",
       " '▁is',\n",
       " '▁30',\n",
       " '▁',\n",
       " 'x',\n",
       " '▁3',\n",
       " '▁',\n",
       " '/',\n",
       " '▁4',\n",
       " '▁=',\n",
       " '▁22',\n",
       " '.',\n",
       " '5',\n",
       " '▁years',\n",
       " '▁old',\n",
       " '.',\n",
       " '▁',\n",
       " 'In',\n",
       " '▁15',\n",
       " '▁years',\n",
       " ',',\n",
       " '▁An',\n",
       " 'ika',\n",
       " '▁would',\n",
       " '▁be',\n",
       " '▁30',\n",
       " '▁+',\n",
       " '▁15',\n",
       " '▁=',\n",
       " '▁45',\n",
       " '▁years',\n",
       " '▁old',\n",
       " '.',\n",
       " '▁',\n",
       " 'And',\n",
       " '▁add',\n",
       " 'e',\n",
       " '▁would',\n",
       " '▁be',\n",
       " '▁22',\n",
       " '.',\n",
       " '5',\n",
       " '▁+',\n",
       " '▁15',\n",
       " '▁=',\n",
       " '▁37',\n",
       " '.',\n",
       " '5',\n",
       " '▁years',\n",
       " '▁old',\n",
       " '.',\n",
       " '▁',\n",
       " '▁Their',\n",
       " '▁average',\n",
       " 'age',\n",
       " '▁in',\n",
       " '▁15',\n",
       " '▁years',\n",
       " '▁would',\n",
       " '▁be',\n",
       " '▁(',\n",
       " '▁45',\n",
       " '▁+',\n",
       " '▁37',\n",
       " '.',\n",
       " '5',\n",
       " '▁',\n",
       " ')',\n",
       " '/',\n",
       " '▁2',\n",
       " '▁=',\n",
       " '▁41',\n",
       " '.',\n",
       " '25',\n",
       " '▁',\n",
       " 'The',\n",
       " '▁answer',\n",
       " '▁is',\n",
       " '▁41',\n",
       " '.',\n",
       " '25',\n",
       " '▁',\n",
       " '</s>']"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(tp[0][0] for tp in transformed_codex_per_step_probs[7472][39])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "781b9951-d8ab-4803-84a5-44b4350f4233",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁',\n",
       " 'If',\n",
       " '▁Nat',\n",
       " 'alia',\n",
       " '▁sold',\n",
       " '▁48',\n",
       " '▁clips',\n",
       " '▁in',\n",
       " '▁April',\n",
       " ',',\n",
       " '▁she',\n",
       " '▁sold',\n",
       " '▁48',\n",
       " '▁',\n",
       " '/',\n",
       " '▁2',\n",
       " '▁=',\n",
       " '▁24',\n",
       " '▁clips',\n",
       " '▁in',\n",
       " '▁May',\n",
       " '.',\n",
       " '▁',\n",
       " 'In',\n",
       " '▁total',\n",
       " ',',\n",
       " '▁Nat',\n",
       " 'alia',\n",
       " '▁sold',\n",
       " '▁48',\n",
       " '▁+',\n",
       " '▁24',\n",
       " '▁=',\n",
       " '▁72',\n",
       " '▁clips',\n",
       " '▁in',\n",
       " '▁April',\n",
       " '▁and',\n",
       " '▁May',\n",
       " '.',\n",
       " '▁',\n",
       " 'The',\n",
       " '▁answer',\n",
       " '▁is',\n",
       " '▁72',\n",
       " '▁',\n",
       " '</s>']"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_ids_to_tokens(tp[0][0] for tp in codex_per_step_probs_idx[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "bdb8cdca-c0ad-400c-ba72-b7839d77bcc0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'t 30, Anika is 4 / 3 theage of adde. So, adde is 30 x 3 / 4 = 22.5 years old. In 15 years, Anika would be 30 + 15 = 45 years old. And adde would be 22.5 + 15 = 37.5 years old.  Their averageage in 15 years would be ( 45 + 37.5 )/ 2 = 41.25 The answer is 41.25 </s>'"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tp[0][0] for tp in codex_per_step_probs_idx[7472][39])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "5bdd0ac4-d8ef-4ed3-9070-2b0e67536244",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' step\\nAt 30, Anika is 4/3 the age of Maddie.\\nSo, Maddie is 30 x 3/4 = 22.5 years old.\\nIn 15 years, Anika would be 30 + 15 = 45 years old.\\nAnd Maddie would be 22.5 + 15 = 37.5 years old.\\nTheir average age in 15 years would be (45 + 37.5) / 2 = 41.25\\nThe answer is 41.25\\n\\n'"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "''.join(tp[0][0] for tp in codex_per_step_probs[7472][39])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "10e4e710-0a37-4c41-a755-a501ab8496ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(codex_per_step_probs_idx, open('codex_per_step_probs_idx.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f704a70-beec-47a8-993e-e35137f15ae6",
   "metadata": {},
   "source": [
    "# TODO: use tokenizer to tokenize FlanT5 generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ce99b74-8457-4f0d-983d-0e44648158cd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "f6b602c2-6792-4aae-8779-0089691eaa62",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32100"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tokenizer.get_vocab())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce874b2d-efa1-4471-8222-8a9358fe6068",
   "metadata": {},
   "source": [
    "# Generation index permutation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1c5af8a8-674f-4cde-bcb6-85940f897ca6",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = list(range(len(transformed_codex_per_step_probs)))\n",
    "random.shuffle(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "8b12f8b8-36d2-4bb7-b280-d679ea16ef8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(idx, open('permuted_idx.pkl', 'wb'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
