{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff434aed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import polars as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83bceebc",
   "metadata": {},
   "outputs": [],
   "source": [
    "nondecodable_tokens = {\n",
    "    \"olmo1b\": [0,12,15,16,18,27,32,63,74,75,103,112,117,156,158,163,175,191,193,194,209,212,214,228,229,230,233,235,246,248,249,250,251,252,253,254,263,275,279,286,287,289,295,298,300,301,304,307,311,312,313,314,315,321,322,324,328,330,334,340,344,358,360,364,365,367,375,377,379,380,385,386,387,389,391,394,396,398,401,407,412,415,419,420,433,436,450,451,456,457,458,461,466,467,482,483,485,489,490,492,505,514,516,521,524,528,530,538,542,545,548,549,550,551,552,554,555,556,557,560,564,566,567,568,570,576,577,578,579,580,582,583,585,595,596,597,600,601,605,606,608,609,612,614,615,621,622,625,626,630,634,639,640,641,644,645,649,650,652,654,655,657,658,670,672,675,680,685,687,689,690,694,695,697,712,714,716,717,719,723,724,729,730,734,736,737,740,747,748,750,766,767,769,772,773,776,777,785,786,787,789,791,793,798,799,802,803,805,811,817,820,824,826,828,830,834,836,838,841,842,843,844,845,846,847,849,851,855,856,857,859,860,866,867,869,870,875,876,878,879,880,885,886,888,889,893,894,897,900,901,908,910,912,913,914,915,919,920,922,923,925,929,931,934,936,937,940,942,944,946,949,950,951,953,955,957,958,959,961,963,964,966,968,969,970,972,974,975,976,977,978,979,980,981,985,986,988,989,991,994,995,998,999],\n",
    "    \"olmo7b\": [0,4,9,12,90,145,153,156,159,168,187,191,229,251,255,286,287,292,334,342,344,375,376,411,415,418,432,444,450,460,470,472,473,477,478,482,484,485,486,488,508,512,514,520,525,526,551,554,558,566,580,590,592,600,615,618,630,638,653,669,672,675,685,695,705,716,725,734,737,742,744,745,746,748,749,759,762,772,776,777,782,785,788,830,831,841,856,868,875,877,888,889,899,900,922,951,953,954,960,962,969,972,976,977,980,988,999],\n",
    "    \"olmo13b\": [0,2,7,13,15,16,22,27,33,42,51,57,62,90,92,98,105,120,123,126,127,128,171,179,191,197,202,209,234,239,249,257,259,260,263,265,266,267,274,275,276,277,286,287,296,299,304,313,315,320,324,325,327,328,329,330,331,334,340,341,344,347,348,351,359,360,364,367,372,375,377,380,383,385,386,389,390,392,395,411,412,415,418,419,421,429,430,431,433,435,436,441,442,445,447,451,452,455,456,457,471,473,478,485,498,500,505,520,524,525,526,529,530,539,541,545,557,559,567,574,575,580,590,593,594,605,606,613,621,626,630,635,636,641,642,645,646,649,650,655,656,661,670,674,678,680,695,698,701,706,724,725,727,729,730,734,739,748,750,754,756,758,759,762,767,786,787,788,792,799,810,811,814,816,822,830,831,834,838,842,843,844,847,851,852,856,870,874,875,887,888,889,891,892,893,896,899,900,901,907,909,916,924,925,935,936,938,944,946,948,955,957,958,960,966,967,968,969,971,977,981,982,986,987,990,993,994,995,996,999],\n",
    "    \"olmo32b\": [0,1,2,4,7,8,9,10,12,13,14,15,16,17,18,19,20,23,24,26,27,28,31,32,33,34,35,36,38,39,40,42,43,44,45,46,47,48,49,51,52,53,54,55,56,57,58,59,60,61,63,64,65,67,68,69,70,71,73,74,75,76,77,78,79,80,81,83,84,85,86,87,88,89,90,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,202,203,204,205,206,207,208,209,210,211,212,213,214,215,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,244,245,246,247,248,249,250,251,252,253,254,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,331,332,333,336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,360,361,362,363,364,365,366,367,368,369,371,372,373,374,375,376,377,378,379,380,381,382,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,402,403,404,405,406,407,408,409,410,411,412,413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447,448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463,464,465,466,467,468,469,470,471,472,473,474,475,476,478,479,480,481,482,483,484,485,486,487,488,489,490,491,493,494,495,496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511,512,513,514,515,516,517,518,520,521,522,523,524,525,526,528,529,530,531,532,533,534,535,536,537,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,554,555,556,557,558,559,560,561,562,563,564,565,566,567,568,569,570,571,572,573,574,575,576,577,578,579,580,581,582,583,584,585,586,587,588,589,590,591,592,593,594,595,596,597,598,599,600,601,602,603,604,605,606,607,608,609,610,611,612,613,615,616,618,620,621,622,623,624,625,626,627,628,629,630,631,632,633,634,635,636,637,638,639,640,641,642,643,644,645,646,647,648,649,650,651,652,653,654,655,656,657,658,659,660,661,662,663,664,665,666,667,668,669,670,671,673,674,675,676,677,678,679,680,681,682,683,684,685,686,687,688,689,690,691,692,693,694,695,696,698,699,700,701,703,704,705,706,707,708,709,710,711,712,713,714,715,716,717,718,720,721,722,723,724,725,726,728,729,730,731,732,733,734,735,736,737,739,740,741,742,743,744,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767,768,769,770,771,772,773,775,776,777,778,779,781,782,783,784,785,787,788,789,790,791,792,793,794,795,796,797,798,799,800,801,802,803,804,805,806,807,808,809,810,811,812,813,814,815,816,817,818,819,820,821,822,823,824,825,826,827,828,829,830,831,832,833,834,835,836,837,838,839,840,841,843,844,845,846,847,848,849,850,851,853,854,855,856,857,858,859,860,861,862,863,864,865,866,867,868,869,871,872,873,874,875,876,877,878,879,880,881,882,883,884,885,886,887,888,889,890,891,892,893,894,895,896,897,898,899,900,901,902,903,905,906,907,908,909,910,911,912,913,914,915,916,917,918,919,920,921,922,923,924,925,926,927,928,929,930,931,932,933,934,935,936,937,938,939,940,941,942,943,945,946,947,948,949,950,951,952,953,954,955,956,957,958,959,960,961,962,963,964,965,966,967,968,969,970,971,972,973,974,975,976,977,978,979,980,981,982,983,984,985,986,987,988,989,990,991,992,993,994,995,996,997,998,999],\n",
    "    \"llama1b\": [0,4,977,999],\n",
    "    \"shuffled_llama1b\": [0,4,977,999],\n",
    "    \"sgn_llama1b\": [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335,336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431,432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447,448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463,464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511,512,513,514,515,516,517,518,519,520,521,522,523,524,525,526,527,528,529,530,531,532,533,534,535,536,537,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,553,554,555,556,557,558,559,560,561,562,563,564,565,566,567,568,569,570,571,572,573,574,575,576,577,578,579,580,581,582,583,584,585,586,587,588,589,590,591,592,593,594,595,596,597,598,599,600,601,602,603,604,605,606,607,608,609,610,611,612,613,614,615,616,617,618,619,620,621,622,623,624,625,626,627,628,629,630,631,632,633,634,635,636,637,638,639,640,641,642,643,644,645,646,647,648,649,650,651,652,653,654,655,656,657,658,659,660,661,662,663,664,665,666,667,668,669,670,671,672,673,674,675,676,677,678,679,680,681,682,683,684,685,686,687,688,689,690,691,692,693,694,695,696,697,698,699,700,701,702,703,704,705,706,707,708,709,710,711,712,713,714,715,716,717,718,719,720,721,722,723,724,725,726,727,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767,768,769,770,771,772,773,774,775,776,777,778,779,780,781,782,783,784,785,786,787,788,789,790,791,792,793,794,795,796,797,798,799,800,801,802,803,804,805,806,807,808,809,811,812,813,814,815,816,817,818,819,820,821,822,823,824,825,826,827,828,829,830,831,832,833,834,835,836,837,838,839,840,841,842,843,844,845,846,847,848,849,850,851,852,853,854,855,856,857,858,859,860,861,862,863,864,865,866,867,868,869,870,871,872,873,874,875,876,877,878,879,880,881,882,883,884,885,886,887,888,889,890,891,892,893,894,895,896,897,898,899,900,901,902,903,904,905,906,907,908,909,910,911,912,913,914,915,916,917,918,919,920,921,922,923,924,925,926,927,928,929,930,931,932,933,934,935,936,937,938,939,940,941,942,943,944,945,946,947,948,949,950,951,952,953,954,955,956,957,958,959,960,961,962,963,964,965,966,967,968,969,970,971,972,973,974,975,976,977,978,979,980,981,982,983,984,985,986,987,988,989,990,991,992,993,994,995,996,997,998,999],\n",
    "    \"llama3b\": [0,9,15,125,444,672,695,838,909,985,999],\n",
    "    \"llama8b\": [4,9,63,90,142,241,252,269,289,358,390,392,394,400,408,474,477,498,500,558,608,609,642,672,698,741,764,768,769,789,835,837,847,852,878,910,918,922,936,943,957,964,969,989,999],\n",
    "    \"llama70b\": [0,999],\n",
    "    \"phi\": [0,4,7,9,10,12,15,25,54,57,58,62,63,65,66,67,72,74,75,77,78,83,84,85,87,90,91,92,93,94,98,110,117,123,125,126,128,129,132,133,135,139,143,149,156,158,168,179,191,205,233,307,358,359,365,375,401,402,461,545,733,877,900,946,954,989,999],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aec83035",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs_errors = {\n",
    "    \"olmo1b\": pl.read_csv(\"error-csvs/errors_addition_olmo-1b.csv\"),\n",
    "    \"olmo7b\": pl.read_csv(\"error-csvs/errors_plus_0_allenai-OLMo-2-1124-7B.csv\"),\n",
    "    \"olmo13b\": pl.read_csv(\"error-csvs/errors_plus_0_allenai-OLMo-2-1124-13B.csv\"),\n",
    "\n",
    "    \"llama1b\": pl.read_csv(\"error-csvs/llama3.2-1b_errors.csv\"),\n",
    "    \"llama3b\": pl.read_csv(\"error-csvs/errors_plus_0_meta-llama-Llama-3.2-3B.csv\"),\n",
    "    \"llama8b\": pl.read_csv(\"error-csvs/errors_plus_0_meta-llama-Llama-3.1-8B.csv\"),\n",
    "}\n",
    "\n",
    "dfs_errors_minus = {\n",
    "    \"olmo1b\": pl.read_csv(\"error-csvs/errors_subtraction_olmo-1b.csv\"),\n",
    "    \"olmo7b\": pl.read_csv(\"error-csvs/errors_minus_0_allenai-OLMo-2-1124-7B.csv\"),\n",
    "    \"olmo13b\": pl.read_csv(\"error-csvs/errors_minus_0_allenai-OLMo-2-1124-13B.csv\"),\n",
    "\n",
    "    \"llama1b\": pl.read_csv(\"error-csvs/errors_subtraction_llama3.2-1b.csv\"),\n",
    "    \"llama3b\": pl.read_csv(\"error-csvs/errors_minus_0_meta-llama-Llama-3.2-3B.csv\"),\n",
    "    \"llama8b\": pl.read_csv(\"error-csvs/errors_minus_0_meta-llama-Llama-3.1-8B.csv\"),\n",
    "}\n",
    "\n",
    "\n",
    "def make_full_table_plus(df_errors: pl.DataFrame, nondecodables: list[int]) -> pl.DataFrame:\n",
    "    all_combinations = list(itertools.product(range(499), range(499)))\n",
    "    errs = {(row[\"x1\"], row[\"x2\"]): row[\"predicted\"] for row in df_errors.to_dicts()}\n",
    "    prediction = [errs.get((x1, x2), x1 + x2) for x1, x2 in all_combinations]\n",
    "    df_full = pl.DataFrame({\n",
    "        \"x1\": [x1 for x1, _ in all_combinations],\n",
    "        \"x2\": [x2 for _, x2 in all_combinations],\n",
    "        \"predicted\": prediction\n",
    "    })\n",
    "    df_full = df_full.with_columns(\n",
    "        (pl.col(\"x1\") + pl.col(\"x2\")).alias(\"expected\")\n",
    "    )\n",
    "    df_full = df_full.with_columns(\n",
    "        is_correct = pl.col(\"predicted\") == pl.col(\"expected\"),\n",
    "        has_nondecodable_value = pl.col(\"x1\").is_in(nondecodables) | pl.col(\"x2\").is_in(nondecodables) #| pl.col(\"expected\").is_in(nondecodables),\n",
    "    )\n",
    "    return df_full\n",
    "\n",
    "def make_full_table_minus(df_errors: pl.DataFrame, nondecodables: list[int]) -> pl.DataFrame:\n",
    "    all_combinations = [(x, y) for x in range(1, 999) for y in range(x)]\n",
    "    errs = {(row[\"x1\"], row[\"x2\"]): row[\"predicted\"] for row in df_errors.to_dicts()}\n",
    "    prediction = [errs.get((x1, x2), x1 - x2) for x1, x2 in all_combinations]\n",
    "    df_full = pl.DataFrame({\n",
    "        \"x1\": [x1 for x1, _ in all_combinations],\n",
    "        \"x2\": [x2 for _, x2 in all_combinations],\n",
    "        \"predicted\": prediction\n",
    "    })\n",
    "    df_full = df_full.with_columns(\n",
    "        (pl.col(\"x1\") - pl.col(\"x2\")).alias(\"expected\")\n",
    "    )\n",
    "    df_full = df_full.with_columns(\n",
    "        is_correct = pl.col(\"predicted\") == pl.col(\"expected\"),\n",
    "        has_nondecodable_value = pl.col(\"x1\").is_in(nondecodables) | pl.col(\"x2\").is_in(nondecodables) #| pl.col(\"expected\").is_in(nondecodables),\n",
    "    )\n",
    "    return df_full"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c463b28b",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_tables_plus = {\n",
    "    nick: make_full_table_plus(df_errors, nondecodable_tokens[nick])\n",
    "    for nick, df_errors in dfs_errors.items()\n",
    "}\n",
    "\n",
    "pred_tables_minus = {\n",
    "    nick: make_full_table_minus(df_errors, nondecodable_tokens[nick])\n",
    "    for nick, df_errors in dfs_errors_minus.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "532b2ae2",
   "metadata": {},
   "source": [
    "### Addition Error Rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d61d78d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "pl.DataFrame({\n",
    "    \"model\": list(pred_tables_plus.keys()),\n",
    "    \"error_rate\": [\"{:.2%}\".format((1 - df.get_column(\"is_correct\").mean())) for df in pred_tables_plus.values()],\n",
    "}).transpose(include_header=True).to_pandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0be155d4",
   "metadata": {},
   "source": [
    "### Subtraction Error Rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dbe0040",
   "metadata": {},
   "outputs": [],
   "source": [
    "pl.DataFrame({\n",
    "    \"model\": list(pred_tables_minus.keys()),\n",
    "    \"error_rate\": [\"{:.2%}\".format((1 - df.get_column(\"is_correct\").mean())) for df in pred_tables_minus.values()],\n",
    "}).transpose(include_header=True).to_pandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8b5e2bb",
   "metadata": {},
   "source": [
    "### Addition Error Rate - given either decodable or non-decodable tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "718867af",
   "metadata": {},
   "outputs": [],
   "source": [
    "pl.DataFrame({\n",
    "    \"model\": list(pred_tables_plus.keys()),\n",
    "    \"Error Rate given decodable inputs\": [\n",
    "        \"{:.2%}\".format(1 - df.filter(~pl.col(\"has_nondecodable_value\")).get_column(\"is_correct\").mean())\n",
    "        for df in pred_tables_plus.values()\n",
    "    ],\n",
    "    \"Error Rate given non-decodable inputs\": [\n",
    "        \"{:.2%}\".format(1 - df.filter(pl.col(\"has_nondecodable_value\")).get_column(\"is_correct\").mean())\n",
    "        for df in pred_tables_plus.values()\n",
    "    ],\n",
    "    \"Decodable inputs (population size)\": [\n",
    "        len(df.filter(~pl.col(\"has_nondecodable_value\")))\n",
    "        for df in pred_tables_plus.values()\n",
    "    ],\n",
    "    \"Non-decodable inputs (population size)\": [\n",
    "        len(df.filter(pl.col(\"has_nondecodable_value\")))\n",
    "        for df in pred_tables_plus.values()\n",
    "    ],\n",
    "}).transpose(include_header=True).to_pandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "875b2fc9",
   "metadata": {},
   "source": [
    "### Subtraction Error Rate - given either decodable or non-decodable tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a86219c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "pl.DataFrame({\n",
    "    \"model\": list(pred_tables_minus.keys()),\n",
    "    \"Error Rate given decodable inputs\": [\n",
    "        \"{:.2%}\".format(1 - df.filter(~pl.col(\"has_nondecodable_value\")).get_column(\"is_correct\").mean())\n",
    "        for df in pred_tables_minus.values()\n",
    "    ],\n",
    "    \"Error Rate given non-decodable inputs\": [\n",
    "        \"{:.2%}\".format(1 - df.filter(pl.col(\"has_nondecodable_value\")).get_column(\"is_correct\").mean())\n",
    "        for df in pred_tables_minus.values()\n",
    "    ],\n",
    "    \"Decodable inputs (population size)\": [\n",
    "        len(df.filter(~pl.col(\"has_nondecodable_value\")))\n",
    "        for df in pred_tables_minus.values()\n",
    "    ],\n",
    "    \"Non-decodable inputs (population size)\": [\n",
    "        len(df.filter(pl.col(\"has_nondecodable_value\")))\n",
    "        for df in pred_tables_minus.values()\n",
    "    ],\n",
    "}).transpose(include_header=True).to_pandas()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
