{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "51734ae9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import polars as pl\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "870b483b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, you need to run `model_multitoken_probing.py` (with ngram parameter ranging from 1 to 10) to generate the jsonl files in `multitoken_probes/` directory.\n",
    "# Then, you can run this notebook to compute the resulting table."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c871fcf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs_dir = \"multitoken_probes_lumi\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "79c8f0cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Selecting the best hidden state by giving highest **product** of individual piecewise accuracies\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (10, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>pos_idx</th><th>block_idx</th><th>suboperand_idx</th><th>test_acc</th></tr><tr><td>i64</td><td>i64</td><td>i64</td><td>f64</td></tr></thead><tbody><tr><td>-2</td><td>31</td><td>0</td><td>0.0</td></tr><tr><td>-2</td><td>31</td><td>1</td><td>0.002686</td></tr><tr><td>-2</td><td>31</td><td>2</td><td>0.0</td></tr><tr><td>-2</td><td>31</td><td>3</td><td>0.000244</td></tr><tr><td>-2</td><td>31</td><td>4</td><td>0.0</td></tr><tr><td>-2</td><td>31</td><td>5</td><td>0.0</td></tr><tr><td>-2</td><td>31</td><td>6</td><td>0.0</td></tr><tr><td>-2</td><td>31</td><td>7</td><td>0.002441</td></tr><tr><td>-2</td><td>31</td><td>8</td><td>0.004883</td></tr><tr><td>-2</td><td>31</td><td>9</td><td>0.772949</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (10, 4)\n",
       "┌─────────┬───────────┬────────────────┬──────────┐\n",
       "│ pos_idx ┆ block_idx ┆ suboperand_idx ┆ test_acc │\n",
       "│ ---     ┆ ---       ┆ ---            ┆ ---      │\n",
       "│ i64     ┆ i64       ┆ i64            ┆ f64      │\n",
       "╞═════════╪═══════════╪════════════════╪══════════╡\n",
       "│ -2      ┆ 31        ┆ 0              ┆ 0.0      │\n",
       "│ -2      ┆ 31        ┆ 1              ┆ 0.002686 │\n",
       "│ -2      ┆ 31        ┆ 2              ┆ 0.0      │\n",
       "│ -2      ┆ 31        ┆ 3              ┆ 0.000244 │\n",
       "│ -2      ┆ 31        ┆ 4              ┆ 0.0      │\n",
       "│ -2      ┆ 31        ┆ 5              ┆ 0.0      │\n",
       "│ -2      ┆ 31        ┆ 6              ┆ 0.0      │\n",
       "│ -2      ┆ 31        ┆ 7              ┆ 0.002441 │\n",
       "│ -2      ┆ 31        ┆ 8              ┆ 0.004883 │\n",
       "│ -2      ┆ 31        ┆ 9              ┆ 0.772949 │\n",
       "└─────────┴───────────┴────────────────┴──────────┘"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ngram= 1, pos_idx=-2, block_idx=15, estimated acc of full number recovery = 100%\n",
      "ngram= 2, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 91%\n",
      "ngram= 3, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 67%\n",
      "ngram= 4, pos_idx=-2, block_idx=5, estimated acc of full number recovery = 5%\n",
      "ngram= 5, pos_idx=-2, block_idx=6, estimated acc of full number recovery = 0%\n",
      "ngram= 6, pos_idx=-2, block_idx=6, estimated acc of full number recovery = 0%\n",
      "ngram= 7, pos_idx=-2, block_idx=8, estimated acc of full number recovery = 0%\n",
      "ngram= 8, pos_idx=-2, block_idx=8, estimated acc of full number recovery = 0%\n",
      "ngram= 9, pos_idx=-2, block_idx=8, estimated acc of full number recovery = 0%\n",
      "ngram=10, pos_idx=-2, block_idx=31, estimated acc of full number recovery = 0%\n",
      "Individual accuracies:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pos -10</th>\n",
       "      <th>pos -9</th>\n",
       "      <th>pos -8</th>\n",
       "      <th>pos -7</th>\n",
       "      <th>pos -6</th>\n",
       "      <th>pos -5</th>\n",
       "      <th>pos -4</th>\n",
       "      <th>pos -3</th>\n",
       "      <th>pos -2</th>\n",
       "      <th>pos -1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Length 1</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 2</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>91%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 3</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>75%</td>\n",
       "      <td>90%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 4</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>22%</td>\n",
       "      <td>35%</td>\n",
       "      <td>72%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 5</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>20%</td>\n",
       "      <td>20%</td>\n",
       "      <td>13%</td>\n",
       "      <td>70%</td>\n",
       "      <td>99%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 6</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>14%</td>\n",
       "      <td>3%</td>\n",
       "      <td>15%</td>\n",
       "      <td>20%</td>\n",
       "      <td>69%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 7</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>9%</td>\n",
       "      <td>3%</td>\n",
       "      <td>3%</td>\n",
       "      <td>6%</td>\n",
       "      <td>24%</td>\n",
       "      <td>60%</td>\n",
       "      <td>99%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 8</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>8%</td>\n",
       "      <td>3%</td>\n",
       "      <td>2%</td>\n",
       "      <td>4%</td>\n",
       "      <td>7%</td>\n",
       "      <td>20%</td>\n",
       "      <td>63%</td>\n",
       "      <td>99%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 9</th>\n",
       "      <td></td>\n",
       "      <td>6%</td>\n",
       "      <td>1%</td>\n",
       "      <td>1%</td>\n",
       "      <td>2%</td>\n",
       "      <td>4%</td>\n",
       "      <td>4%</td>\n",
       "      <td>17%</td>\n",
       "      <td>61%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 10</th>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>77%</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          pos -10 pos -9 pos -8 pos -7 pos -6 pos -5 pos -4 pos -3 pos -2  \\\n",
       "Length 1                                                                    \n",
       "Length 2                                                              91%   \n",
       "Length 3                                                       75%    90%   \n",
       "Length 4                                                22%    35%    72%   \n",
       "Length 5                                         20%    20%    13%    70%   \n",
       "Length 6                                  14%     3%    15%    20%    69%   \n",
       "Length 7                            9%     3%     3%     6%    24%    60%   \n",
       "Length 8                     8%     3%     2%     4%     7%    20%    63%   \n",
       "Length 9              6%     1%     1%     2%     4%     4%    17%    61%   \n",
       "Length 10      0%     0%     0%     0%     0%     0%     0%     0%     0%   \n",
       "\n",
       "          pos -1  \n",
       "Length 1    100%  \n",
       "Length 2    100%  \n",
       "Length 3    100%  \n",
       "Length 4    100%  \n",
       "Length 5     99%  \n",
       "Length 6    100%  \n",
       "Length 7     99%  \n",
       "Length 8     99%  \n",
       "Length 9    100%  \n",
       "Length 10    77%  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Selecting the best hidden state by giving highest **avg** of individual piecewise accuracies\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (10, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>pos_idx</th><th>block_idx</th><th>suboperand_idx</th><th>test_acc</th></tr><tr><td>i64</td><td>i64</td><td>i64</td><td>f64</td></tr></thead><tbody><tr><td>-2</td><td>2</td><td>0</td><td>0.005127</td></tr><tr><td>-2</td><td>2</td><td>1</td><td>0.0</td></tr><tr><td>-2</td><td>2</td><td>2</td><td>0.0</td></tr><tr><td>-2</td><td>2</td><td>3</td><td>0.0</td></tr><tr><td>-2</td><td>2</td><td>4</td><td>0.000732</td></tr><tr><td>-2</td><td>2</td><td>5</td><td>0.0</td></tr><tr><td>-2</td><td>2</td><td>6</td><td>0.004883</td></tr><tr><td>-2</td><td>2</td><td>7</td><td>0.557129</td></tr><tr><td>-2</td><td>2</td><td>8</td><td>0.881836</td></tr><tr><td>-2</td><td>2</td><td>9</td><td>0.99585</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (10, 4)\n",
       "┌─────────┬───────────┬────────────────┬──────────┐\n",
       "│ pos_idx ┆ block_idx ┆ suboperand_idx ┆ test_acc │\n",
       "│ ---     ┆ ---       ┆ ---            ┆ ---      │\n",
       "│ i64     ┆ i64       ┆ i64            ┆ f64      │\n",
       "╞═════════╪═══════════╪════════════════╪══════════╡\n",
       "│ -2      ┆ 2         ┆ 0              ┆ 0.005127 │\n",
       "│ -2      ┆ 2         ┆ 1              ┆ 0.0      │\n",
       "│ -2      ┆ 2         ┆ 2              ┆ 0.0      │\n",
       "│ -2      ┆ 2         ┆ 3              ┆ 0.0      │\n",
       "│ -2      ┆ 2         ┆ 4              ┆ 0.000732 │\n",
       "│ -2      ┆ 2         ┆ 5              ┆ 0.0      │\n",
       "│ -2      ┆ 2         ┆ 6              ┆ 0.004883 │\n",
       "│ -2      ┆ 2         ┆ 7              ┆ 0.557129 │\n",
       "│ -2      ┆ 2         ┆ 8              ┆ 0.881836 │\n",
       "│ -2      ┆ 2         ┆ 9              ┆ 0.99585  │\n",
       "└─────────┴───────────┴────────────────┴──────────┘"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ngram= 1, pos_idx=-2, block_idx=17, estimated acc of full number recovery = 100%\n",
      "ngram= 2, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 91%\n",
      "ngram= 3, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 67%\n",
      "ngram= 4, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 1%\n",
      "ngram= 5, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%\n",
      "ngram= 6, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%\n",
      "ngram= 7, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%\n",
      "ngram= 8, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%\n",
      "ngram= 9, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%\n",
      "ngram=10, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%\n",
      "Individual accuracies:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pos -10</th>\n",
       "      <th>pos -9</th>\n",
       "      <th>pos -8</th>\n",
       "      <th>pos -7</th>\n",
       "      <th>pos -6</th>\n",
       "      <th>pos -5</th>\n",
       "      <th>pos -4</th>\n",
       "      <th>pos -3</th>\n",
       "      <th>pos -2</th>\n",
       "      <th>pos -1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Length 1</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 2</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>91%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 3</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>75%</td>\n",
       "      <td>90%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 4</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>2%</td>\n",
       "      <td>51%</td>\n",
       "      <td>90%</td>\n",
       "      <td>98%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 5</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>0%</td>\n",
       "      <td>1%</td>\n",
       "      <td>56%</td>\n",
       "      <td>89%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 6</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>58%</td>\n",
       "      <td>88%</td>\n",
       "      <td>99%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 7</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>1%</td>\n",
       "      <td>60%</td>\n",
       "      <td>88%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 8</th>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>58%</td>\n",
       "      <td>89%</td>\n",
       "      <td>99%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 9</th>\n",
       "      <td></td>\n",
       "      <td>0%</td>\n",
       "      <td>1%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>1%</td>\n",
       "      <td>59%</td>\n",
       "      <td>89%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Length 10</th>\n",
       "      <td>1%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>0%</td>\n",
       "      <td>56%</td>\n",
       "      <td>88%</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          pos -10 pos -9 pos -8 pos -7 pos -6 pos -5 pos -4 pos -3 pos -2  \\\n",
       "Length 1                                                                    \n",
       "Length 2                                                              91%   \n",
       "Length 3                                                       75%    90%   \n",
       "Length 4                                                 2%    51%    90%   \n",
       "Length 5                                          0%     1%    56%    89%   \n",
       "Length 6                                   0%     0%     0%    58%    88%   \n",
       "Length 7                            0%     0%     0%     1%    60%    88%   \n",
       "Length 8                     0%     0%     0%     0%     0%    58%    89%   \n",
       "Length 9              0%     1%     0%     0%     0%     1%    59%    89%   \n",
       "Length 10      1%     0%     0%     0%     0%     0%     0%    56%    88%   \n",
       "\n",
       "          pos -1  \n",
       "Length 1    100%  \n",
       "Length 2    100%  \n",
       "Length 3    100%  \n",
       "Length 4     98%  \n",
       "Length 5    100%  \n",
       "Length 6     99%  \n",
       "Length 7    100%  \n",
       "Length 8     99%  \n",
       "Length 9    100%  \n",
       "Length 10   100%  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for aggregation in [\"product\", \"avg\"]:\n",
    "    print(\"\\n\")\n",
    "    print(f\"Selecting the best hidden state by giving highest **{aggregation}** of individual piecewise accuracies\")\n",
    "\n",
    "    reports = []\n",
    "    individual_accs = np.zeros((10, 10)) - np.nan\n",
    "\n",
    "    for ngram in range(1, 11):\n",
    "        filename = f\"{outputs_dir}/results_ngram={ngram}_operand_idx=1.jsonl\"\n",
    "        df = pl.read_ndjson(filename)\n",
    "        if aggregation == \"avg\":\n",
    "            agg_col = pl.col(\"best_valid_acc\").mean()\n",
    "        else:\n",
    "            agg_col = pl.col(\"best_valid_acc\").product()\n",
    "        \n",
    "        best_hidden_state = df.group_by([\"pos_idx\", \"block_idx\"]).agg(agg_col).sort(\"best_valid_acc\", descending=True).head(1)\n",
    "        pos_idx = best_hidden_state.get_column(\"pos_idx\")[0]\n",
    "        block_idx = best_hidden_state.get_column(\"block_idx\")[0]\n",
    "\n",
    "        subset = df.filter(\n",
    "            (pl.col(\"pos_idx\") == pos_idx) & (pl.col(\"block_idx\") == block_idx)\n",
    "        ).select(\n",
    "            [\"pos_idx\", \"block_idx\", \"suboperand_idx\", \"test_acc\"]\n",
    "        )\n",
    "\n",
    "        test_acc = subset.get_column(\"test_acc\").product()\n",
    "        reports.append(f\"ngram={ngram:>2}, pos_idx={pos_idx}, block_idx={block_idx}, estimated acc of full number recovery = {test_acc:.0%}\")\n",
    "\n",
    "        #print(ngram)\n",
    "        #display(subset)\n",
    "\n",
    "        individual_accs[ngram - 1, -len(subset):] = subset.get_column(\"test_acc\").to_numpy()\n",
    "\n",
    "        if ngram == 10:\n",
    "            display(subset)\n",
    "\n",
    "    for report in reports:\n",
    "        print(report)\n",
    "\n",
    "\n",
    "    print(\"Individual accuracies:\")\n",
    "\n",
    "    individual_accs = pl.DataFrame(individual_accs, schema=[f\"pos {- 10 + (i)}\" for i in range(10)])\n",
    "    individual_accs = individual_accs.to_pandas().map(lambda x: f\"{x:.0%}\" if not np.isnan(x) else \"\")\n",
    "    individual_accs.index = [f\"Length {i+1}\" for i in range(10)]\n",
    "    display(individual_accs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
