{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6418b0d2-945d-4cc9-9ce2-baed6f2ba603",
   "metadata": {},
   "source": [
    "# Align Codex tokens to FlanT5 tokens using dynamic time wrapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "31e7ec24-b9a8-4d3f-a511-d8b45450929a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "18852c6c-c70b-46b3-90b3-0bbabe93b468",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys \n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "e05cb005-ab03-4296-a3bf-ebb519238c4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import pickle\n",
    "import random\n",
    "import editdistance\n",
    "\n",
    "from collections import OrderedDict\n",
    "from tqdm import tqdm\n",
    "from transformers import T5Tokenizer, GPT2Tokenizer\n",
    "from src.utils import (parse_codex_outputs, \n",
    "                       parse_flan_t5_outputs, \n",
    "                       vis_prob_flow, \n",
    "                       vis_heatmap, \n",
    "                       test_acc, \n",
    "                       majority_vote_acc,\n",
    "                       ClosestToken,\n",
    "                       transform_codex_token_to_t5_token,\n",
    "                       print_transformed_probs,\n",
    "                       dtw\n",
    "                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "fc647324-37e7-484f-bdce-73ea10278166",
   "metadata": {},
   "outputs": [],
   "source": [
    "codex_questions = pickle.load(open('../processed_data/codex_questions.pkl', 'rb'))\n",
    "codex_answers = pickle.load(open('../processed_data/codex_answers.pkl', 'rb'))\n",
    "codex_predictions = pickle.load(open('../processed_data/codex_predictions.pkl', 'rb'))\n",
    "codex_per_step_probs = pickle.load(open('../processed_data/codex_per_step_probs.pkl', 'rb'))\n",
    "codex_prediction_labels = pickle.load(open('../processed_data/codex_prediction_labels.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "id": "8d7bd10d-00f6-4355-afdb-4c91d364b35e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7473"
      ]
     },
     "execution_count": 144,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(codex_questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "756f8566-049c-45d2-9a79-a75b808caa33",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Question 3324: Monica and Sheila are twins. Their mother gave them $50 and told them to buy some toilet paper and spend the remainder on groceries. The toilet paper cost $12. They bought apples, butter, eggs, and a large ham for twice the cost of the toilet paper. Since they still had some leftover money, they called their mother and she gave them permission to buy whatever they wanted for themselves as long as they shared the money evenly. They saw some boots they really liked, but a pair of boots costs 3 times the amount they had left. How much more would Monica and Sheila each have to add of their own money to buy two pairs of boots?\\n'"
      ]
     },
     "execution_count": 122,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_questions[3324]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "ac4a0de5-533d-4723-a357-aeb56b49fc71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Question 3324: Monica and Sheila are twins. Their mother gave them $50 and told them to buy some toilet paper and spend the remainder on groceries. The toilet paper cost $12. They bought apples, butter, eggs, and a large ham for twice the cost of the toilet paper. Since they still had some leftover money, they called their mother and she gave them permission to buy whatever they wanted for themselves as long as they shared the money evenly. They saw some boots they really liked, but a pair of boots costs 3 times the amount they had left. How much more would Monica and Sheila each have to add of their own money to buy two pairs of boots?\\n'"
      ]
     },
     "execution_count": 123,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_questions[3325]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "05fe3c98-8fec-4069-bd25-d698bc9dfdb2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"Question 3325: Billy's family likes to keep their bicycles stored in the garage when they're not being used.  They own a total of 4 bicycles.  Each bicycle wheel has 10 spokes.  How many spokes are inside the garage?\\n\""
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_questions[3326]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "7bcc669d-ae98-4ffa-ad86-c427a79dea4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "codex_questions_ = list(codex_questions[:3325] + codex_questions[3326:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "id": "bbbf8c60-5953-45f5-b32f-63015e04c79c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Question 3326: The largest animal to have ever lived on earth is the blue whale.  The tongue of an adult blue whale can weigh 6000 pounds.  If one ton is 2000 pounds, how many tons can the tongue of an adult blue whale weigh?\\n'"
      ]
     },
     "execution_count": 140,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_questions_[3326]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "afcdadce-795c-4e34-a68a-baf5779a9f08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Answer: Each bicycle has 2 wheels, so there are a total of 4*2=<<4*2=8>>8 wheels in the garage as there are 4 bicycles.\\nSince each wheel has 10 spokes, this means there are 8*10=<<8*10=80>>80 spokes in total.\\n#### 80\\n'"
      ]
     },
     "execution_count": 125,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_answers[3325]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "8214e75c-ab54-4ce2-ba47-4e2e823e47df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Model output 1: \\nEach bicycle has two wheels, so there are a total of 4*2 = 8 wheels.\\nEach wheel has 10 spokes, so there are a total of 8*10 = 80 spokes in total.\\nThe answer is 80\\n'"
      ]
     },
     "execution_count": 129,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_predictions[3325][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "121a9f15-03d7-438d-b82b-6343dd9c445b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7473"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(codex_prediction_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "db910382-472e-4d98-b5ff-5f4efff503a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xxl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "49c65d85-aa25-4a9c-bb76-ff6b261ad272",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Model output 0: \\nIf Natalia sold 48 clips in April, she sold 48/2 = 24 clips in May.\\nIn total, Natalia sold 48 + 24 = 72 clips in April and May.\\nThe answer is 72\\n'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_predictions[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "8a8271bd-0fe3-4df4-977f-ff088326243f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = ''.join(codex_predictions[1][0].split(': ')[1:]).strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "2a6c7d7b-6be5-48cc-8001-3509f641e4e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Weng earns $12 for every hour of babysitting.\\nShe babysat for 50 minutes.\\n50 minutes is less than an hour, so she earned less than $12.\\nShe earned $12/60 * 50 = $10.\\nThe answer is 10'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "9d924cfe-939c-4700-b9b7-26738c42e087",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = tokenizer.convert_ids_to_tokens(tokenizer(pred)['input_ids'])[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "9f1489e9-a4eb-4f44-8e17-219c9cc0da52",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁We',\n",
       " 'ng',\n",
       " '▁earn',\n",
       " 's',\n",
       " '▁$12',\n",
       " '▁for',\n",
       " '▁every',\n",
       " '▁hour',\n",
       " '▁of',\n",
       " '▁baby',\n",
       " 's',\n",
       " 'i',\n",
       " 'tting',\n",
       " '.',\n",
       " '▁She',\n",
       " '▁baby',\n",
       " 's',\n",
       " 'at',\n",
       " '▁for',\n",
       " '▁50',\n",
       " '▁minutes',\n",
       " '.',\n",
       " '▁50',\n",
       " '▁minutes',\n",
       " '▁is',\n",
       " '▁less',\n",
       " '▁than',\n",
       " '▁an',\n",
       " '▁hour',\n",
       " ',',\n",
       " '▁so',\n",
       " '▁she',\n",
       " '▁earned',\n",
       " '▁less',\n",
       " '▁than',\n",
       " '▁$12',\n",
       " '.',\n",
       " '▁She',\n",
       " '▁earned',\n",
       " '▁$12',\n",
       " '/',\n",
       " '60',\n",
       " '▁*',\n",
       " '▁50',\n",
       " '▁=',\n",
       " '▁$10',\n",
       " '.',\n",
       " '▁The',\n",
       " '▁answer',\n",
       " '▁is',\n",
       " '▁10']"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "66010c2b-923e-44a7-bb6d-1cf85e9c3705",
   "metadata": {},
   "outputs": [],
   "source": [
    "steps = [s[0][0] for s in codex_per_step_probs[1][0][2:]]\n",
    "\n",
    "idx = -1\n",
    "while(steps[idx] == '\\n'): idx -= 1\n",
    "steps = steps[:idx + 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "e3292217-e0e7-4cfd-9536-8de6bb4114ce",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['W',\n",
       " 'eng',\n",
       " ' earns',\n",
       " ' $',\n",
       " '12',\n",
       " ' for',\n",
       " ' every',\n",
       " ' hour',\n",
       " ' of',\n",
       " ' babys',\n",
       " 'itting',\n",
       " '.',\n",
       " '\\n',\n",
       " 'She',\n",
       " ' babys',\n",
       " 'at',\n",
       " ' for',\n",
       " ' 50',\n",
       " ' minutes',\n",
       " '.',\n",
       " '\\n',\n",
       " '50',\n",
       " ' minutes',\n",
       " ' is',\n",
       " ' less',\n",
       " ' than',\n",
       " ' an',\n",
       " ' hour',\n",
       " ',',\n",
       " ' so',\n",
       " ' she',\n",
       " ' earned',\n",
       " ' less',\n",
       " ' than',\n",
       " ' $',\n",
       " '12',\n",
       " '.',\n",
       " '\\n',\n",
       " 'She',\n",
       " ' earned',\n",
       " ' $',\n",
       " '12',\n",
       " '/',\n",
       " '60',\n",
       " ' *',\n",
       " ' 50',\n",
       " ' =',\n",
       " ' $',\n",
       " '10',\n",
       " '.',\n",
       " '\\n',\n",
       " 'The',\n",
       " ' answer',\n",
       " ' is',\n",
       " ' 10']"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7931d279-802e-467c-b8bb-a6140c56608f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dist_fn(codex_a, flan_b):\n",
    "    a = codex_a.replace(' ', '')\n",
    "    b = flan_b.replace('▁', '')\n",
    "    dist = editdistance.eval(a, b)\n",
    "    return dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "3e368396-5f63-44a8-95a3-447881cfe6b9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' Nat'"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "steps[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "76ea69c2-5784-49ac-a29c-4ed2772b7345",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'▁Nat'"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "892dda9c-fe56-4fb9-9717-b6ff443c4532",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dist_fn(steps[1], tokens[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "b2e65141-ad40-44ed-aa31-0df2253ee7ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(45, 40)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(steps), len(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d590d9ee-7558-4f62-9ca3-67ef640bd1d3",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[(' step', 1.0),\n",
       "  ('!', 0.0),\n",
       "  (' stepped', 0.0),\n",
       "  (' stepping', 0.0),\n",
       "  (' step', 1.0),\n",
       "  (' steps', 0.0)],\n",
       " [('\\n', 0.9878),\n",
       "  ('.', 0.0067),\n",
       "  ('\\n', 0.9878),\n",
       "  (' ', 0.0011),\n",
       "  (':', 0.0027),\n",
       "  ('\\n\\n', 0.0003)],\n",
       " [('W', 0.3172),\n",
       "  ('We', 0.0438),\n",
       "  ('If', 0.1324),\n",
       "  ('She', 0.0625),\n",
       "  ('W', 0.3172),\n",
       "  ('Since', 0.0693)],\n",
       " [('eng', 0.9903),\n",
       "  ('eng', 0.9903),\n",
       "  ('en', 0.0034),\n",
       "  ('end', 0.0035),\n",
       "  ('ong', 0.0004),\n",
       "  ('ang', 0.0013)],\n",
       " [(' earns', 0.7358),\n",
       "  (' makes', 0.0283),\n",
       "  (' earns', 0.7358),\n",
       "  (' gets', 0.0301),\n",
       "  (' is', 0.0231),\n",
       "  (' earned', 0.0544)],\n",
       " [(' $', 0.8506),\n",
       "  (' 12', 0.1252),\n",
       "  (' money', 0.0031),\n",
       "  (' a', 0.0031),\n",
       "  (' 1', 0.0033),\n",
       "  (' $', 0.8506)],\n",
       " [('12', 0.9898),\n",
       "  ('12', 0.9898),\n",
       "  (' 12', 0.0016),\n",
       "  ('0', 0.0007),\n",
       "  ('1', 0.0053),\n",
       "  ('24', 0.0004)],\n",
       " [(' for', 0.1161),\n",
       "  ('/', 0.0625),\n",
       "  (' an', 0.4909),\n",
       "  (' for', 0.1161),\n",
       "  (' per', 0.2627),\n",
       "  (' every', 0.0137)],\n",
       " [(' every', 0.3063),\n",
       "  (' each', 0.2902),\n",
       "  (' an', 0.1321),\n",
       "  (' 1', 0.0732),\n",
       "  (' babys', 0.0787),\n",
       "  (' every', 0.3063)],\n",
       " [(' hour', 0.8344),\n",
       "  (' full', 0.0238),\n",
       "  (' hour', 0.8344),\n",
       "  (' 60', 0.0744),\n",
       "  (' 1', 0.0349),\n",
       "  (' one', 0.0137)],\n",
       " [(' of', 0.3404),\n",
       "  (',', 0.0729),\n",
       "  ('.', 0.1042),\n",
       "  (' of', 0.3404),\n",
       "  (' that', 0.0767),\n",
       "  (' she', 0.3329)],\n",
       " [(' babys', 0.9566),\n",
       "  (' babys', 0.9566),\n",
       "  (' baby', 0.0028),\n",
       "  (' sitting', 0.001),\n",
       "  (' her', 0.0011),\n",
       "  (' work', 0.0344)],\n",
       " [('itting', 0.9981),\n",
       "  ('iting', 0.0013),\n",
       "  ('itter', 0.0001),\n",
       "  ('itting', 0.9981),\n",
       "  ('it', 0.0002),\n",
       "  ('itt', 0.0002)],\n",
       " [('.', 0.6357),\n",
       "  (',', 0.1517),\n",
       "  ('.', 0.6357),\n",
       "  ('\\n', 0.0545),\n",
       "  (' that', 0.0259),\n",
       "  (' she', 0.0983)],\n",
       " [('\\n', 0.8175),\n",
       "  (' She', 0.0223),\n",
       "  (' So', 0.0265),\n",
       "  ('\\n', 0.8175),\n",
       "  (' ', 0.0203),\n",
       "  (' Since', 0.0209)],\n",
       " [('She', 0.2205),\n",
       "  ('If', 0.1238),\n",
       "  ('Yesterday', 0.0847),\n",
       "  ('She', 0.2205),\n",
       "  ('W', 0.073),\n",
       "  ('Since', 0.0929)],\n",
       " [(' babys', 0.3703),\n",
       "  (' worked', 0.0498),\n",
       "  (' babys', 0.3703),\n",
       "  (' just', 0.0872),\n",
       "  (' earned', 0.0952),\n",
       "  (' did', 0.2147)],\n",
       " [('at', 0.921),\n",
       "  ('itted', 0.0182),\n",
       "  ('itting', 0.0007),\n",
       "  ('at', 0.921),\n",
       "  ('it', 0.0026),\n",
       "  ('its', 0.0552)],\n",
       " [(' for', 0.9274),\n",
       "  (' 50', 0.0506),\n",
       "  (' a', 0.0028),\n",
       "  (' for', 0.9274),\n",
       "  (' only', 0.0024),\n",
       "  (' yesterday', 0.006)],\n",
       " [(' 50', 0.8729),\n",
       "  (' 50', 0.8729),\n",
       "  (' a', 0.0343),\n",
       "  (' 1', 0.0204),\n",
       "  (' 0', 0.009),\n",
       "  (' only', 0.0174)],\n",
       " [(' minutes', 0.9476),\n",
       "  ('/', 0.0227),\n",
       "  (' mins', 0.005),\n",
       "  (' minutes', 0.9476),\n",
       "  (' out', 0.0047),\n",
       "  (' min', 0.0069)],\n",
       " [('.', 0.3746),\n",
       "  (',', 0.2816),\n",
       "  ('.', 0.3746),\n",
       "  (' and', 0.0241),\n",
       "  (' which', 0.0437),\n",
       "  (' yesterday', 0.1726)],\n",
       " [('\\n', 0.8499),\n",
       "  (' There', 0.019),\n",
       "  ('\\n', 0.8499),\n",
       "  (' 50', 0.018),\n",
       "  (' ', 0.0161),\n",
       "  (' Since', 0.016)],\n",
       " [('50', 0.1482),\n",
       "  ('50', 0.1482),\n",
       "  ('If', 0.0569),\n",
       "  ('There', 0.112),\n",
       "  ('To', 0.0641),\n",
       "  ('Since', 0.0784)],\n",
       " [(' minutes', 0.8502),\n",
       "  (' /', 0.015),\n",
       "  ('/', 0.0656),\n",
       "  (' minutes', 0.8502),\n",
       "  (' is', 0.0133),\n",
       "  (' min', 0.0294)],\n",
       " [(' is', 0.6015),\n",
       "  (' /', 0.0437),\n",
       "  (' of', 0.021),\n",
       "  (' is', 0.6015),\n",
       "  (' can', 0.0196),\n",
       "  (' =', 0.1733)],\n",
       " [(' less', 0.1926),\n",
       "  (' less', 0.1926),\n",
       "  (' 50', 0.1168),\n",
       "  (' the', 0.0736),\n",
       "  (' 1', 0.1932),\n",
       "  (' equal', 0.0832)],\n",
       " [(' than', 0.9922),\n",
       "  (' an', 0.0002),\n",
       "  (' that', 0.0019),\n",
       "  (' than', 0.9922),\n",
       "  (' time', 0.0017),\n",
       "  (' then', 0.0026)],\n",
       " [(' an', 0.391),\n",
       "  (' a', 0.0082),\n",
       "  (' an', 0.391),\n",
       "  (' 60', 0.1204),\n",
       "  (' 1', 0.2775),\n",
       "  (' one', 0.1906)],\n",
       " [(' hour', 0.9975),\n",
       "  (' full', 0.0001),\n",
       "  (' hour', 0.9975),\n",
       "  (' entire', 0.0015),\n",
       "  (' hours', 0.0002),\n",
       "  (' 1', 0.0002)],\n",
       " [(',', 0.5736),\n",
       "  (',', 0.5736),\n",
       "  ('.', 0.2524),\n",
       "  (' and', 0.0148),\n",
       "  (' (', 0.0202),\n",
       "  (' so', 0.0807)],\n",
       " [(' so', 0.7408),\n",
       "  (' and', 0.0452),\n",
       "  (' meaning', 0.0181),\n",
       "  (' but', 0.1071),\n",
       "  (' so', 0.7408),\n",
       "  (' which', 0.0251)],\n",
       " [(' she', 0.519),\n",
       "  (' let', 0.0205),\n",
       "  (' it', 0.0666),\n",
       "  (' we', 0.1232),\n",
       "  (' W', 0.0929),\n",
       "  (' she', 0.519)],\n",
       " [(' earned', 0.3897),\n",
       "  (' didn', 0.0871),\n",
       "  (' will', 0.0619),\n",
       "  (' would', 0.0401),\n",
       "  (' earned', 0.3897),\n",
       "  (' did', 0.1004)],\n",
       " [(' less', 0.8115),\n",
       "  (' less', 0.8115),\n",
       "  (' 50', 0.0081),\n",
       "  (' a', 0.0153),\n",
       "  (' only', 0.0066),\n",
       "  (' $', 0.1121)],\n",
       " [(' than', 0.978),\n",
       "  (' money', 0.012),\n",
       "  (' that', 0.0022),\n",
       "  (' than', 0.978),\n",
       "  (' $', 0.0022),\n",
       "  (' then', 0.0014)],\n",
       " [(' $', 0.9346),\n",
       "  (' 12', 0.022),\n",
       "  (' the', 0.0106),\n",
       "  (' an', 0.0083),\n",
       "  (' one', 0.0065),\n",
       "  (' $', 0.9346)],\n",
       " [('12', 0.9992),\n",
       "  ('120', 0.0001),\n",
       "  ('12', 0.9992),\n",
       "  (' 12', 0.0003),\n",
       "  ('1', 0.0001),\n",
       "  ('10', 0.0001)],\n",
       " [('.', 0.7548),\n",
       "  (',', 0.0196),\n",
       "  ('.', 0.7548),\n",
       "  ('\\n', 0.0656),\n",
       "  (' in', 0.0057),\n",
       "  (' for', 0.1187)],\n",
       " [('\\n', 0.9791),\n",
       "  (' She', 0.0015),\n",
       "  ('\\n', 0.9791),\n",
       "  (' ', 0.0058),\n",
       "  (' We', 0.0012),\n",
       "  (' But', 0.0013)],\n",
       " [('She', 0.1063),\n",
       "  ('50', 0.1235),\n",
       "  ('We', 0.0914),\n",
       "  ('If', 0.0686),\n",
       "  ('To', 0.1049),\n",
       "  ('She', 0.1063)],\n",
       " [(' earned', 0.6041),\n",
       "  (' earns', 0.0552),\n",
       "  (' can', 0.0388),\n",
       "  (' babys', 0.0418),\n",
       "  (' earned', 0.6041),\n",
       "  (' did', 0.045)],\n",
       " [(' $', 0.3685),\n",
       "  (' 12', 0.1196),\n",
       "  (' less', 0.0552),\n",
       "  (' 50', 0.2338),\n",
       "  (' 1', 0.0818),\n",
       "  (' $', 0.3685)],\n",
       " [('12', 0.9741),\n",
       "  ('12', 0.9741),\n",
       "  ('0', 0.0034),\n",
       "  ('1', 0.0091),\n",
       "  ('6', 0.0034),\n",
       "  ('10', 0.0036)],\n",
       " [('/', 0.2705),\n",
       "  (' /', 0.0543),\n",
       "  ('/', 0.2705),\n",
       "  (' *', 0.064),\n",
       "  (' for', 0.3772),\n",
       "  (' per', 0.0815)],\n",
       " [('60', 0.9634),\n",
       "  ('hr', 0.0009),\n",
       "  ('1', 0.0142),\n",
       "  ('60', 0.9634),\n",
       "  (' 60', 0.0045),\n",
       "  ('hour', 0.0121)],\n",
       " [(' *', 0.3366),\n",
       "  (' *', 0.3366),\n",
       "  (' x', 0.051),\n",
       "  (' minutes', 0.3076),\n",
       "  (' =', 0.1227),\n",
       "  ('*', 0.0682)],\n",
       " [(' 50', 0.9881),\n",
       "  ('50', 0.0086),\n",
       "  (' 50', 0.9881),\n",
       "  (' minutes', 0.0005),\n",
       "  (' (', 0.0004),\n",
       "  (' $', 0.0004)],\n",
       " [(' =', 0.9169),\n",
       "  (',', 0.0168),\n",
       "  (' minutes', 0.0094),\n",
       "  ('=', 0.0102),\n",
       "  (' dollars', 0.0071),\n",
       "  (' =', 0.9169)],\n",
       " [(' $', 0.8299),\n",
       "  (' 12', 0.0701),\n",
       "  (' 50', 0.0154),\n",
       "  (' (', 0.0045),\n",
       "  (' $', 0.8299),\n",
       "  (' 10', 0.0511)],\n",
       " [('10', 0.9744),\n",
       "  ('12', 0.0188),\n",
       "  ('1', 0.0007),\n",
       "  ('6', 0.0027),\n",
       "  (' 10', 0.001),\n",
       "  ('10', 0.9744)],\n",
       " [('.', 0.3116),\n",
       "  (',', 0.016),\n",
       "  ('.', 0.3116),\n",
       "  ('\\n', 0.492),\n",
       "  (' in', 0.0116),\n",
       "  (' for', 0.0971)],\n",
       " [('\\n', 0.9534),\n",
       "  (' So', 0.0034),\n",
       "  ('\\n', 0.9534),\n",
       "  (' ', 0.0108),\n",
       "  ('00', 0.0053),\n",
       "  (' This', 0.0054)],\n",
       " [('The', 0.9161),\n",
       "  ('\\n', 0.0052),\n",
       "  ('So', 0.0101),\n",
       "  ('Therefore', 0.0083),\n",
       "  ('She', 0.0107),\n",
       "  ('The', 0.9161)],\n",
       " [(' answer', 0.9943),\n",
       "  (' amount', 0.0006),\n",
       "  (' total', 0.0006),\n",
       "  (' an', 0.0003),\n",
       "  (' answer', 0.9943),\n",
       "  (' correct', 0.0005)],\n",
       " [(' is', 0.9989),\n",
       "  (' in', 0.0001),\n",
       "  (' is', 0.9989),\n",
       "  (' was', 0.0001),\n",
       "  (' $', 0.0001),\n",
       "  (' 10', 0.0003)],\n",
       " [(' 10', 0.9625),\n",
       "  (' 12', 0.0024),\n",
       "  (' ten', 0.0037),\n",
       "  (' 1', 0.0004),\n",
       "  (' $', 0.0279),\n",
       "  (' 10', 0.9625)],\n",
       " [('\\n', 0.9511),\n",
       "  ('.', 0.0137),\n",
       "  ('\\n', 0.9511),\n",
       "  (' ', 0.0024),\n",
       "  ('<|endoftext|>', 0.0092),\n",
       "  ('\\n\\n', 0.019)],\n",
       " [('\\n', 0.9722),\n",
       "  ('``', 0.0023),\n",
       "  ('\\n', 0.9722),\n",
       "  (' ', 0.0068),\n",
       "  ('\"\"\"', 0.0055),\n",
       "  (\"''\", 0.0029)]]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_per_step_probs[1][0][:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "f16e8361-4565-4da7-abbe-11bd0da03a59",
   "metadata": {},
   "outputs": [],
   "source": [
    "matches, matrix_, mappings_series_1, mappings_series_2, matrix = dtw(steps, tokens, norm_func=dist_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "5f2b1722-48df-49e9-93ad-65f6cfa996d2",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'▁We' | 'W' \n",
      "'ng' | 'eng' \n",
      "'▁earn' | ' earns' \n",
      "'s' | ' $' \n",
      "'▁$12' | '12' \n",
      "'▁for' | ' for' \n",
      "'▁every' | ' every' \n",
      "'▁hour' | ' hour' \n",
      "'▁of' | ' of' \n",
      "'▁baby' | ' babys' \n",
      "'s' | ' babys' \n",
      "'i' | ' babys' \n",
      "'tting' | 'itting' \n",
      "'.' | '.' '\\n' \n",
      "'▁She' | 'She' \n",
      "'▁baby' | ' babys' \n",
      "'s' | 'at' \n",
      "'at' | 'at' \n",
      "'▁for' | ' for' \n",
      "'▁50' | ' 50' \n",
      "'▁minutes' | ' minutes' \n",
      "'.' | '.' '\\n' \n",
      "'▁50' | '50' \n",
      "'▁minutes' | ' minutes' \n",
      "'▁is' | ' is' \n",
      "'▁less' | ' less' \n",
      "'▁than' | ' than' \n",
      "'▁an' | ' an' \n",
      "'▁hour' | ' hour' \n",
      "',' | ',' \n",
      "'▁so' | ' so' \n",
      "'▁she' | ' she' \n",
      "'▁earned' | ' earned' \n",
      "'▁less' | ' less' \n",
      "'▁than' | ' than' \n",
      "'▁$12' | ' $' '12' \n",
      "'.' | '.' '\\n' \n",
      "'▁She' | 'She' \n",
      "'▁earned' | ' earned' \n",
      "'▁$12' | ' $' '12' \n",
      "'/' | '/' \n",
      "'60' | '60' \n",
      "'▁*' | ' *' \n",
      "'▁50' | ' 50' \n",
      "'▁=' | ' =' ' $' \n",
      "'▁$10' | '10' \n",
      "'.' | '.' '\\n' \n",
      "'▁The' | 'The' \n",
      "'▁answer' | ' answer' \n",
      "'▁is' | ' is' \n",
      "'▁10' | ' 10' \n"
     ]
    }
   ],
   "source": [
    "for i, mapped in enumerate(mappings_series_2):\n",
    "    print(repr(tokens[i]), end=' | ')\n",
    "    for j in mapped: print(repr(steps[j]), end = \" \")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8dbc05a3-4a2b-4722-bdee-4ff789d6fdeb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  0.,   3.,   7., ..., 112., 114., 116.],\n",
       "       [  3.,   0.,   4., ..., 112., 115., 117.],\n",
       "       [  7.,   4.,   0., ..., 111., 114., 118.],\n",
       "       ...,\n",
       "       [122., 123., 124., ...,  15.,   6.,   4.],\n",
       "       [124., 125., 127., ...,  21.,   8.,   6.],\n",
       "       [126., 127., 129., ...,  24.,  10.,   8.]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0de1f22a-1b26-4b88-842b-9fa2fbb9a8e3",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[(' step', 1.0),\n",
       "  ('!', 0.0),\n",
       "  (' stepped', 0.0),\n",
       "  (' stepping', 0.0),\n",
       "  (' step', 1.0),\n",
       "  (' steps', 0.0)],\n",
       " [('\\n', 0.9862),\n",
       "  ('!', 0.0002),\n",
       "  ('.', 0.009),\n",
       "  ('\\n', 0.9862),\n",
       "  (' ', 0.001),\n",
       "  (':', 0.0025)],\n",
       " [('If', 0.1009),\n",
       "  ('If', 0.1009),\n",
       "  ('She', 0.0416),\n",
       "  ('The', 0.0394),\n",
       "  ('Nat', 0.3286),\n",
       "  ('In', 0.2685)],\n",
       " [(' Nat', 0.6018),\n",
       "  (' Nat', 0.6018),\n",
       "  (' in', 0.0261),\n",
       "  (' we', 0.0125),\n",
       "  (' 48', 0.0335),\n",
       "  (' she', 0.2588)],\n",
       " [('alia', 0.9968),\n",
       "  ('ale', 0.0001),\n",
       "  ('ilia', 0.0001),\n",
       "  ('lia', 0.0002),\n",
       "  (' sold', 0.0013),\n",
       "  ('alia', 0.9968)],\n",
       " [(' sold', 0.9573),\n",
       "  (' sells', 0.0217),\n",
       "  (' sold', 0.9573),\n",
       "  (' is', 0.0022),\n",
       "  (\"'s\", 0.0028),\n",
       "  (' had', 0.0024)],\n",
       " [(' 48', 0.4455),\n",
       "  (' clips', 0.3353),\n",
       "  (' half', 0.1318),\n",
       "  (' to', 0.0351),\n",
       "  (' 48', 0.4455),\n",
       "  (' twice', 0.0058)],\n",
       " [(' clips', 0.9248),\n",
       "  (' clips', 0.9248),\n",
       "  (' friends', 0.0062),\n",
       "  (' to', 0.0058),\n",
       "  (' of', 0.0109),\n",
       "  (' in', 0.0345)],\n",
       " [(' in', 0.7597),\n",
       "  (',', 0.0206),\n",
       "  (' to', 0.2018),\n",
       "  (' in', 0.7597),\n",
       "  (' and', 0.0031),\n",
       "  (' then', 0.0028)],\n",
       " [(' April', 0.9759),\n",
       "  (' May', 0.0039),\n",
       "  (' total', 0.0025),\n",
       "  (' the', 0.0067),\n",
       "  (' April', 0.9759),\n",
       "  (' apr', 0.0026)],\n",
       " [(',', 0.7709),\n",
       "  (',', 0.7709),\n",
       "  (' to', 0.0033),\n",
       "  (' and', 0.1443),\n",
       "  (' she', 0.0074),\n",
       "  (' then', 0.0602)],\n",
       " [(' she', 0.1407),\n",
       "  (' in', 0.0182),\n",
       "  (' and', 0.2126),\n",
       "  (' that', 0.0235),\n",
       "  (' she', 0.1407),\n",
       "  (' then', 0.5304)],\n",
       " [(' sold', 0.815),\n",
       "  (' must', 0.0433),\n",
       "  (' sold', 0.815),\n",
       "  (' will', 0.0162),\n",
       "  (' would', 0.0389),\n",
       "  (' then', 0.0111)],\n",
       " [(' 48', 0.556),\n",
       "  (' 24', 0.1033),\n",
       "  (' half', 0.1671),\n",
       "  (' 1', 0.0266),\n",
       "  (' (', 0.0178),\n",
       "  (' 48', 0.556)],\n",
       " [('/', 0.7438),\n",
       "  (' /', 0.0471),\n",
       "  ('/', 0.7438),\n",
       "  (' *', 0.0386),\n",
       "  (' x', 0.0391),\n",
       "  ('*', 0.0442)],\n",
       " [('2', 0.9811),\n",
       "  ('12', 0.0016),\n",
       "  ('1', 0.0023),\n",
       "  ('2', 0.9811),\n",
       "  ('4', 0.0085),\n",
       "  (' 2', 0.0019)],\n",
       " [(' =', 0.8517),\n",
       "  (' clips', 0.0251),\n",
       "  ('=', 0.099),\n",
       "  (' in', 0.0068),\n",
       "  (' or', 0.0085),\n",
       "  (' =', 0.8517)],\n",
       " [(' 24', 0.9896),\n",
       "  (' 12', 0.0005),\n",
       "  ('24', 0.007),\n",
       "  (' 24', 0.9896),\n",
       "  (' ', 0.0002),\n",
       "  (' 2', 0.0004)],\n",
       " [(' clips', 0.8775),\n",
       "  (' less', 0.0026),\n",
       "  (' clips', 0.8775),\n",
       "  (' in', 0.106),\n",
       "  (' more', 0.0026),\n",
       "  (' fewer', 0.0034)],\n",
       " [(' in', 0.9914),\n",
       "  (' during', 0.0008),\n",
       "  (' less', 0.001),\n",
       "  (' the', 0.0011),\n",
       "  (' to', 0.0013),\n",
       "  (' in', 0.9914)],\n",
       " [(' May', 0.9877),\n",
       "  (' May', 0.9877),\n",
       "  (' the', 0.0026),\n",
       "  (' March', 0.0006),\n",
       "  (' April', 0.0012),\n",
       "  (' may', 0.0054)],\n",
       " [('.', 0.8685),\n",
       "  (',', 0.0226),\n",
       "  (' since', 0.0052),\n",
       "  ('.', 0.8685),\n",
       "  ('\\n', 0.0827),\n",
       "  (' because', 0.0075)],\n",
       " [('\\n', 0.9843),\n",
       "  (' She', 0.0007),\n",
       "  (' So', 0.0015),\n",
       "  ('\\n', 0.9843),\n",
       "  (' ', 0.0052),\n",
       "  (' In', 0.001)],\n",
       " [('In', 0.2022),\n",
       "  ('So', 0.1476),\n",
       "  ('Therefore', 0.0765),\n",
       "  ('She', 0.0584),\n",
       "  ('The', 0.1467),\n",
       "  ('In', 0.2022)],\n",
       " [(' total', 0.8519),\n",
       "  (' both', 0.0057),\n",
       "  (' May', 0.0097),\n",
       "  (' total', 0.8519),\n",
       "  (' the', 0.0058),\n",
       "  (' April', 0.1015)],\n",
       " [(',', 0.8402),\n",
       "  (',', 0.8402),\n",
       "  (' Nat', 0.0197),\n",
       "  (' in', 0.0042),\n",
       "  (' she', 0.114),\n",
       "  (' then', 0.0114)],\n",
       " [(' Nat', 0.2935),\n",
       "  (' Nat', 0.2935),\n",
       "  (' in', 0.0108),\n",
       "  (' 48', 0.0141),\n",
       "  (' she', 0.5699),\n",
       "  (' then', 0.0726)],\n",
       " [('alia', 0.9996),\n",
       "  ('lia', 0.0001),\n",
       "  ('elia', 0.0002),\n",
       "  ('alias', 0.0),\n",
       "  ('ala', 0.0),\n",
       "  ('alia', 0.9996)],\n",
       " [(' sold', 0.9757),\n",
       "  (' sells', 0.0026),\n",
       "  (' sold', 0.9757),\n",
       "  (' has', 0.0071),\n",
       "  (' will', 0.0025),\n",
       "  (' then', 0.0048)],\n",
       " [(' 48', 0.9477),\n",
       "  (' clips', 0.0153),\n",
       "  (' 24', 0.0158),\n",
       "  (' a', 0.0032),\n",
       "  (' 48', 0.9477),\n",
       "  (' 72', 0.0048)],\n",
       " [(' +', 0.6164),\n",
       "  ('+', 0.3547),\n",
       "  (' +', 0.6164),\n",
       "  (' clips', 0.0236),\n",
       "  (' in', 0.0019),\n",
       "  (' April', 0.0005)],\n",
       " [(' 24', 0.992),\n",
       "  (' 12', 0.0002),\n",
       "  ('24', 0.0065),\n",
       "  (' 24', 0.992),\n",
       "  (' 2', 0.0002),\n",
       "  (' 48', 0.0002)],\n",
       " [(' =', 0.9853),\n",
       "  (',', 0.0002),\n",
       "  (' clips', 0.0072),\n",
       "  (' ', 0.0002),\n",
       "  ('=', 0.0066),\n",
       "  (' =', 0.9853)],\n",
       " [(' 72', 0.9951),\n",
       "  (' ', 0.0002),\n",
       "  (' 70', 0.0003),\n",
       "  ('72', 0.0026),\n",
       "  (' 7', 0.0001),\n",
       "  (' 72', 0.9951)],\n",
       " [(' clips', 0.9888),\n",
       "  (' clip', 0.001),\n",
       "  ('.', 0.0012),\n",
       "  (' clips', 0.9888),\n",
       "  ('\\n', 0.003),\n",
       "  (' total', 0.0018)],\n",
       " [(' in', 0.359),\n",
       "  ('.', 0.3778),\n",
       "  (' altogether', 0.0955),\n",
       "  ('\\n', 0.1112),\n",
       "  (' in', 0.359),\n",
       "  (' over', 0.0129)],\n",
       " [(' April', 0.8876),\n",
       "  (' both', 0.0564),\n",
       "  (' May', 0.0049),\n",
       "  (' total', 0.0139),\n",
       "  (' the', 0.0216),\n",
       "  (' April', 0.8876)],\n",
       " [(' and', 0.9928),\n",
       "  (',', 0.0004),\n",
       "  (' &', 0.0024),\n",
       "  (' +', 0.0024),\n",
       "  (' and', 0.9928),\n",
       "  (' plus', 0.0005)],\n",
       " [(' May', 0.9962),\n",
       "  (' May', 0.9962),\n",
       "  (' June', 0.0001),\n",
       "  (' in', 0.0023),\n",
       "  (' may', 0.0005),\n",
       "  (' then', 0.0001)],\n",
       " [('.', 0.8727),\n",
       "  ('.', 0.8727),\n",
       "  (' altogether', 0.004),\n",
       "  (' together', 0.0127),\n",
       "  ('\\n', 0.0838),\n",
       "  (' combined', 0.0231)],\n",
       " [('\\n', 0.995),\n",
       "  ('\\n', 0.995),\n",
       "  (' ', 0.0035),\n",
       "  (' The', 0.0005),\n",
       "  ('  ', 0.0002),\n",
       "  ('\\n\\n', 0.0002)],\n",
       " [('The', 0.9803),\n",
       "  ('\\n', 0.0085),\n",
       "  ('So', 0.001),\n",
       "  ('Therefore', 0.0009),\n",
       "  ('Answer', 0.0008),\n",
       "  ('The', 0.9803)],\n",
       " [(' answer', 0.9972),\n",
       "  (' number', 0.0003),\n",
       "  (' Answer', 0.0006),\n",
       "  (' total', 0.0005),\n",
       "  (' an', 0.0002),\n",
       "  (' answer', 0.9972)],\n",
       " [(' is', 0.999),\n",
       "  (':', 0.0001),\n",
       "  (' to', 0.0001),\n",
       "  (' in', 0.0001),\n",
       "  (' is', 0.999),\n",
       "  (' 72', 0.0004)],\n",
       " [(' 72', 0.9981),\n",
       "  (' 24', 0.0001),\n",
       "  (' ', 0.0004),\n",
       "  (' 48', 0.0001),\n",
       "  (' 7', 0.0001),\n",
       "  (' 72', 0.9981)],\n",
       " [('\\n', 0.9577),\n",
       "  ('.', 0.0077),\n",
       "  ('\\n', 0.9577),\n",
       "  (' ', 0.0018),\n",
       "  ('<|endoftext|>', 0.0083),\n",
       "  ('\\n\\n', 0.0214)],\n",
       " [('\\n', 0.973),\n",
       "  ('``', 0.0018),\n",
       "  ('\\n', 0.973),\n",
       "  (' ', 0.0083),\n",
       "  ('\"\"\"', 0.0046),\n",
       "  (\"''\", 0.0028)]]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_per_step_probs[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "35d0b79f-b04e-4b48-979e-e3d1acedee26",
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform_step_probs(qi, aj, vocab, pred, per_step_prob):\n",
    "    # remove \"Model output i:\"\n",
    "    pred = ''.join(pred.split(': ')[1:]).strip()\n",
    "    flan_tokens = tokenizer.convert_ids_to_tokens(tokenizer(pred)['input_ids'])[:-1]\n",
    "    \n",
    "    if(per_step_prob[0][0][0] != ' step'): \n",
    "        # print('q %d a %d debug 1' % (qi, aj))\n",
    "        return -1, None\n",
    "    if(per_step_prob[1][0][0] != '\\n'): \n",
    "        # print('q %d a %d debug 2' % (qi, aj))\n",
    "        return -1, None\n",
    "    if(per_step_prob[-2][0][0] not in ['\\n', '\\n\\n']): \n",
    "        # print('q %d a %d debug 3' % (qi, aj))\n",
    "        return -1, None\n",
    "    if(per_step_prob[-1][0][0] != '\\n'): \n",
    "        # print('q %d a %d debug 4' % (qi, aj))\n",
    "        return -1, None\n",
    "    codex_tokens = [s[0][0] for s in per_step_prob[2:-2]]\n",
    "    per_step_prob = per_step_prob[2:-2]\n",
    "    \n",
    "    _, _, _, flan2codex, _ = dtw(codex_tokens, flan_tokens, norm_func=dist_fn)\n",
    "    transformed_step_probs = []\n",
    "    for i, codex_idx in enumerate(flan2codex):\n",
    "        if(len(codex_idx) == 1): # one flan token map to on codex token\n",
    "            j = codex_idx[0]\n",
    "            flan_token = flan_tokens[i]\n",
    "            codex_token = codex_tokens[j].replace(' ', '▁')\n",
    "            \n",
    "            # import ipdb; ipdb.set_trace()\n",
    "            if(flan_token == codex_token):\n",
    "                probs = OrderedDict()\n",
    "                for t, p in per_step_prob[j]:\n",
    "                    flan_t = t.replace(' ', '▁')\n",
    "                    if(flan_t in vocab):\n",
    "                        if(flan_t not in probs):\n",
    "                            probs[flan_t] = p\n",
    "            else:\n",
    "                probs = {flan_token: 1.0}\n",
    "        else: # one flan token map to multiple codex token, in this case only fit flan token\n",
    "            flan_token = flan_tokens[i]\n",
    "            probs = {flan_token: 1.0}\n",
    "        transformed_step_probs.append(probs)\n",
    "    if(i != len(flan_tokens) - 1): print('q %d a %d debug 5')\n",
    "    return 1, transformed_step_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "730444d0-5c2e-4b19-904d-747472c2d76b",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = tokenizer.get_vocab()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "530ee487-98a1-42f2-aa18-9d4ef804049f",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed_step_probs = transform_step_probs(0, 0, vocab, codex_predictions[0][0], codex_per_step_probs[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "17882f2c-68d7-4893-895f-9b98b04a00f9",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[(' step', 1.0),\n",
       "  ('!', 0.0),\n",
       "  (' stepped', 0.0),\n",
       "  (' stepping', 0.0),\n",
       "  (' step', 1.0),\n",
       "  (' steps', 0.0)],\n",
       " [('\\n', 0.9862),\n",
       "  ('!', 0.0002),\n",
       "  ('.', 0.009),\n",
       "  ('\\n', 0.9862),\n",
       "  (' ', 0.001),\n",
       "  (':', 0.0025)],\n",
       " [('If', 0.1009),\n",
       "  ('If', 0.1009),\n",
       "  ('She', 0.0416),\n",
       "  ('The', 0.0394),\n",
       "  ('Nat', 0.3286),\n",
       "  ('In', 0.2685)],\n",
       " [(' Nat', 0.6018),\n",
       "  (' Nat', 0.6018),\n",
       "  (' in', 0.0261),\n",
       "  (' we', 0.0125),\n",
       "  (' 48', 0.0335),\n",
       "  (' she', 0.2588)],\n",
       " [('alia', 0.9968),\n",
       "  ('ale', 0.0001),\n",
       "  ('ilia', 0.0001),\n",
       "  ('lia', 0.0002),\n",
       "  (' sold', 0.0013),\n",
       "  ('alia', 0.9968)],\n",
       " [(' sold', 0.9573),\n",
       "  (' sells', 0.0217),\n",
       "  (' sold', 0.9573),\n",
       "  (' is', 0.0022),\n",
       "  (\"'s\", 0.0028),\n",
       "  (' had', 0.0024)],\n",
       " [(' 48', 0.4455),\n",
       "  (' clips', 0.3353),\n",
       "  (' half', 0.1318),\n",
       "  (' to', 0.0351),\n",
       "  (' 48', 0.4455),\n",
       "  (' twice', 0.0058)],\n",
       " [(' clips', 0.9248),\n",
       "  (' clips', 0.9248),\n",
       "  (' friends', 0.0062),\n",
       "  (' to', 0.0058),\n",
       "  (' of', 0.0109),\n",
       "  (' in', 0.0345)],\n",
       " [(' in', 0.7597),\n",
       "  (',', 0.0206),\n",
       "  (' to', 0.2018),\n",
       "  (' in', 0.7597),\n",
       "  (' and', 0.0031),\n",
       "  (' then', 0.0028)],\n",
       " [(' April', 0.9759),\n",
       "  (' May', 0.0039),\n",
       "  (' total', 0.0025),\n",
       "  (' the', 0.0067),\n",
       "  (' April', 0.9759),\n",
       "  (' apr', 0.0026)],\n",
       " [(',', 0.7709),\n",
       "  (',', 0.7709),\n",
       "  (' to', 0.0033),\n",
       "  (' and', 0.1443),\n",
       "  (' she', 0.0074),\n",
       "  (' then', 0.0602)],\n",
       " [(' she', 0.1407),\n",
       "  (' in', 0.0182),\n",
       "  (' and', 0.2126),\n",
       "  (' that', 0.0235),\n",
       "  (' she', 0.1407),\n",
       "  (' then', 0.5304)],\n",
       " [(' sold', 0.815),\n",
       "  (' must', 0.0433),\n",
       "  (' sold', 0.815),\n",
       "  (' will', 0.0162),\n",
       "  (' would', 0.0389),\n",
       "  (' then', 0.0111)],\n",
       " [(' 48', 0.556),\n",
       "  (' 24', 0.1033),\n",
       "  (' half', 0.1671),\n",
       "  (' 1', 0.0266),\n",
       "  (' (', 0.0178),\n",
       "  (' 48', 0.556)],\n",
       " [('/', 0.7438),\n",
       "  (' /', 0.0471),\n",
       "  ('/', 0.7438),\n",
       "  (' *', 0.0386),\n",
       "  (' x', 0.0391),\n",
       "  ('*', 0.0442)],\n",
       " [('2', 0.9811),\n",
       "  ('12', 0.0016),\n",
       "  ('1', 0.0023),\n",
       "  ('2', 0.9811),\n",
       "  ('4', 0.0085),\n",
       "  (' 2', 0.0019)],\n",
       " [(' =', 0.8517),\n",
       "  (' clips', 0.0251),\n",
       "  ('=', 0.099),\n",
       "  (' in', 0.0068),\n",
       "  (' or', 0.0085),\n",
       "  (' =', 0.8517)],\n",
       " [(' 24', 0.9896),\n",
       "  (' 12', 0.0005),\n",
       "  ('24', 0.007),\n",
       "  (' 24', 0.9896),\n",
       "  (' ', 0.0002),\n",
       "  (' 2', 0.0004)],\n",
       " [(' clips', 0.8775),\n",
       "  (' less', 0.0026),\n",
       "  (' clips', 0.8775),\n",
       "  (' in', 0.106),\n",
       "  (' more', 0.0026),\n",
       "  (' fewer', 0.0034)],\n",
       " [(' in', 0.9914),\n",
       "  (' during', 0.0008),\n",
       "  (' less', 0.001),\n",
       "  (' the', 0.0011),\n",
       "  (' to', 0.0013),\n",
       "  (' in', 0.9914)],\n",
       " [(' May', 0.9877),\n",
       "  (' May', 0.9877),\n",
       "  (' the', 0.0026),\n",
       "  (' March', 0.0006),\n",
       "  (' April', 0.0012),\n",
       "  (' may', 0.0054)],\n",
       " [('.', 0.8685),\n",
       "  (',', 0.0226),\n",
       "  (' since', 0.0052),\n",
       "  ('.', 0.8685),\n",
       "  ('\\n', 0.0827),\n",
       "  (' because', 0.0075)],\n",
       " [('\\n', 0.9843),\n",
       "  (' She', 0.0007),\n",
       "  (' So', 0.0015),\n",
       "  ('\\n', 0.9843),\n",
       "  (' ', 0.0052),\n",
       "  (' In', 0.001)],\n",
       " [('In', 0.2022),\n",
       "  ('So', 0.1476),\n",
       "  ('Therefore', 0.0765),\n",
       "  ('She', 0.0584),\n",
       "  ('The', 0.1467),\n",
       "  ('In', 0.2022)],\n",
       " [(' total', 0.8519),\n",
       "  (' both', 0.0057),\n",
       "  (' May', 0.0097),\n",
       "  (' total', 0.8519),\n",
       "  (' the', 0.0058),\n",
       "  (' April', 0.1015)],\n",
       " [(',', 0.8402),\n",
       "  (',', 0.8402),\n",
       "  (' Nat', 0.0197),\n",
       "  (' in', 0.0042),\n",
       "  (' she', 0.114),\n",
       "  (' then', 0.0114)],\n",
       " [(' Nat', 0.2935),\n",
       "  (' Nat', 0.2935),\n",
       "  (' in', 0.0108),\n",
       "  (' 48', 0.0141),\n",
       "  (' she', 0.5699),\n",
       "  (' then', 0.0726)],\n",
       " [('alia', 0.9996),\n",
       "  ('lia', 0.0001),\n",
       "  ('elia', 0.0002),\n",
       "  ('alias', 0.0),\n",
       "  ('ala', 0.0),\n",
       "  ('alia', 0.9996)],\n",
       " [(' sold', 0.9757),\n",
       "  (' sells', 0.0026),\n",
       "  (' sold', 0.9757),\n",
       "  (' has', 0.0071),\n",
       "  (' will', 0.0025),\n",
       "  (' then', 0.0048)],\n",
       " [(' 48', 0.9477),\n",
       "  (' clips', 0.0153),\n",
       "  (' 24', 0.0158),\n",
       "  (' a', 0.0032),\n",
       "  (' 48', 0.9477),\n",
       "  (' 72', 0.0048)],\n",
       " [(' +', 0.6164),\n",
       "  ('+', 0.3547),\n",
       "  (' +', 0.6164),\n",
       "  (' clips', 0.0236),\n",
       "  (' in', 0.0019),\n",
       "  (' April', 0.0005)],\n",
       " [(' 24', 0.992),\n",
       "  (' 12', 0.0002),\n",
       "  ('24', 0.0065),\n",
       "  (' 24', 0.992),\n",
       "  (' 2', 0.0002),\n",
       "  (' 48', 0.0002)],\n",
       " [(' =', 0.9853),\n",
       "  (',', 0.0002),\n",
       "  (' clips', 0.0072),\n",
       "  (' ', 0.0002),\n",
       "  ('=', 0.0066),\n",
       "  (' =', 0.9853)],\n",
       " [(' 72', 0.9951),\n",
       "  (' ', 0.0002),\n",
       "  (' 70', 0.0003),\n",
       "  ('72', 0.0026),\n",
       "  (' 7', 0.0001),\n",
       "  (' 72', 0.9951)],\n",
       " [(' clips', 0.9888),\n",
       "  (' clip', 0.001),\n",
       "  ('.', 0.0012),\n",
       "  (' clips', 0.9888),\n",
       "  ('\\n', 0.003),\n",
       "  (' total', 0.0018)],\n",
       " [(' in', 0.359),\n",
       "  ('.', 0.3778),\n",
       "  (' altogether', 0.0955),\n",
       "  ('\\n', 0.1112),\n",
       "  (' in', 0.359),\n",
       "  (' over', 0.0129)],\n",
       " [(' April', 0.8876),\n",
       "  (' both', 0.0564),\n",
       "  (' May', 0.0049),\n",
       "  (' total', 0.0139),\n",
       "  (' the', 0.0216),\n",
       "  (' April', 0.8876)],\n",
       " [(' and', 0.9928),\n",
       "  (',', 0.0004),\n",
       "  (' &', 0.0024),\n",
       "  (' +', 0.0024),\n",
       "  (' and', 0.9928),\n",
       "  (' plus', 0.0005)],\n",
       " [(' May', 0.9962),\n",
       "  (' May', 0.9962),\n",
       "  (' June', 0.0001),\n",
       "  (' in', 0.0023),\n",
       "  (' may', 0.0005),\n",
       "  (' then', 0.0001)],\n",
       " [('.', 0.8727),\n",
       "  ('.', 0.8727),\n",
       "  (' altogether', 0.004),\n",
       "  (' together', 0.0127),\n",
       "  ('\\n', 0.0838),\n",
       "  (' combined', 0.0231)],\n",
       " [('\\n', 0.995),\n",
       "  ('\\n', 0.995),\n",
       "  (' ', 0.0035),\n",
       "  (' The', 0.0005),\n",
       "  ('  ', 0.0002),\n",
       "  ('\\n\\n', 0.0002)],\n",
       " [('The', 0.9803),\n",
       "  ('\\n', 0.0085),\n",
       "  ('So', 0.001),\n",
       "  ('Therefore', 0.0009),\n",
       "  ('Answer', 0.0008),\n",
       "  ('The', 0.9803)],\n",
       " [(' answer', 0.9972),\n",
       "  (' number', 0.0003),\n",
       "  (' Answer', 0.0006),\n",
       "  (' total', 0.0005),\n",
       "  (' an', 0.0002),\n",
       "  (' answer', 0.9972)],\n",
       " [(' is', 0.999),\n",
       "  (':', 0.0001),\n",
       "  (' to', 0.0001),\n",
       "  (' in', 0.0001),\n",
       "  (' is', 0.999),\n",
       "  (' 72', 0.0004)],\n",
       " [(' 72', 0.9981),\n",
       "  (' 24', 0.0001),\n",
       "  (' ', 0.0004),\n",
       "  (' 48', 0.0001),\n",
       "  (' 7', 0.0001),\n",
       "  (' 72', 0.9981)],\n",
       " [('\\n', 0.9577),\n",
       "  ('.', 0.0077),\n",
       "  ('\\n', 0.9577),\n",
       "  (' ', 0.0018),\n",
       "  ('<|endoftext|>', 0.0083),\n",
       "  ('\\n\\n', 0.0214)],\n",
       " [('\\n', 0.973),\n",
       "  ('``', 0.0018),\n",
       "  ('\\n', 0.973),\n",
       "  (' ', 0.0083),\n",
       "  ('\"\"\"', 0.0046),\n",
       "  (\"''\", 0.0028)]]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_per_step_probs[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "5e94e71a-b540-435b-96cb-42453c697a5a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Model output 0: \\nIf Natalia sold 48 clips in April, she sold 48/2 = 24 clips in May.\\nIn total, Natalia sold 48 + 24 = 72 clips in April and May.\\nThe answer is 72\\n'"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_predictions[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "3e6b0c5f-0e89-48d2-b56a-18bdefe74b60",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'▁If': 1.0},\n",
       " OrderedDict([('▁Nat', 0.6018),\n",
       "              ('▁in', 0.0261),\n",
       "              ('▁we', 0.0125),\n",
       "              ('▁48', 0.0335),\n",
       "              ('▁she', 0.2588)]),\n",
       " OrderedDict([('alia', 0.9968), ('ilia', 0.0001), ('▁sold', 0.0013)]),\n",
       " OrderedDict([('▁sold', 0.9573), ('▁is', 0.0022), ('▁had', 0.0024)]),\n",
       " OrderedDict([('▁48', 0.4455),\n",
       "              ('▁clips', 0.3353),\n",
       "              ('▁half', 0.1318),\n",
       "              ('▁to', 0.0351),\n",
       "              ('▁twice', 0.0058)]),\n",
       " OrderedDict([('▁clips', 0.9248),\n",
       "              ('▁friends', 0.0062),\n",
       "              ('▁to', 0.0058),\n",
       "              ('▁of', 0.0109),\n",
       "              ('▁in', 0.0345)]),\n",
       " OrderedDict([('▁in', 0.7597),\n",
       "              (',', 0.0206),\n",
       "              ('▁to', 0.2018),\n",
       "              ('▁and', 0.0031),\n",
       "              ('▁then', 0.0028)]),\n",
       " OrderedDict([('▁April', 0.9759),\n",
       "              ('▁May', 0.0039),\n",
       "              ('▁total', 0.0025),\n",
       "              ('▁the', 0.0067)]),\n",
       " OrderedDict([(',', 0.7709),\n",
       "              ('▁to', 0.0033),\n",
       "              ('▁and', 0.1443),\n",
       "              ('▁she', 0.0074),\n",
       "              ('▁then', 0.0602)]),\n",
       " OrderedDict([('▁she', 0.1407),\n",
       "              ('▁in', 0.0182),\n",
       "              ('▁and', 0.2126),\n",
       "              ('▁that', 0.0235),\n",
       "              ('▁then', 0.5304)]),\n",
       " OrderedDict([('▁sold', 0.815),\n",
       "              ('▁must', 0.0433),\n",
       "              ('▁will', 0.0162),\n",
       "              ('▁would', 0.0389),\n",
       "              ('▁then', 0.0111)]),\n",
       " OrderedDict([('▁48', 0.556),\n",
       "              ('▁24', 0.1033),\n",
       "              ('▁half', 0.1671),\n",
       "              ('▁1', 0.0266),\n",
       "              ('▁(', 0.0178)]),\n",
       " {'/2': 1.0},\n",
       " OrderedDict([('▁=', 0.8517),\n",
       "              ('▁clips', 0.0251),\n",
       "              ('=', 0.099),\n",
       "              ('▁in', 0.0068),\n",
       "              ('▁or', 0.0085)]),\n",
       " OrderedDict([('▁24', 0.9896),\n",
       "              ('▁12', 0.0005),\n",
       "              ('24', 0.007),\n",
       "              ('▁', 0.0002),\n",
       "              ('▁2', 0.0004)]),\n",
       " OrderedDict([('▁clips', 0.8775),\n",
       "              ('▁less', 0.0026),\n",
       "              ('▁in', 0.106),\n",
       "              ('▁more', 0.0026)]),\n",
       " OrderedDict([('▁in', 0.9914),\n",
       "              ('▁during', 0.0008),\n",
       "              ('▁less', 0.001),\n",
       "              ('▁the', 0.0011),\n",
       "              ('▁to', 0.0013)]),\n",
       " OrderedDict([('▁May', 0.9877),\n",
       "              ('▁the', 0.0026),\n",
       "              ('▁March', 0.0006),\n",
       "              ('▁April', 0.0012),\n",
       "              ('▁may', 0.0054)]),\n",
       " {'.': 1.0},\n",
       " {'▁In': 1.0},\n",
       " OrderedDict([('▁total', 0.8519),\n",
       "              ('▁both', 0.0057),\n",
       "              ('▁May', 0.0097),\n",
       "              ('▁the', 0.0058),\n",
       "              ('▁April', 0.1015)]),\n",
       " OrderedDict([(',', 0.8402),\n",
       "              ('▁Nat', 0.0197),\n",
       "              ('▁in', 0.0042),\n",
       "              ('▁she', 0.114),\n",
       "              ('▁then', 0.0114)]),\n",
       " OrderedDict([('▁Nat', 0.2935),\n",
       "              ('▁in', 0.0108),\n",
       "              ('▁48', 0.0141),\n",
       "              ('▁she', 0.5699),\n",
       "              ('▁then', 0.0726)]),\n",
       " OrderedDict([('alia', 0.9996), ('elia', 0.0002)]),\n",
       " OrderedDict([('▁sold', 0.9757),\n",
       "              ('▁has', 0.0071),\n",
       "              ('▁will', 0.0025),\n",
       "              ('▁then', 0.0048)]),\n",
       " OrderedDict([('▁48', 0.9477),\n",
       "              ('▁clips', 0.0153),\n",
       "              ('▁24', 0.0158),\n",
       "              ('▁72', 0.0048)]),\n",
       " OrderedDict([('▁+', 0.6164),\n",
       "              ('+', 0.3547),\n",
       "              ('▁clips', 0.0236),\n",
       "              ('▁in', 0.0019),\n",
       "              ('▁April', 0.0005)]),\n",
       " OrderedDict([('▁24', 0.992),\n",
       "              ('▁12', 0.0002),\n",
       "              ('24', 0.0065),\n",
       "              ('▁2', 0.0002),\n",
       "              ('▁48', 0.0002)]),\n",
       " OrderedDict([('▁=', 0.9853),\n",
       "              (',', 0.0002),\n",
       "              ('▁clips', 0.0072),\n",
       "              ('▁', 0.0002),\n",
       "              ('=', 0.0066)]),\n",
       " OrderedDict([('▁72', 0.9951),\n",
       "              ('▁', 0.0002),\n",
       "              ('▁70', 0.0003),\n",
       "              ('72', 0.0026),\n",
       "              ('▁7', 0.0001)]),\n",
       " OrderedDict([('▁clips', 0.9888),\n",
       "              ('▁clip', 0.001),\n",
       "              ('.', 0.0012),\n",
       "              ('▁total', 0.0018)]),\n",
       " OrderedDict([('▁in', 0.359),\n",
       "              ('.', 0.3778),\n",
       "              ('▁altogether', 0.0955),\n",
       "              ('▁over', 0.0129)]),\n",
       " OrderedDict([('▁April', 0.8876),\n",
       "              ('▁both', 0.0564),\n",
       "              ('▁May', 0.0049),\n",
       "              ('▁total', 0.0139),\n",
       "              ('▁the', 0.0216)]),\n",
       " OrderedDict([('▁and', 0.9928),\n",
       "              (',', 0.0004),\n",
       "              ('▁+', 0.0024),\n",
       "              ('▁plus', 0.0005)]),\n",
       " OrderedDict([('▁May', 0.9962),\n",
       "              ('▁June', 0.0001),\n",
       "              ('▁in', 0.0023),\n",
       "              ('▁may', 0.0005),\n",
       "              ('▁then', 0.0001)]),\n",
       " {'.': 1.0},\n",
       " {'▁The': 1.0},\n",
       " OrderedDict([('▁answer', 0.9972),\n",
       "              ('▁number', 0.0003),\n",
       "              ('▁Answer', 0.0006),\n",
       "              ('▁total', 0.0005),\n",
       "              ('▁an', 0.0002)]),\n",
       " OrderedDict([('▁is', 0.999),\n",
       "              (':', 0.0001),\n",
       "              ('▁to', 0.0001),\n",
       "              ('▁in', 0.0001),\n",
       "              ('▁72', 0.0004)]),\n",
       " OrderedDict([('▁72', 0.9981),\n",
       "              ('▁24', 0.0001),\n",
       "              ('▁', 0.0004),\n",
       "              ('▁48', 0.0001),\n",
       "              ('▁7', 0.0001)])]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transformed_step_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "f195768e-1651-43d9-815a-0ed6c2da2435",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6018"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transformed_step_probs[1]['▁Nat']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "d1c738a1-12f9-4432-9f70-f26c69773927",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "▁total 0.8519\n",
      "▁both 0.0057\n",
      "▁May 0.0097\n",
      "▁the 0.0058\n",
      "▁April 0.1015\n"
     ]
    }
   ],
   "source": [
    "idx = 20\n",
    "for k in transformed_step_probs[idx]: print(k, transformed_step_probs[idx][k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "ac794e89-587d-4da6-ad8e-7aa2545af8b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7473"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(codex_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "7cf8436d-a6c6-4ddd-9333-4852c616a626",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7474"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(codex_questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "9072ed48-3c11-41b7-979b-9934c40dcd8a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7473/7473 [58:01<00:00,  2.15it/s]\n"
     ]
    }
   ],
   "source": [
    "codex_dtw_transformed_step_probs = []\n",
    "codex_updated_labels = []\n",
    "total_case = 0\n",
    "total_correct = 0\n",
    "modified = 0\n",
    "for qi, q in tqdm(enumerate(codex_questions), total=len(codex_questions)):\n",
    "    transformed = []\n",
    "    updated_labels = []\n",
    "    for ai, (pred, prob, label) in enumerate(zip(codex_predictions[qi], codex_per_step_probs[qi], codex_prediction_labels[qi])):\n",
    "        total_case += 1\n",
    "        if(label == 1):\n",
    "            total_correct += 1\n",
    "            ret_code, ret = transform_step_probs(qi, ai, vocab, pred, prob)\n",
    "            if(ret_code == -1):\n",
    "                updated_labels.append(0)\n",
    "                transformed.append(None)\n",
    "                modified += 1\n",
    "            else:\n",
    "                transformed.append(ret)\n",
    "                updated_labels.append(1)\n",
    "        else: \n",
    "            transformed.append(None)\n",
    "            updated_labels.append(0)\n",
    "    codex_updated_labels.append(updated_labels)\n",
    "    codex_dtw_transformed_step_probs.append(transformed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "id": "b16fb2ac-e0d4-46c1-a380-297a4f60aa96",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total prediction 323810, original labeled corred 207619, modified 484\n"
     ]
    }
   ],
   "source": [
    "print('total prediction %d, original labeled corred %d, modified %d' % (total_case, total_correct, modified))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "b3ea4802-ea0d-4bf4-a28f-cdd37b12c8dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(codex_dtw_transformed_step_probs, open('../processed_data/codex_dtw_transformed_step_probs.pkl', 'wb'))\n",
    "pickle.dump(codex_updated_labels, open('../processed_data/codex_updated_labels.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26f4b5a9-bfd5-4119-902b-38b184d54e33",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "5c1157c6-de64-44f7-b982-fcce1c4fb497",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Model output 1: \\nJasper buys 2 pounds of cheddar cheese for $10, meaning that 1 pound of cheddar cheese costs $10/2 = $5\\nJasper buys a pound of cream cheese that costs half the price of cheddar cheese, meaning that 1 pound of cream cheese costs $5/2 = $2.50\\nJasper buys a pack of cold cuts that cost twice the price of cheddar cheese, meaning that 1 pack of cold cuts costs $5*2 = $10\\nJasper buys 2 pounds of cheddar cheese, 1 pound of cream cheese, and 1 pack of cold cuts.\\nTogether, he spends 2*$5 + $2.50 + $10 = $22.50\\nThe answer is 22.5\\n'"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_predictions[1][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "bf17a727-a6f2-4b09-94fa-b905f15efd63",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[(' step', 1.0),\n",
       "  ('!', 0.0),\n",
       "  (' stepped', 0.0),\n",
       "  (' stepping', 0.0),\n",
       "  (' step', 1.0),\n",
       "  (' steps', 0.0)],\n",
       " [('\\n', 0.9877),\n",
       "  ('.', 0.0066),\n",
       "  ('\\n', 0.9877),\n",
       "  (' ', 0.0007),\n",
       "  (':', 0.0027),\n",
       "  ('\\n\\n', 0.0007)],\n",
       " [('After', 0.2227),\n",
       "  ('We', 0.0602),\n",
       "  ('If', 0.0648),\n",
       "  ('After', 0.2227),\n",
       "  ('The', 0.2784),\n",
       "  ('Let', 0.0527)],\n",
       " [(' the', 0.8436),\n",
       "  (' each', 0.0086),\n",
       "  (' traveling', 0.01),\n",
       "  (' turn', 0.0138),\n",
       "  (' the', 0.8436),\n",
       "  (' 1', 0.049)],\n",
       " [(' first', 0.4074),\n",
       "  (' third', 0.0225),\n",
       "  (' 1', 0.4435),\n",
       "  (' 2', 0.0232),\n",
       "  (' 3', 0.0537),\n",
       "  (' first', 0.4074)],\n",
       " [(' turn', 0.8979),\n",
       "  (' three', 0.0075),\n",
       "  (' turn', 0.8979),\n",
       "  (' and', 0.0113),\n",
       "  (' two', 0.0159),\n",
       "  (' right', 0.0446)],\n",
       " [(',', 0.8022),\n",
       "  (',', 0.8022),\n",
       "  (' the', 0.1396),\n",
       "  (' of', 0.0051),\n",
       "  (' it', 0.0252),\n",
       "  (' we', 0.0072)],\n",
       " [(' the', 0.8069),\n",
       "  (' the', 0.8069),\n",
       "  (' it', 0.1321),\n",
       "  (' we', 0.0212),\n",
       "  (' there', 0.0043),\n",
       "  (' 5', 0.0093)],\n",
       " [(' car', 0.9658),\n",
       "  (' car', 0.9658),\n",
       "  (' total', 0.006),\n",
       "  (' vehicle', 0.0017),\n",
       "  (' driver', 0.0023),\n",
       "  (' distance', 0.0134)],\n",
       " [(' has', 0.099),\n",
       "  (' drives', 0.0424),\n",
       "  (' traveled', 0.167),\n",
       "  (' travels', 0.4661),\n",
       "  (' travelled', 0.0769),\n",
       "  (' has', 0.099)],\n",
       " [(' traveled', 0.44),\n",
       "  (' traveled', 0.44),\n",
       "  (' travelled', 0.2797),\n",
       "  (' to', 0.0521),\n",
       "  (' gone', 0.0309),\n",
       "  (' driven', 0.1188)],\n",
       " [(' 5', 0.8729),\n",
       "  (' a', 0.0764),\n",
       "  (' for', 0.007),\n",
       "  (' 1', 0.0076),\n",
       "  (' 5', 0.8729),\n",
       "  (' through', 0.0046)],\n",
       " [(' meters', 0.8521),\n",
       "  ('+', 0.0062),\n",
       "  (' meters', 0.8521),\n",
       "  (' +', 0.0075),\n",
       "  (' m', 0.0433),\n",
       "  ('m', 0.0723)],\n",
       " [('.', 0.6938),\n",
       "  (',', 0.072),\n",
       "  ('.', 0.6938),\n",
       "  ('\\n', 0.137),\n",
       "  (' in', 0.012),\n",
       "  (' and', 0.0141)],\n",
       " [('\\n', 0.9046),\n",
       "  ('\\n', 0.9046),\n",
       "  (' ', 0.0136),\n",
       "  (' After', 0.0477),\n",
       "  (' The', 0.0033),\n",
       "  (' This', 0.0034)],\n",
       " [('After', 0.8982),\n",
       "  ('And', 0.0069),\n",
       "  ('After', 0.8982),\n",
       "  ('The', 0.0145),\n",
       "  ('Then', 0.0164),\n",
       "  ('In', 0.0059)],\n",
       " [(' the', 0.9865),\n",
       "  (' a', 0.0008),\n",
       "  (' the', 0.9865),\n",
       "  (' that', 0.0009),\n",
       "  (' 2', 0.0026),\n",
       "  (' two', 0.0037)],\n",
       " [(' second', 0.9026),\n",
       "  (' second', 0.9026),\n",
       "  (' next', 0.0026),\n",
       "  (' third', 0.0006),\n",
       "  (' 2', 0.0875),\n",
       "  (' first', 0.0045)],\n",
       " [(' turn', 0.9834),\n",
       "  (',', 0.0118),\n",
       "  (' turn', 0.9834),\n",
       "  (' and', 0.0023),\n",
       "  (' it', 0.0006),\n",
       "  (' right', 0.0004)],\n",
       " [(',', 0.9721),\n",
       "  (',', 0.9721),\n",
       "  (' the', 0.0055),\n",
       "  (' and', 0.0007),\n",
       "  (' it', 0.0192),\n",
       "  (' (', 0.0005)],\n",
       " [(' the', 0.6598),\n",
       "  (' the', 0.6598),\n",
       "  (' it', 0.3194),\n",
       "  (' we', 0.0033),\n",
       "  (' 5', 0.0012),\n",
       "  (' 8', 0.0043)],\n",
       " [(' car', 0.9919),\n",
       "  (' car', 0.9919),\n",
       "  (' second', 0.0003),\n",
       "  (' total', 0.0031),\n",
       "  (' distance', 0.0011),\n",
       "  (' first', 0.0005)],\n",
       " [(' has', 0.9585),\n",
       "  (' traveled', 0.0051),\n",
       "  (' travels', 0.0159),\n",
       "  (' is', 0.0042),\n",
       "  (' has', 0.9585),\n",
       "  (' will', 0.0041)],\n",
       " [(' traveled', 0.9669),\n",
       "  (' traveled', 0.9669),\n",
       "  (' travelled', 0.0064),\n",
       "  (' gone', 0.0045),\n",
       "  (' now', 0.0036),\n",
       "  (' driven', 0.003)],\n",
       " [(' 8', 0.4184),\n",
       "  (' 13', 0.0502),\n",
       "  (' a', 0.074),\n",
       "  (' an', 0.0492),\n",
       "  (' 5', 0.3238),\n",
       "  (' 8', 0.4184)],\n",
       " [(' meters', 0.6633),\n",
       "  ('+', 0.0372),\n",
       "  (' meters', 0.6633),\n",
       "  (' +', 0.045),\n",
       "  (' additional', 0.0331),\n",
       "  (' more', 0.1909)],\n",
       " [('.', 0.695),\n",
       "  (',', 0.1211),\n",
       "  ('.', 0.695),\n",
       "  (' +', 0.012),\n",
       "  (' in', 0.0171),\n",
       "  (' more', 0.062)],\n",
       " [('\\n', 0.9839),\n",
       "  (' So', 0.0017),\n",
       "  ('\\n', 0.9839),\n",
       "  (' ', 0.0041),\n",
       "  (' The', 0.0011),\n",
       "  (' This', 0.0021)],\n",
       " [('After', 0.5798),\n",
       "  ('This', 0.0303),\n",
       "  ('So', 0.0524),\n",
       "  ('Therefore', 0.0199),\n",
       "  ('After', 0.5798),\n",
       "  ('The', 0.076)],\n",
       " [(' the', 0.9646),\n",
       "  (' both', 0.0071),\n",
       "  (' three', 0.0031),\n",
       "  (' the', 0.9646),\n",
       "  (' two', 0.0044),\n",
       "  (' these', 0.0049)],\n",
       " [(' third', 0.9027),\n",
       "  (' second', 0.011),\n",
       "  (' third', 0.9027),\n",
       "  (' 3', 0.0329),\n",
       "  (' first', 0.0285),\n",
       "  (' two', 0.0058)],\n",
       " [(' turn', 0.9889),\n",
       "  (',', 0.0026),\n",
       "  (' turn', 0.9889),\n",
       "  (' and', 0.0068),\n",
       "  (' tur', 0.0002),\n",
       "  (' right', 0.0002)],\n",
       " [(',', 0.9752),\n",
       "  (',', 0.9752),\n",
       "  (' the', 0.0108),\n",
       "  (' and', 0.004),\n",
       "  (' it', 0.0041),\n",
       "  (' (', 0.0011)],\n",
       " [(' the', 0.9212),\n",
       "  (' the', 0.9212),\n",
       "  (' it', 0.0343),\n",
       "  (' we', 0.0215),\n",
       "  (' if', 0.0021),\n",
       "  (' there', 0.0018)],\n",
       " [(' car', 0.9928),\n",
       "  (' car', 0.9928),\n",
       "  (' sum', 0.0003),\n",
       "  (' total', 0.0028),\n",
       "  (' distance', 0.0009),\n",
       "  (' remaining', 0.0002)],\n",
       " [(' has', 0.8906),\n",
       "  (' must', 0.0245),\n",
       "  (' traveled', 0.0132),\n",
       "  (' travels', 0.0261),\n",
       "  (' is', 0.0102),\n",
       "  (' has', 0.8906)],\n",
       " [(' traveled', 0.9641),\n",
       "  (' traveled', 0.9641),\n",
       "  (' travelled', 0.0062),\n",
       "  (' to', 0.0034),\n",
       "  (' gone', 0.0036),\n",
       "  (' driven', 0.0076)],\n",
       " [(' a', 0.213),\n",
       "  (' x', 0.0799),\n",
       "  (' 23', 0.0705),\n",
       "  (' a', 0.213),\n",
       "  (' some', 0.0898),\n",
       "  (' 5', 0.0708)],\n",
       " [(' little', 0.6594),\n",
       "  (' little', 0.6594),\n",
       "  (' certain', 0.0116),\n",
       "  (' further', 0.0204),\n",
       "  (' total', 0.1435),\n",
       "  (' distance', 0.0856)],\n",
       " [(' further', 0.8368),\n",
       "  (' bit', 0.0317),\n",
       "  (' farther', 0.0126),\n",
       "  (' further', 0.8368),\n",
       "  (' longer', 0.0027),\n",
       "  (' more', 0.0977)],\n",
       " [('.', 0.3624),\n",
       "  (',', 0.25),\n",
       "  ('.', 0.3624),\n",
       "  (' and', 0.1033),\n",
       "  (' (', 0.0227),\n",
       "  (' than', 0.1234)],\n",
       " [('\\n', 0.8339),\n",
       "  ('\\n', 0.8339),\n",
       "  (' ', 0.0147),\n",
       "  (' The', 0.0091),\n",
       "  (' Let', 0.0417),\n",
       "  (' We', 0.0254)],\n",
       " [('After', 0.1896),\n",
       "  ('We', 0.0891),\n",
       "  ('If', 0.0541),\n",
       "  ('After', 0.1896),\n",
       "  ('The', 0.1554),\n",
       "  ('Let', 0.0548)],\n",
       " [(' the', 0.8839),\n",
       "  (' traveling', 0.0053),\n",
       "  (' the', 0.8839),\n",
       "  (' all', 0.0502),\n",
       "  (' 4', 0.0063),\n",
       "  (' these', 0.0058)],\n",
       " [(' fourth', 0.7284),\n",
       "  (' third', 0.1024),\n",
       "  (' 3', 0.0104),\n",
       "  (' fourth', 0.7284),\n",
       "  (' 4', 0.0916),\n",
       "  (' first', 0.0155)],\n",
       " [(' turn', 0.963),\n",
       "  (',', 0.0026),\n",
       "  (' turn', 0.963),\n",
       "  (' and', 0.0308),\n",
       "  (' (', 0.0016),\n",
       "  (' right', 0.0004)],\n",
       " [(',', 0.9694),\n",
       "  (',', 0.9694),\n",
       "  (' the', 0.0131),\n",
       "  (' and', 0.0043),\n",
       "  (' it', 0.0067),\n",
       "  (' (', 0.0033)],\n",
       " [(' the', 0.8405),\n",
       "  (' the', 0.8405),\n",
       "  (' it', 0.1384),\n",
       "  (' we', 0.0053),\n",
       "  (' they', 0.001),\n",
       "  (' which', 0.0025)],\n",
       " [(' car', 0.9886),\n",
       "  (' car', 0.9886),\n",
       "  (' tunnel', 0.0011),\n",
       "  (' total', 0.0038),\n",
       "  (' distance', 0.0014),\n",
       "  (' fourth', 0.0005)],\n",
       " [(' exits', 0.2523),\n",
       "  (' travels', 0.0246),\n",
       "  (' exits', 0.2523),\n",
       "  (' is', 0.0274),\n",
       "  (' immediately', 0.2233),\n",
       "  (' has', 0.4041)],\n",
       " [(' the', 0.943),\n",
       "  (',', 0.0084),\n",
       "  ('.', 0.0265),\n",
       "  (' the', 0.943),\n",
       "  (' and', 0.005),\n",
       "  (' immediately', 0.0077)],\n",
       " [(' tunnel', 0.9888),\n",
       "  (' turn', 0.0008),\n",
       "  (' tunnel', 0.9888),\n",
       "  (' total', 0.0002),\n",
       "  (' ring', 0.0087),\n",
       "  (' tun', 0.0002)],\n",
       " [('.', 0.7942),\n",
       "  (',', 0.0889),\n",
       "  ('.', 0.7942),\n",
       "  ('\\n', 0.0102),\n",
       "  (' and', 0.0426),\n",
       "  (' immediately', 0.02)],\n",
       " [('\\n', 0.9687),\n",
       "  (' So', 0.0042),\n",
       "  ('\\n', 0.9687),\n",
       "  (' ', 0.0045),\n",
       "  (' The', 0.0028),\n",
       "  (' This', 0.0037)],\n",
       " [('The', 0.2482),\n",
       "  ('We', 0.0659),\n",
       "  ('If', 0.0613),\n",
       "  ('So', 0.0734),\n",
       "  ('After', 0.0812),\n",
       "  ('The', 0.2482)],\n",
       " [(' total', 0.3567),\n",
       "  (' car', 0.3546),\n",
       "  (' sum', 0.0332),\n",
       "  (' total', 0.3567),\n",
       "  (' distance', 0.0671),\n",
       "  (' fourth', 0.0327)],\n",
       " [(' distance', 0.704),\n",
       "  (' number', 0.0899),\n",
       "  (' amount', 0.03),\n",
       "  (' of', 0.017),\n",
       "  (' length', 0.0968),\n",
       "  (' distance', 0.704)],\n",
       " [(' traveled', 0.4364),\n",
       "  (' around', 0.0487),\n",
       "  (' traveled', 0.4364),\n",
       "  (' the', 0.2463),\n",
       "  (' of', 0.0608),\n",
       "  (' that', 0.0497)],\n",
       " [(' is', 0.2895),\n",
       "  (' around', 0.1977),\n",
       "  (' in', 0.0531),\n",
       "  (' is', 0.2895),\n",
       "  (' by', 0.2961),\n",
       "  (' after', 0.0427)],\n",
       " [(' 5', 0.421),\n",
       "  (' 23', 0.4397),\n",
       "  (' the', 0.0613),\n",
       "  (' therefore', 0.0081),\n",
       "  (' equal', 0.007),\n",
       "  (' 5', 0.421)],\n",
       " [('+', 0.4337),\n",
       "  ('+', 0.4337),\n",
       "  (' meters', 0.0172),\n",
       "  (' +', 0.5388),\n",
       "  (' m', 0.0013),\n",
       "  ('m', 0.0053)],\n",
       " [('8', 0.9956),\n",
       "  ('2', 0.0001),\n",
       "  ('3', 0.0001),\n",
       "  ('5', 0.0015),\n",
       "  ('8', 0.9956),\n",
       "  (' 8', 0.0022)],\n",
       " [('+', 0.8773),\n",
       "  ('+', 0.8773),\n",
       "  (' +', 0.0048),\n",
       "  ('=', 0.0528),\n",
       "  ('+(', 0.0246),\n",
       "  (' =', 0.0294)],\n",
       " [('x', 0.3074),\n",
       "  ('?', 0.0622),\n",
       "  ('A', 0.0175),\n",
       "  ('X', 0.0468),\n",
       "  ('a', 0.2171),\n",
       "  ('x', 0.3074)],\n",
       " [('+', 0.9733),\n",
       "  ('+', 0.9733),\n",
       "  (' +', 0.0022),\n",
       "  ('=', 0.0053),\n",
       "  ('+(', 0.0078),\n",
       "  (' =', 0.0047)],\n",
       " [('y', 0.2355),\n",
       "  ('1', 0.1018),\n",
       "  ('4', 0.0853),\n",
       "  ('23', 0.0724),\n",
       "  ('5', 0.0868),\n",
       "  ('y', 0.2355)],\n",
       " [(' =', 0.4029),\n",
       "  (',', 0.0553),\n",
       "  ('\\n', 0.0468),\n",
       "  ('=', 0.2697),\n",
       "  (' =', 0.4029),\n",
       "  (' where', 0.1437)],\n",
       " [(' 23', 0.9786),\n",
       "  ('23', 0.012),\n",
       "  (' 23', 0.9786),\n",
       "  (' (', 0.0004),\n",
       "  (' 4', 0.0005),\n",
       "  (' 5', 0.0026)],\n",
       " [(' where', 0.1655),\n",
       "  (' meters', 0.1208),\n",
       "  (',', 0.1268),\n",
       "  ('.', 0.158),\n",
       "  ('\\n', 0.3792),\n",
       "  (' where', 0.1655)],\n",
       " [(' x', 0.9035),\n",
       "  (' X', 0.0065),\n",
       "  (' x', 0.9035),\n",
       "  (' the', 0.0096),\n",
       "  (' y', 0.0509),\n",
       "  (' 5', 0.0055)],\n",
       " [(' is', 0.6525),\n",
       "  ('+', 0.0165),\n",
       "  (' and', 0.2144),\n",
       "  (' is', 0.6525),\n",
       "  (' represents', 0.045),\n",
       "  (' =', 0.0228)],\n",
       " [(' the', 0.9091),\n",
       "  (' a', 0.0137),\n",
       "  (' the', 0.9091),\n",
       "  (' distance', 0.0189),\n",
       "  (' how', 0.0181),\n",
       "  (' after', 0.0073)],\n",
       " [(' distance', 0.8755),\n",
       "  (' little', 0.0094),\n",
       "  (' additional', 0.0118),\n",
       "  (' length', 0.0189),\n",
       "  (' distance', 0.8755),\n",
       "  (' unknown', 0.0222)],\n",
       " [(' traveled', 0.645),\n",
       "  (' traveled', 0.645),\n",
       "  (' travelled', 0.0199),\n",
       "  (' the', 0.0321),\n",
       "  (' covered', 0.0178),\n",
       "  (' after', 0.2074)],\n",
       " [(' after', 0.881),\n",
       "  (' between', 0.0159),\n",
       "  (' in', 0.0372),\n",
       "  (' by', 0.0102),\n",
       "  (' after', 0.881),\n",
       "  (' before', 0.0094)],\n",
       " [(' the', 0.9601),\n",
       "  (' turn', 0.0171),\n",
       "  (' third', 0.0075),\n",
       "  (' the', 0.9601),\n",
       "  (' 3', 0.0095),\n",
       "  (' turning', 0.0024)],\n",
       " [(' third', 0.8003),\n",
       "  (' second', 0.0011),\n",
       "  (' third', 0.8003),\n",
       "  (' 3', 0.1904),\n",
       "  (' fourth', 0.0008),\n",
       "  (' first', 0.003)],\n",
       " [(' turn', 0.9877),\n",
       "  (',', 0.0003),\n",
       "  (' turn', 0.9877),\n",
       "  (' left', 0.0004),\n",
       "  (' and', 0.0023),\n",
       "  (' right', 0.0065)],\n",
       " [(' and', 0.7867),\n",
       "  (',', 0.1439),\n",
       "  ('.', 0.0466),\n",
       "  ('\\n', 0.0107),\n",
       "  (' and', 0.7867),\n",
       "  (' (', 0.002)],\n",
       " [(' y', 0.9936),\n",
       "  (' x', 0.0007),\n",
       "  (' the', 0.001),\n",
       "  (' is', 0.0007),\n",
       "  (' y', 0.9936),\n",
       "  (' where', 0.0006)],\n",
       " [(' is', 0.9715),\n",
       "  (' the', 0.0066),\n",
       "  (' is', 0.9715),\n",
       "  (' represents', 0.0053),\n",
       "  (' after', 0.0046),\n",
       "  (' =', 0.0027)],\n",
       " [(' the', 0.9681),\n",
       "  (' a', 0.006),\n",
       "  (' the', 0.9681),\n",
       "  (' distance', 0.0089),\n",
       "  (' 0', 0.0023),\n",
       "  (' how', 0.0034)],\n",
       " [(' distance', 0.9647),\n",
       "  (' final', 0.002),\n",
       "  (' total', 0.0041),\n",
       "  (' length', 0.0094),\n",
       "  (' distance', 0.9647),\n",
       "  (' remaining', 0.002)],\n",
       " [(' traveled', 0.8817),\n",
       "  (' traveled', 0.8817),\n",
       "  (' the', 0.0116),\n",
       "  (' to', 0.0186),\n",
       "  (' from', 0.0134),\n",
       "  (' after', 0.0188)],\n",
       " [(' after', 0.8213),\n",
       "  (' to', 0.0246),\n",
       "  (' in', 0.0177),\n",
       "  (' immediately', 0.0159),\n",
       "  (' after', 0.8213),\n",
       "  (' before', 0.0211)],\n",
       " [(' the', 0.9799),\n",
       "  (' making', 0.0016),\n",
       "  (' the', 0.9799),\n",
       "  (' exiting', 0.0092),\n",
       "  (' fourth', 0.0021),\n",
       "  (' turning', 0.0024)],\n",
       " [(' fourth', 0.932),\n",
       "  (' third', 0.0117),\n",
       "  (' fourth', 0.932),\n",
       "  (' 4', 0.0325),\n",
       "  (' forth', 0.006),\n",
       "  (' last', 0.0043)],\n",
       " [(' turn', 0.9717),\n",
       "  (',', 0.001),\n",
       "  (' turn', 0.9717),\n",
       "  ('.', 0.0107),\n",
       "  (' and', 0.0091),\n",
       "  (' (', 0.0031)],\n",
       " [('.', 0.8814),\n",
       "  (',', 0.0115),\n",
       "  ('.', 0.8814),\n",
       "  ('\\n', 0.0556),\n",
       "  (' to', 0.0073),\n",
       "  (' (', 0.0097)],\n",
       " [('\\n', 0.9609),\n",
       "  (' So', 0.002),\n",
       "  ('\\n', 0.9609),\n",
       "  (' ', 0.0079),\n",
       "  (' Since', 0.0031),\n",
       "  (' We', 0.0034)],\n",
       " [('We', 0.1298),\n",
       "  ('We', 0.1298),\n",
       "  ('So', 0.0722),\n",
       "  ('Therefore', 0.0815),\n",
       "  ('The', 0.0742),\n",
       "  ('Since', 0.1)],\n",
       " [(' can', 0.2106),\n",
       "  (' are', 0.0522),\n",
       "  (' have', 0.0237),\n",
       "  (' can', 0.2106),\n",
       "  (' also', 0.0373),\n",
       "  (' know', 0.5421)],\n",
       " [(' solve', 0.1084),\n",
       "  (' substitute', 0.0976),\n",
       "  (' rewrite', 0.0932),\n",
       "  (' simplify', 0.0971),\n",
       "  (' write', 0.0556),\n",
       "  (' solve', 0.1084)],\n",
       " [(' for', 0.6128),\n",
       "  (' x', 0.0064),\n",
       "  (' the', 0.0823),\n",
       "  (' for', 0.6128),\n",
       "  (' by', 0.0095),\n",
       "  (' this', 0.253)],\n",
       " [(' y', 0.1778),\n",
       "  (' both', 0.0044),\n",
       "  (' x', 0.738),\n",
       "  (' the', 0.0569),\n",
       "  (' y', 0.1778),\n",
       "  (' this', 0.0042)],\n",
       " [(' by', 0.3364),\n",
       "  (' using', 0.0832),\n",
       "  (':', 0.0681),\n",
       "  (' in', 0.0649),\n",
       "  (' as', 0.0796),\n",
       "  (' by', 0.3364)],\n",
       " [(' substit', 0.095),\n",
       "  (' noting', 0.0588),\n",
       "  (' using', 0.0684),\n",
       "  (' realizing', 0.078),\n",
       "  (' substit', 0.095),\n",
       "  (' subtract', 0.1014)],\n",
       " [('uting', 0.9996),\n",
       "  ('uating', 0.0001),\n",
       "  ('uting', 0.9996),\n",
       "  ('izing', 0.0001),\n",
       "  ('uing', 0.0001),\n",
       "  ('u', 0.0)],\n",
       " [(' in', 0.1885),\n",
       "  (' x', 0.2447),\n",
       "  (' the', 0.1475),\n",
       "  (' in', 0.1885),\n",
       "  (' y', 0.0598),\n",
       "  (' 5', 0.0977)],\n",
       " [(' the', 0.3178),\n",
       "  (' x', 0.1839),\n",
       "  (' the', 0.3178),\n",
       "  (' for', 0.1001),\n",
       "  (' 5', 0.1138),\n",
       "  (' our', 0.0599)],\n",
       " [(' values', 0.2491),\n",
       "  (' equation', 0.0437),\n",
       "  (' known', 0.074),\n",
       "  (' value', 0.1427),\n",
       "  (' total', 0.0415),\n",
       "  (' values', 0.2491)],\n",
       " [(' of', 0.167),\n",
       "  (' of', 0.167),\n",
       "  (' that', 0.0192),\n",
       "  (' for', 0.4064),\n",
       "  (' we', 0.267),\n",
       "  (' from', 0.0328)],\n",
       " [(' the', 0.2869),\n",
       "  (' x', 0.1024),\n",
       "  (' the', 0.2869),\n",
       "  (' y', 0.0056),\n",
       "  (' all', 0.0101),\n",
       "  (' 5', 0.5449)],\n",
       " [(' other', 0.3718),\n",
       "  (' distances', 0.0902),\n",
       "  (' known', 0.0468),\n",
       "  (' previous', 0.0786),\n",
       "  (' other', 0.3718),\n",
       "  (' first', 0.2601)],\n",
       " [(' variables', 0.3862),\n",
       "  (' three', 0.0853),\n",
       "  (' distances', 0.2386),\n",
       "  (' terms', 0.0775),\n",
       "  (' 3', 0.0314),\n",
       "  (' variables', 0.3862)],\n",
       " [(':', 0.2305),\n",
       "  (',', 0.1008),\n",
       "  ('.', 0.3681),\n",
       "  (':', 0.2305),\n",
       "  (' to', 0.0392),\n",
       "  (' and', 0.063)],\n",
       " [(' y', 0.1282),\n",
       "  ('\\n', 0.5155),\n",
       "  (' x', 0.0161),\n",
       "  (' 23', 0.1611),\n",
       "  (' y', 0.1282),\n",
       "  (' 5', 0.1515)],\n",
       " [(' =', 0.7955),\n",
       "  ('+', 0.0096),\n",
       "  (' +', 0.0081),\n",
       "  ('=(', 0.0109),\n",
       "  ('=', 0.1711),\n",
       "  (' =', 0.7955)],\n",
       " [(' 23', 0.9339),\n",
       "  (' x', 0.0041),\n",
       "  (' 23', 0.9339),\n",
       "  (' (', 0.0172),\n",
       "  (' -', 0.0021),\n",
       "  (' 5', 0.0261)],\n",
       " [(' -', 0.5007),\n",
       "  ('-', 0.2933),\n",
       "  ('-(', 0.1906),\n",
       "  (' -', 0.5007),\n",
       "  (' –', 0.0067),\n",
       "  (' −', 0.0036)],\n",
       " [(' (', 0.3551),\n",
       "  (' 13', 0.098),\n",
       "  ('5', 0.0064),\n",
       "  (' (', 0.3551),\n",
       "  (' 5', 0.5026),\n",
       "  (' 8', 0.0229)],\n",
       " [('5', 0.9571),\n",
       "  ('13', 0.0034),\n",
       "  ('5', 0.9571),\n",
       "  ('8', 0.0255),\n",
       "  (' 5', 0.0059),\n",
       "  ('x', 0.0058)],\n",
       " [('+', 0.7453),\n",
       "  ('+', 0.7453),\n",
       "  ('-', 0.0001),\n",
       "  (' +', 0.254),\n",
       "  (' -', 0.0004),\n",
       "  (')', 0.0001)],\n",
       " [('8', 0.9984),\n",
       "  ('3', 0.0001),\n",
       "  ('8', 0.9984),\n",
       "  (' 8', 0.0005),\n",
       "  ('x', 0.0007),\n",
       "  ('y', 0.0001)],\n",
       " [('+', 0.985),\n",
       "  ('+', 0.985),\n",
       "  ('-', 0.0005),\n",
       "  (')-', 0.0014),\n",
       "  (' +', 0.0013),\n",
       "  (')', 0.0111)],\n",
       " [('x', 0.9954),\n",
       "  ('3', 0.0002),\n",
       "  ('X', 0.0004),\n",
       "  ('x', 0.9954),\n",
       "  ('y', 0.0021),\n",
       "  ('z', 0.0006)],\n",
       " [(').', 0.2425),\n",
       "  (');', 0.0017),\n",
       "  (')=', 0.0041),\n",
       "  (').', 0.2425),\n",
       "  (')', 0.7395),\n",
       "  ('),', 0.0112)],\n",
       " [('\\n', 0.9478),\n",
       "  ('\\n', 0.9478),\n",
       "  (' ', 0.0134),\n",
       "  (' Since', 0.0028),\n",
       "  (' This', 0.0028),\n",
       "  (' We', 0.005)],\n",
       " [('The', 0.1393),\n",
       "  ('We', 0.1912),\n",
       "  ('So', 0.0601),\n",
       "  ('The', 0.1393),\n",
       "  ('Since', 0.0645),\n",
       "  ('Then', 0.0711)],\n",
       " [(' question', 0.093),\n",
       "  (' car', 0.1413),\n",
       "  (' question', 0.093),\n",
       "  (' total', 0.1409),\n",
       "  (' answer', 0.2385),\n",
       "  (' distance', 0.1962)],\n",
       " [(' is', 0.2098),\n",
       "  (' asked', 0.0267),\n",
       "  (' is', 0.2098),\n",
       "  (' wants', 0.0182),\n",
       "  (' then', 0.0137),\n",
       "  (' asks', 0.6475)],\n",
       " [(' asking', 0.7364),\n",
       "  (',', 0.0365),\n",
       "  (':', 0.02),\n",
       "  (' to', 0.0306),\n",
       "  (' asking', 0.7364),\n",
       "  (' how', 0.0643)],\n",
       " [(' how', 0.1263),\n",
       "  (' for', 0.518),\n",
       "  (' us', 0.2094),\n",
       "  (' about', 0.0395),\n",
       "  (' what', 0.037),\n",
       "  (' how', 0.1263)],\n",
       " [(' far', 0.8653),\n",
       "  (' far', 0.8653),\n",
       "  (' to', 0.0044),\n",
       "  (' many', 0.0147),\n",
       "  (' much', 0.0711),\n",
       "  (' long', 0.0394)],\n",
       " [(' the', 0.8024),\n",
       "  (' x', 0.0229),\n",
       "  (' the', 0.8024),\n",
       "  (' it', 0.0282),\n",
       "  (' was', 0.0095),\n",
       "  (' did', 0.1032)],\n",
       " [(' car', 0.9861),\n",
       "  (' car', 0.9861),\n",
       "  (' third', 0.0005),\n",
       "  (' vehicle', 0.0012),\n",
       "  (' distance', 0.0089),\n",
       "  (' fourth', 0.0004)],\n",
       " [(' traveled', 0.4952),\n",
       "  (' must', 0.025),\n",
       "  (' traveled', 0.4952),\n",
       "  (' travels', 0.1003),\n",
       "  (' has', 0.1575),\n",
       "  (' had', 0.1539)],\n",
       " [(' after', 0.9573),\n",
       "  (' between', 0.0071),\n",
       "  (' x', 0.0026),\n",
       "  (' in', 0.0065),\n",
       "  (' from', 0.0044),\n",
       "  (' after', 0.9573)],\n",
       " [(' the', 0.9781),\n",
       "  (' turn', 0.0036),\n",
       "  (' the', 0.9781),\n",
       "  (' it', 0.0038),\n",
       "  (' turning', 0.0024),\n",
       "  (' its', 0.004)],\n",
       " [(' third', 0.941),\n",
       "  (' second', 0.0007),\n",
       "  (' third', 0.941),\n",
       "  (' 3', 0.0459),\n",
       "  (' fourth', 0.007),\n",
       "  (' 4', 0.0011)],\n",
       " [(' turn', 0.9963),\n",
       "  (',', 0.0004),\n",
       "  (' turn', 0.9963),\n",
       "  (' and', 0.0005),\n",
       "  (' tun', 0.0001),\n",
       "  (' right', 0.0008)],\n",
       " [(',', 0.5685),\n",
       "  (',', 0.5685),\n",
       "  ('.', 0.2588),\n",
       "  (' (', 0.0392),\n",
       "  (' so', 0.0645),\n",
       "  (' which', 0.0207)],\n",
       " [(' so', 0.6289),\n",
       "  (' x', 0.0403),\n",
       "  (' meaning', 0.0112),\n",
       "  (' or', 0.044),\n",
       "  (' so', 0.6289),\n",
       "  (' which', 0.1978)],\n",
       " [(' we', 0.6957),\n",
       "  (' let', 0.0262),\n",
       "  (' x', 0.0499),\n",
       "  (' the', 0.105),\n",
       "  (' we', 0.6957),\n",
       "  (' our', 0.0315)],\n",
       " [(' want', 0.2182),\n",
       "  (' are', 0.1039),\n",
       "  (' can', 0.273),\n",
       "  (' need', 0.1273),\n",
       "  (' want', 0.2182),\n",
       "  (' solve', 0.0346)],\n",
       " [(' to', 0.9021),\n",
       "  (' x', 0.036),\n",
       "  (' the', 0.0489),\n",
       "  (' an', 0.0022),\n",
       "  (' to', 0.9021),\n",
       "  (' our', 0.0015)],\n",
       " [(' solve', 0.5754),\n",
       "  (' find', 0.1591),\n",
       "  (' calculate', 0.0121),\n",
       "  (' isolate', 0.0178),\n",
       "  (' know', 0.1751),\n",
       "  (' solve', 0.5754)],\n",
       " [(' for', 0.9764),\n",
       "  (' x', 0.0039),\n",
       "  (' the', 0.0064),\n",
       "  (' for', 0.9764),\n",
       "  (' this', 0.006),\n",
       "  (' only', 0.0017)],\n",
       " [(' x', 0.931),\n",
       "  (' x', 0.931),\n",
       "  (' the', 0.0456),\n",
       "  (' y', 0.0058),\n",
       "  (' \"', 0.0023),\n",
       "  (' variable', 0.0028)],\n",
       " [('.', 0.702),\n",
       "  (',', 0.048),\n",
       "  ('.', 0.702),\n",
       "  (':', 0.1007),\n",
       "  (' in', 0.0634),\n",
       "  (' and', 0.0146)],\n",
       " [('\\n', 0.7637),\n",
       "  (' To', 0.018),\n",
       "  ('\\n', 0.7637),\n",
       "  (' x', 0.0108),\n",
       "  (' ', 0.026),\n",
       "  (' We', 0.0679)],\n",
       " [('We', 0.2527),\n",
       "  ('We', 0.2527),\n",
       "  ('To', 0.0463),\n",
       "  ('The', 0.0722),\n",
       "  ('Sub', 0.0421),\n",
       "  ('x', 0.1457)],\n",
       " [(' can', 0.6018),\n",
       "  (' substitute', 0.0169),\n",
       "  (' have', 0.0487),\n",
       "  (' can', 0.6018),\n",
       "  (' do', 0.0157),\n",
       "  (' know', 0.1714)],\n",
       " [(' substitute', 0.2195),\n",
       "  (' substitute', 0.2195),\n",
       "  (' do', 0.1789),\n",
       "  (' plug', 0.0407),\n",
       "  (' use', 0.0524),\n",
       "  (' solve', 0.2618)],\n",
       " [(' in', 0.357),\n",
       "  (' 23', 0.0394),\n",
       "  (' the', 0.1536),\n",
       "  (' in', 0.357),\n",
       "  (' for', 0.0472),\n",
       "  (' y', 0.2981)],\n",
       " [(' the', 0.4941),\n",
       "  (' 23', 0.0327),\n",
       "  (' the', 0.4941),\n",
       "  (' for', 0.1424),\n",
       "  (' y', 0.2017),\n",
       "  (' our', 0.0676)],\n",
       " [(' value', 0.7311),\n",
       "  (' equation', 0.0315),\n",
       "  (' value', 0.7311),\n",
       "  (' y', 0.0258),\n",
       "  (' values', 0.0664),\n",
       "  (' expression', 0.0656)],\n",
       " [(' of', 0.7701),\n",
       "  (' of', 0.7701),\n",
       "  (' that', 0.0033),\n",
       "  (' for', 0.1625),\n",
       "  (' y', 0.0073),\n",
       "  (' we', 0.0479)],\n",
       " [(' y', 0.9863),\n",
       "  (' x', 0.0029),\n",
       "  (' 23', 0.0027),\n",
       "  (' the', 0.0045),\n",
       "  (' y', 0.9863),\n",
       "  (' 5', 0.0006)],\n",
       " [(' to', 0.2515),\n",
       "  (' to', 0.2515),\n",
       "  (' in', 0.0647),\n",
       "  (' and', 0.0819),\n",
       "  (' that', 0.0636),\n",
       "  (' into', 0.2117)],\n",
       " [(' solve', 0.4924),\n",
       "  (' find', 0.0849),\n",
       "  (' the', 0.0234),\n",
       "  (' simplify', 0.0171),\n",
       "  (' get', 0.2687),\n",
       "  (' solve', 0.4924)],\n",
       " [(' for', 0.9323),\n",
       "  (' x', 0.0046),\n",
       "  (':', 0.0398),\n",
       "  (' the', 0.0072),\n",
       "  (' for', 0.9323),\n",
       "  (' this', 0.0061)],\n",
       " [(' x', 0.9791),\n",
       "  (' x', 0.9791),\n",
       "  (' the', 0.0071),\n",
       "  (' y', 0.0049),\n",
       "  (' it', 0.0025),\n",
       "  (' this', 0.0028)],\n",
       " [(':', 0.7158),\n",
       "  (',', 0.0425),\n",
       "  ('.', 0.1273),\n",
       "  (':', 0.7158),\n",
       "  (' in', 0.0183),\n",
       "  (' by', 0.0264)],\n",
       " [(' 23', 0.1438),\n",
       "  ('\\n', 0.1427),\n",
       "  (' x', 0.6566),\n",
       "  (' ', 0.0113),\n",
       "  (' 23', 0.1438),\n",
       "  (' 5', 0.0211)],\n",
       " [(' -', 0.7008),\n",
       "  ('-', 0.0352),\n",
       "  ('=', 0.0034),\n",
       "  ('-(', 0.0548),\n",
       "  (' -', 0.7008),\n",
       "  (' =', 0.2015)],\n",
       " [(' (', 0.9671),\n",
       "  (' 13', 0.0058),\n",
       "  (' x', 0.0021),\n",
       "  (' (', 0.9671),\n",
       "  (' 5', 0.0183),\n",
       "  ('(', 0.0032)],\n",
       " [('5', 0.9935),\n",
       "  ('13', 0.0021),\n",
       "  ('5', 0.9935),\n",
       "  ('8', 0.001),\n",
       "  (' 5', 0.0006),\n",
       "  ('x', 0.0019)],\n",
       " [('+', 0.9697),\n",
       "  ('+', 0.9697),\n",
       "  ('-', 0.0001),\n",
       "  (' +', 0.0301),\n",
       "  ('+(', 0.0),\n",
       "  (' -', 0.0)],\n",
       " [('8', 0.9992),\n",
       "  ('8', 0.9992),\n",
       "  ('9', 0.0),\n",
       "  (' 8', 0.0002),\n",
       "  ('x', 0.0004),\n",
       "  ('y', 0.0001)],\n",
       " [('+', 0.9139),\n",
       "  ('+', 0.9139),\n",
       "  (' +', 0.0048),\n",
       "  ('+(', 0.0656),\n",
       "  (')+', 0.0004),\n",
       "  (')', 0.0146)],\n",
       " [('x', 0.9886),\n",
       "  ('23', 0.0002),\n",
       "  (' x', 0.0003),\n",
       "  ('X', 0.0002),\n",
       "  ('x', 0.9886),\n",
       "  ('y', 0.0101)],\n",
       " [(')', 0.9872),\n",
       "  ('+', 0.0012),\n",
       "  ('-', 0.0005),\n",
       "  (')-', 0.0014),\n",
       "  (')=', 0.0081),\n",
       "  (')', 0.9872)],\n",
       " [(' =', 0.9795),\n",
       "  (' +', 0.0017),\n",
       "  ('\\n', 0.0037),\n",
       "  (' is', 0.0004),\n",
       "  (' -', 0.0121),\n",
       "  (' =', 0.9795)],\n",
       " [(' y', 0.6563),\n",
       "  (' x', 0.0576),\n",
       "  (' 23', 0.2264),\n",
       "  (' y', 0.6563),\n",
       "  (' 5', 0.0209),\n",
       "  (' 0', 0.0103)],\n",
       " [('.', 0.425),\n",
       "  (',', 0.0776),\n",
       "  ('.', 0.425),\n",
       "  ('\\n', 0.3216),\n",
       "  (' and', 0.0196),\n",
       "  (' =', 0.0368)],\n",
       " [('\\n', 0.9272),\n",
       "  ('\\n', 0.9272),\n",
       "  (' ', 0.0107),\n",
       "  (' Then', 0.0134),\n",
       "  (' This', 0.0054),\n",
       "  (' We', 0.0069)],\n",
       " [('We', 0.1433),\n",
       "  ('We', 0.1433),\n",
       "  ('23', 0.1189),\n",
       "  ('The', 0.0496),\n",
       "  ('Then', 0.115),\n",
       "  ('x', 0.0496)],\n",
       " [(' can', 0.5705),\n",
       "  (' can', 0.5705),\n",
       "  (' get', 0.0256),\n",
       "  (' know', 0.1382),\n",
       "  (' want', 0.0204),\n",
       "  (' then', 0.0517)],\n",
       " [(' then', 0.1795),\n",
       "  (' simplify', 0.1327),\n",
       "  (' subtract', 0.0762),\n",
       "  (' rearr', 0.1021),\n",
       "  (' then', 0.1795),\n",
       "  (' solve', 0.0716)],\n",
       " [(' solve', 0.1953),\n",
       "  (' simplify', 0.0896),\n",
       "  (' subtract', 0.1409),\n",
       "  (' rearr', 0.1427),\n",
       "  (' add', 0.0794),\n",
       "  (' solve', 0.1953)],\n",
       " [(' for', 0.845),\n",
       "  (':', 0.0108),\n",
       "  (' the', 0.0402),\n",
       "  (' for', 0.845),\n",
       "  (' by', 0.0119),\n",
       "  (' this', 0.0506)],\n",
       " [(' x', 0.9571),\n",
       "  (' x', 0.9571),\n",
       "  (' 23', 0.0049),\n",
       "  (' the', 0.0164),\n",
       "  (' y', 0.0139),\n",
       "  (' (', 0.0013)],\n",
       " [(' by', 0.4551),\n",
       "  (' using', 0.0438),\n",
       "  (':', 0.3335),\n",
       "  (' to', 0.0251),\n",
       "  (' in', 0.0222),\n",
       "  (' by', 0.4551)],\n",
       " [(' subtract', 0.3846),\n",
       "  (' subtract', 0.3846),\n",
       "  (' rearr', 0.0908),\n",
       "  (' moving', 0.0774),\n",
       "  (' adding', 0.1186),\n",
       "  (' isol', 0.0519)],\n",
       " [('ing', 0.999),\n",
       "  (' 23', 0.0003),\n",
       "  (' the', 0.0001),\n",
       "  ('ing', 0.999),\n",
       "  (' y', 0.0001),\n",
       "  (' 5', 0.0001)],\n",
       " [(' 8', 0.2156),\n",
       "  (' 23', 0.2065),\n",
       "  (' the', 0.0544),\n",
       "  (' y', 0.0567),\n",
       "  (' 5', 0.27),\n",
       "  (' 8', 0.2156)],\n",
       " [(' and', 0.464),\n",
       "  ('+', 0.0313),\n",
       "  (',', 0.134),\n",
       "  (' +', 0.008),\n",
       "  (' and', 0.464),\n",
       "  (' from', 0.3313)],\n",
       " [(' 5', 0.9001),\n",
       "  (' 23', 0.0338),\n",
       "  (' subtract', 0.0127),\n",
       "  (' adding', 0.0227),\n",
       "  (' 5', 0.9001),\n",
       "  (' then', 0.0147)],\n",
       " [(' from', 0.8415),\n",
       "  (',', 0.0183),\n",
       "  (':', 0.0257),\n",
       "  (' and', 0.0449),\n",
       "  (' on', 0.0378),\n",
       "  (' from', 0.8415)],\n",
       " [(' both', 0.8638),\n",
       "  (' both', 0.8638),\n",
       "  (' each', 0.0805),\n",
       "  (' either', 0.0015),\n",
       "  (' 23', 0.0314),\n",
       "  (' the', 0.0186)],\n",
       " [(' sides', 0.9964),\n",
       "  (' side', 0.0022),\n",
       "  (' the', 0.0004),\n",
       "  (' sides', 0.9964),\n",
       "  (' ends', 0.0002),\n",
       "  ('s', 0.0002)],\n",
       " [(' of', 0.3453),\n",
       "  (',', 0.0728),\n",
       "  ('.', 0.047),\n",
       "  (':', 0.3364),\n",
       "  (' of', 0.3453),\n",
       "  (' and', 0.1431)],\n",
       " [(' the', 0.9647),\n",
       "  (' 23', 0.0012),\n",
       "  (' the', 0.9647),\n",
       "  (' that', 0.0034),\n",
       "  (' this', 0.0268),\n",
       "  (' our', 0.0015)],\n",
       " [(' equation', 0.9578),\n",
       "  (' equality', 0.0068),\n",
       "  (' equation', 0.9578),\n",
       "  (' above', 0.0095),\n",
       "  (' equal', 0.0073),\n",
       "  (' expression', 0.0044)],\n",
       " [('.', 0.1364),\n",
       "  (',', 0.0965),\n",
       "  ('.', 0.1364),\n",
       "  (':', 0.5292),\n",
       "  (' to', 0.0398),\n",
       "  (' and', 0.1595)],\n",
       " [('\\n', 0.8473),\n",
       "  ('\\n', 0.8473),\n",
       "  (' 23', 0.0178),\n",
       "  (' The', 0.0117),\n",
       "  (' This', 0.0363),\n",
       "  (' We', 0.0185)],\n",
       " [('23', 0.2559),\n",
       "  ('We', 0.0908),\n",
       "  ('This', 0.1192),\n",
       "  ('23', 0.2559),\n",
       "  ('The', 0.1235),\n",
       "  ('x', 0.0886)],\n",
       " [(' -', 0.8409),\n",
       "  ('-', 0.1065),\n",
       "  ('-(', 0.0399),\n",
       "  (' -', 0.8409),\n",
       "  (' –', 0.0025),\n",
       "  (' =', 0.0065)],\n",
       " [(' (', 0.6173),\n",
       "  (' 13', 0.1093),\n",
       "  (' y', 0.005),\n",
       "  (' (', 0.6173),\n",
       "  (' 5', 0.1315),\n",
       "  (' 8', 0.118)],\n",
       " [('5', 0.9579),\n",
       "  ('13', 0.0063),\n",
       "  ('5', 0.9579),\n",
       "  ('8', 0.0309),\n",
       "  ('x', 0.0017),\n",
       "  ('y', 0.0008)],\n",
       " [('+', 0.9656),\n",
       "  ('+', 0.9656),\n",
       "  (')-', 0.0001),\n",
       "  (' +', 0.0328),\n",
       "  (' -', 0.0002),\n",
       "  (')', 0.0011)],\n",
       " [('8', 0.9994),\n",
       "  ('5', 0.0001),\n",
       "  ('8', 0.9994),\n",
       "  (' 8', 0.0001),\n",
       "  ('x', 0.0002),\n",
       "  ('y', 0.0)],\n",
       " [('+', 0.9298),\n",
       "  ('+', 0.9298),\n",
       "  (')-', 0.0013),\n",
       "  (' +', 0.0012),\n",
       "  (')+', 0.0003),\n",
       "  (')', 0.0667)],\n",
       " [('x', 0.9992),\n",
       "  ('23', 0.0),\n",
       "  (' x', 0.0001),\n",
       "  ('X', 0.0002),\n",
       "  ('x', 0.9992),\n",
       "  ('y', 0.0002)],\n",
       " [(')', 0.9915),\n",
       "  ('+', 0.0003),\n",
       "  ('-', 0.0004),\n",
       "  (')-', 0.0035),\n",
       "  (')=', 0.0035),\n",
       "  (')', 0.9915)],\n",
       " [(' =', 0.9013),\n",
       "  (' +', 0.0016),\n",
       "  ('\\n', 0.0014),\n",
       "  (' becomes', 0.0025),\n",
       "  (' -', 0.088),\n",
       "  (' =', 0.9013)],\n",
       " [(' y', 0.8861),\n",
       "  (' x', 0.0077),\n",
       "  (' 23', 0.0827),\n",
       "  (' y', 0.8861),\n",
       "  (' (', 0.0035),\n",
       "  (' 5', 0.0089)],\n",
       " [('\\n', 0.6231),\n",
       "  (',', 0.0302),\n",
       "  ('.', 0.141),\n",
       "  (' -->', 0.0125),\n",
       "  ('\\n', 0.6231),\n",
       "  (' becomes', 0.0689)],\n",
       " [('23', 0.8027),\n",
       "  ('-', 0.044),\n",
       "  ('18', 0.0258),\n",
       "  ('23', 0.8027),\n",
       "  ('8', 0.011),\n",
       "  ('x', 0.0173)],\n",
       " [(' -', 0.9257),\n",
       "  ('-', 0.0522),\n",
       "  ('-(', 0.0176),\n",
       "  (' -', 0.9257),\n",
       "  (' –', 0.0017),\n",
       "  (' =', 0.0015)],\n",
       " [(' 13', 0.4635),\n",
       "  ('13', 0.0066),\n",
       "  (' 13', 0.4635),\n",
       "  (' (', 0.3106),\n",
       "  (' 5', 0.078),\n",
       "  (' 8', 0.1106)],\n",
       " [(' -', 0.9605),\n",
       "  ('-', 0.0099),\n",
       "  (' +', 0.0027),\n",
       "  (' -', 0.9605),\n",
       "  (' –', 0.0009),\n",
       "  (' =', 0.0245)],\n",
       " [(' x', 0.9474),\n",
       "  (' X', 0.0008),\n",
       "  (' x', 0.9474),\n",
       "  (' (', 0.0023),\n",
       "  (' 5', 0.0005),\n",
       "  ('x', 0.0474)],\n",
       " [(' =', 0.9955),\n",
       "  ('\\n', 0.0002),\n",
       "  (' ', 0.0005),\n",
       "  ('=', 0.0027),\n",
       "  (' -', 0.0008),\n",
       "  (' =', 0.9955)],\n",
       " [(' y', 0.9939),\n",
       "  (' x', 0.0015),\n",
       "  (' 23', 0.0011),\n",
       "  (' y', 0.9939),\n",
       "  (' -', 0.0014),\n",
       "  ('y', 0.0009)],\n",
       " [('\\n', 0.9722),\n",
       "  (',', 0.0025),\n",
       "  ('.', 0.0122),\n",
       "  ('\\n', 0.9722),\n",
       "  (' ', 0.0021),\n",
       "  (' (', 0.0037)],\n",
       " [('10', 0.5741),\n",
       "  ('-', 0.0156),\n",
       "  ('23', 0.2795),\n",
       "  ('x', 0.057),\n",
       "  ('y', 0.0252),\n",
       "  ('10', 0.5741)],\n",
       " [(' -', 0.9652),\n",
       "  ('-', 0.0294),\n",
       "  (' +', 0.0002),\n",
       "  (' -', 0.9652),\n",
       "  (' –', 0.0009),\n",
       "  (' =', 0.0037)],\n",
       " [(' x', 0.9868),\n",
       "  (' X', 0.0008),\n",
       "  (' 13', 0.0002),\n",
       "  (' x', 0.9868),\n",
       "  (' y', 0.0007),\n",
       "  ('x', 0.011)],\n",
       " [(' =', 0.9976),\n",
       "  ('\\n', 0.0001),\n",
       "  (' ', 0.0003),\n",
       "  ('=', 0.001),\n",
       "  (' -', 0.0008),\n",
       "  (' =', 0.9976)],\n",
       " [(' y', 0.9972),\n",
       "  (' x', 0.0004),\n",
       "  (' 23', 0.0007),\n",
       "  (' y', 0.9972),\n",
       "  (' -', 0.0004),\n",
       "  ('y', 0.0007)],\n",
       " [('\\n', 0.9444),\n",
       "  (',', 0.0022),\n",
       "  ('.', 0.0465),\n",
       "  ('\\n', 0.9444),\n",
       "  (' ', 0.0013),\n",
       "  (' (', 0.0012)],\n",
       " [('10', 0.1753),\n",
       "  ('We', 0.1768),\n",
       "  ('Then', 0.051),\n",
       "  ('Sub', 0.0547),\n",
       "  ('x', 0.1505),\n",
       "  ('10', 0.1753)],\n",
       " [(' -', 0.7789),\n",
       "  ('-', 0.0207),\n",
       "  (' +', 0.0028),\n",
       "  ('=', 0.0008),\n",
       "  (' -', 0.7789),\n",
       "  (' =', 0.1937)],\n",
       " [(' x', 0.8971),\n",
       "  (' x', 0.8971),\n",
       "  (' 23', 0.0032),\n",
       "  (' y', 0.0767),\n",
       "  (' (', 0.0079),\n",
       "  ('x', 0.005)],\n",
       " [(' =', 0.451),\n",
       "  (' +', 0.0327),\n",
       "  ('=', 0.001),\n",
       "  (' is', 0.0037),\n",
       "  (' -', 0.5062),\n",
       "  (' =', 0.451)],\n",
       " [(' 23', 0.9367),\n",
       "  (' x', 0.0029),\n",
       "  (' 23', 0.9367),\n",
       "  (' y', 0.0213),\n",
       "  (' (', 0.0287),\n",
       "  (' -', 0.0029)],\n",
       " [(' -', 0.9756),\n",
       "  ('-', 0.0106),\n",
       "  (' +', 0.0001),\n",
       "  ('-(', 0.013),\n",
       "  (' -', 0.9756),\n",
       "  (' –', 0.0004)],\n",
       " [(' (', 0.8449),\n",
       "  ('13', 0.0019),\n",
       "  (' 13', 0.1409),\n",
       "  (' (', 0.8449),\n",
       "  (' 5', 0.0067),\n",
       "  ('(', 0.0022)],\n",
       " [('5', 0.9941),\n",
       "  ('13', 0.002),\n",
       "  ('5', 0.9941),\n",
       "  ('8', 0.003),\n",
       "  (' 5', 0.0003),\n",
       "  ('x', 0.0003)],\n",
       " [('+', 0.9582),\n",
       "  ('+', 0.9582),\n",
       "  ('-', 0.0),\n",
       "  (' +', 0.0417),\n",
       "  (' -', 0.0001),\n",
       "  (')', 0.0)],\n",
       " [('8', 0.9993),\n",
       "  ('13', 0.0),\n",
       "  ('8', 0.9993),\n",
       "  (' 8', 0.0001),\n",
       "  ('x', 0.0004),\n",
       "  ('y', 0.0)],\n",
       " [('+', 0.9924),\n",
       "  ('+', 0.9924),\n",
       "  (')-', 0.0002),\n",
       "  (' +', 0.0017),\n",
       "  (').', 0.0001),\n",
       "  (')', 0.0053)],\n",
       " [('x', 0.999),\n",
       "  (' x', 0.0001),\n",
       "  ('X', 0.0001),\n",
       "  ('x', 0.999),\n",
       "  ('y', 0.0007),\n",
       "  ('10', 0.0001)],\n",
       " [(')', 0.9464),\n",
       "  (' )', 0.0001),\n",
       "  (');', 0.0002),\n",
       "  (').', 0.0487),\n",
       "  (')', 0.9464),\n",
       "  ('),', 0.0041)],\n",
       " [('\\n', 0.9888),\n",
       "  ('\\n', 0.9888),\n",
       "  (' ', 0.002),\n",
       "  (' (', 0.0017),\n",
       "  (' by', 0.0006),\n",
       "  (' =', 0.0017)],\n",
       " [('10', 0.7605),\n",
       "  ('We', 0.0373),\n",
       "  ('-', 0.0221),\n",
       "  ('23', 0.0184),\n",
       "  ('x', 0.0439),\n",
       "  ('10', 0.7605)],\n",
       " [(' -', 0.8994),\n",
       "  ('-', 0.0182),\n",
       "  (' +', 0.0024),\n",
       "  ('=', 0.0008),\n",
       "  (' -', 0.8994),\n",
       "  (' =', 0.0776)],\n",
       " [(' x', 0.9789),\n",
       "  (' x', 0.9789),\n",
       "  (' 23', 0.0061),\n",
       "  (' (', 0.0031),\n",
       "  (' 5', 0.0013),\n",
       "  ('x', 0.0063)],\n",
       " [(' =', 0.9783),\n",
       "  (' +', 0.003),\n",
       "  (' ', 0.0002),\n",
       "  ('=', 0.0011),\n",
       "  (' -', 0.0172),\n",
       "  (' =', 0.9783)],\n",
       " [(' 23', 0.9636),\n",
       "  (' x', 0.0013),\n",
       "  (' 23', 0.9636),\n",
       "  (' y', 0.0266),\n",
       "  (' (', 0.0024),\n",
       "  (' -', 0.0024)],\n",
       " [(' -', 0.9872),\n",
       "  ('-', 0.0076),\n",
       "  ('\\n', 0.0001),\n",
       "  ('-(', 0.0044),\n",
       "  (' -', 0.9872),\n",
       "  (' –', 0.0003)],\n",
       " [(' 13', 0.757),\n",
       "  ('13', 0.0078),\n",
       "  (' 13', 0.757),\n",
       "  (' y', 0.007),\n",
       "  (' (', 0.1683),\n",
       "  (' 5', 0.0453)],\n",
       " [(' -', 0.9956),\n",
       "  ('-', 0.0036),\n",
       "  (' +', 0.0003),\n",
       "  (' -', 0.9956),\n",
       "  (' –', 0.0002),\n",
       "  (' =', 0.0001)],\n",
       " [(' x', 0.9803),\n",
       "  (' X', 0.0001),\n",
       "  (' x', 0.9803),\n",
       "  (' y', 0.0003),\n",
       "  (' (', 0.0001),\n",
       "  ('x', 0.0189)],\n",
       " [('\\n', 0.9857),\n",
       "  ('.', 0.0038),\n",
       "  ('\\n', 0.9857),\n",
       "  (' ', 0.0022),\n",
       "  (' (', 0.0031),\n",
       "  (' =', 0.0008)],\n",
       " [('x', 0.0569),\n",
       "  ('-', 0.0483),\n",
       "  ('0', 0.0155),\n",
       "  ('2', 0.0908),\n",
       "  ('x', 0.0569),\n",
       "  ('10', 0.687)],\n",
       " [(' =', 0.5776),\n",
       "  ('-', 0.0132),\n",
       "  (' +', 0.0759),\n",
       "  ('=', 0.0079),\n",
       "  (' -', 0.3127),\n",
       "  (' =', 0.5776)],\n",
       " [(' 10', 0.7712),\n",
       "  (' 13', 0.0288),\n",
       "  (' 23', 0.0912),\n",
       "  (' y', 0.0188),\n",
       "  (' 5', 0.0153),\n",
       "  (' 10', 0.7712)],\n",
       " [(' -', 0.3911),\n",
       "  ('-', 0.0162),\n",
       "  ('.', 0.0324),\n",
       "  ('/', 0.0106),\n",
       "  ('\\n', 0.5205),\n",
       "  (' -', 0.3911)],\n",
       " [(' (', 0.4383),\n",
       "  (' x', 0.1069),\n",
       "  (' 23', 0.2024),\n",
       "  (' y', 0.2014),\n",
       "  (' (', 0.4383),\n",
       "  (' 10', 0.0316)],\n",
       " [('23', 0.9457),\n",
       "  ('23', 0.9457),\n",
       "  (' 23', 0.0033),\n",
       "  ('x', 0.0018),\n",
       "  ('y', 0.0031),\n",
       "  ('10', 0.0436)],\n",
       " [(' -', 0.8661),\n",
       "  ('-', 0.1329),\n",
       "  ('-(', 0.0004),\n",
       "  (' -', 0.8661),\n",
       "  (' –', 0.0003),\n",
       "  (' −', 0.0001)],\n",
       " [(' 13', 0.9517),\n",
       "  ('13', 0.0106),\n",
       "  (' 13', 0.9517),\n",
       "  (' x', 0.001),\n",
       "  (' (', 0.0318),\n",
       "  (' 5', 0.0025)],\n",
       " [(' -', 0.8883),\n",
       "  ('-', 0.0013),\n",
       "  (' +', 0.0005),\n",
       "  (' -', 0.8883),\n",
       "  (').', 0.0004),\n",
       "  (')', 0.1082)],\n",
       " [(' x', 0.9784),\n",
       "  (' X', 0.0002),\n",
       "  (' x', 0.9784),\n",
       "  (' y', 0.0002),\n",
       "  (' 10', 0.0002),\n",
       "  ('x', 0.0208)],\n",
       " [(')', 0.9846),\n",
       "  (' )', 0.0006),\n",
       "  (');', 0.0001),\n",
       "  (').', 0.0112),\n",
       "  (')', 0.9846),\n",
       "  ('),', 0.003)],\n",
       " [('\\n', 0.9906),\n",
       "  ('\\n', 0.9906),\n",
       "  (' ', 0.0014),\n",
       "  (' by', 0.0005),\n",
       "  (' -', 0.0005),\n",
       "  (' =', 0.0024)],\n",
       " [('x', 0.8901),\n",
       "  ('We', 0.0074),\n",
       "  ('2', 0.0539),\n",
       "  ('The', 0.0046),\n",
       "  ('x', 0.8901),\n",
       "  ('10', 0.0108)],\n",
       " [(' =', 0.9738),\n",
       "  ('+', 0.0003),\n",
       "  (' +', 0.019),\n",
       "  ('=', 0.0012),\n",
       "  (' -', 0.0052),\n",
       "  (' =', 0.9738)],\n",
       " [(' 10', 0.966),\n",
       "  (' x', 0.0062),\n",
       "  (' 23', 0.0039),\n",
       "  (' (', 0.002),\n",
       "  (' -', 0.0124),\n",
       "  (' 10', 0.966)],\n",
       " [(' -', 0.9876),\n",
       "  ('-', 0.0018),\n",
       "  (' +', 0.005),\n",
       "  ('\\n', 0.0038),\n",
       "  ('-(', 0.0006),\n",
       "  (' -', 0.9876)],\n",
       " [(' 10', 0.4277),\n",
       "  (' 13', 0.0184),\n",
       "  (' x', 0.0041),\n",
       "  (' 23', 0.1932),\n",
       "  (' (', 0.3455),\n",
       "  (' 10', 0.4277)],\n",
       " [('\\n', 0.8149),\n",
       "  (' +', 0.1008),\n",
       "  ('\\n', 0.8149),\n",
       "  (' ', 0.0019),\n",
       "  (' -', 0.0562),\n",
       "  (' =', 0.0226)],\n",
       " [('x', 0.9631),\n",
       "  ('0', 0.0029),\n",
       "  ('Therefore', 0.003),\n",
       "  ('The', 0.0121),\n",
       "  ('x', 0.9631),\n",
       "  ('10', 0.0033)]]"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "codex_per_step_probs[32][44]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
