{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:9\"\n",
    "\n",
    "model_name = \"google/gemma-2-9b\"\n",
    "# model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
    "# model_name = \"mistralai/Mistral-7B-v0.1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DIGIT_IDS_DICT = {\n",
    "    \"meta-llama/Llama-2-7b-hf\": [29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929],\n",
    "    \"mistralai/Mistral-7B-v0.1\": [28734, 28740, 28750, 28770, 28781, 28782, 28784, 28787, 28783, 28774],\n",
    "    \"google/gemma-2-9b\": [235276, 235274, 235284, 235304, 235310, 235308, 235318, 235324, 235321, 235315],\n",
    "    \"Qwen/Qwen2.5-7B\": [15, 16, 17, 18, 19, 20, 21, 22, 23, 24],\n",
    "    \"meta-llama/Meta-Llama-3-8B\": [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 605, 806, 717, 1032, 975, 868, 845, 1114, 972, 777, 508, 1691, 1313, 1419, 1187, 914, 1627, 1544, 1591, 1682, 966, 2148, 843, 1644, 1958, 1758, 1927, 1806, 1987, 2137, 1272, 3174, 2983, 3391, 2096, 1774, 2790, 2618, 2166, 2491, 1135, 3971, 4103, 4331, 4370, 2131, 3487, 3226, 2970, 2946, 1399, 5547, 5538, 5495, 1227, 2397, 2287, 3080, 2614, 3076, 2031, 6028, 5332, 5958, 5728, 2075, 4767, 2813, 2495, 4643, 1490, 5932, 6086, 6069, 5833, 5313, 4218, 4044, 2421, 4578, 1954, 5925, 6083, 6365, 6281, 2721, 4161, 3534, 3264, 1484, 1041, 4645, 4278, 6889, 6849, 6550, 7461, 7699, 6640, 7743, 5120, 5037, 7261, 8190, 8011, 7322, 8027, 8546, 8899, 9079, 4364, 7994, 8259, 4513, 8874, 6549, 9390, 6804, 4386, 9748, 5894, 9263, 9413, 9423, 9565, 8878, 9795, 10148, 10350, 10125, 6860, 9335, 10239, 10290, 8929, 9591, 10465, 10288, 10410, 10161, 3965, 9690, 9756, 9800, 10559, 9992, 10132, 10895, 11286, 11068, 6330, 10718, 10674, 9892, 10513, 10680, 11247, 11515, 8953, 11739, 8258, 11123, 10861, 11908, 11771, 10005, 10967, 11242, 11256, 11128, 5245, 10562, 10828, 10750, 10336, 9741, 9714, 9674, 9367, 9378, 7028, 7529, 5926, 7285, 6393, 6280, 5162, 4468, 3753, 2550, 1049, 679, 2366, 9639, 7854, 10866, 11056, 12060, 12171, 12652, 8848, 11483, 11227, 11702, 11584, 12112, 12463, 13460, 13302, 13762, 8610, 12425, 9716, 12533, 10697, 11057, 14057, 14206, 14261, 14378, 9870, 12245, 12338, 12994, 11727, 12422, 14087, 14590, 13895, 14815, 8273, 13341, 12754, 14052, 13719, 13078, 14205, 14125, 14185, 14735, 5154, 13860, 12326, 14022, 12375, 3192, 4146, 15574, 15966, 15537, 11387, 15602, 14274, 15666, 12815, 14374, 15999, 16567, 16332, 16955, 10914, 15828, 15741, 15451, 16590, 14417, 16660, 16367, 16949, 17267, 11209, 15282, 16544, 16085, 17058, 15935, 17361, 17897, 15287, 17212, 13754, 17335, 16443, 17313, 17168, 16780, 17408, 18163, 17690, 15531, 3101, 12405, 13121, 13236, 12166, 13364, 12879, 14777, 14498, 15500, 12226, 15134, 13384, 15231, 16104, 15189, 15340, 16718, 17592, 16874, 9588, 14423, 15805, 15726, 16723, 15257, 17470, 13817, 16884, 18196, 10568, 16707, 17079, 8765, 17153, 16596, 17014, 17609, 18633, 17887, 13679, 16546, 17590, 16522, 17451, 12901, 18061, 17678, 19746, 18634, 8652, 18113, 16482, 17228, 18384, 17306, 18349, 18520, 17112, 19192, 6843, 18277, 18509, 18199, 15951, 12676, 18044, 18775, 19057, 19929, 14648, 18650, 17662, 18017, 18265, 12935, 18322, 10898, 19166, 19867, 13897, 19162, 18781, 19230, 12910, 18695, 16481, 20062, 19081, 20422, 15515, 19631, 19695, 18252, 20077, 19498, 19615, 20698, 19838, 18572, 3443, 10841, 16496, 13074, 7507, 16408, 17264, 18501, 18058, 12378, 14487, 17337, 17574, 19288, 17448, 18136, 17763, 19561, 19770, 19391, 12819, 18245, 16460, 19711, 18517, 17837, 20363, 20465, 19140, 16371, 14245, 19852, 16739, 20153, 20165, 19305, 21299, 18318, 20596, 20963, 14868, 18495, 20502, 17147, 14870, 19697, 20385, 20800, 19956, 21125, 10617, 20360, 21098, 20235, 20555, 20325, 10961, 21675, 21209, 22094, 16551, 19608, 20911, 21290, 21033, 19988, 21404, 20419, 20304, 21330, 17711, 20617, 21757, 21505, 21358, 19799, 22191, 21144, 22086, 21848, 11738, 21235, 21984, 21884, 20339, 19773, 21511, 22184, 21310, 22418, 18518, 21824, 21776, 22741, 22054, 21038, 19447, 22640, 21962, 18162, 2636, 14408, 17824, 17735, 18048, 17786, 19673, 20068, 19869, 12448, 15633, 18625, 8358, 21164, 20998, 19633, 20571, 22507, 21312, 21851, 15830, 20767, 20936, 21123, 21177, 18415, 22593, 22369, 21458, 21618, 17252, 20823, 20711, 21876, 22467, 20618, 21600, 19038, 22600, 23033, 17048, 22058, 21791, 19642, 21239, 20749, 22048, 23215, 22287, 22782, 13506, 21860, 21478, 22663, 22303, 14148, 20866, 23906, 22895, 22424, 17698, 20460, 19242, 21789, 22210, 20943, 23477, 19282, 22049, 23642, 18712, 22005, 22468, 22529, 23402, 21228, 20758, 23411, 22915, 24847, 18216, 23864, 23670, 23493, 23816, 21535, 22345, 22159, 20691, 22905, 20615, 24380, 20128, 22608, 23428, 22754, 24515, 24574, 21856, 21944, 5067, 18262, 20224, 21006, 20354, 19666, 20213, 21996, 19944, 21138, 17608, 20973, 21018, 22922, 22638, 21385, 21379, 21717, 21985, 23388, 17416, 22488, 19808, 22801, 23000, 15894, 22385, 23103, 23574, 24239, 18660, 21729, 20775, 23736, 24307, 22276, 22422, 21788, 24495, 23079, 14033, 23525, 22266, 22956, 21975, 22926, 22642, 22644, 23802, 24734, 13655, 23409, 23181, 21598, 21969, 15573, 20744, 23480, 23654, 25090, 19274, 24132, 24199, 24491, 23888, 23467, 10943, 19774, 24427, 25289, 21218, 23403, 22768, 24938, 25513, 21129, 24187, 24375, 17458, 25136, 17814, 25091, 25178, 24887, 24313, 23717, 22347, 21897, 23292, 25458, 21741, 25168, 25073, 25298, 25392, 24394, 23578, 25388, 25169, 23459, 7007, 19597, 20253, 20436, 21949, 21469, 22457, 18770, 21295, 22874, 19027, 22375, 22708, 22977, 23193, 22744, 23929, 25150, 21982, 24758, 13104, 20873, 23024, 24388, 24735, 23309, 24430, 23486, 24054, 22194, 20785, 24626, 24289, 24865, 24438, 24939, 23969, 22039, 25527, 25809, 21112, 25021, 25560, 26260, 23800, 23901, 25594, 23619, 20338, 25541, 11711, 23986, 23644, 25504, 23952, 23532, 24456, 23776, 25302, 26439, 19104, 25110, 24376, 26083, 24402, 22240, 25358, 23275, 17521, 24619, 20772, 24876, 23624, 23267, 24472, 22908, 23823, 15831, 23592, 25659, 19423, 21893, 23833, 26008, 22148, 22539, 25251, 23171, 24216, 16474, 22876, 26234, 24763, 24531, 25926, 25808, 24832, 25314, 26519, 23987, 4728, 17973, 13135, 20899, 20417, 21032, 22397, 23178, 11770, 21474, 19232, 22588, 19270, 24288, 25498, 23582, 23713, 25528, 23141, 18831, 18248, 23282, 23105, 23848, 25016, 22091, 23038, 24920, 22716, 26218, 21221, 25009, 23879, 22904, 26223, 23424, 25192, 26244, 24250, 25465, 19899, 25496, 25377, 23996, 24344, 24650, 26563, 25125, 24951, 26537, 16217, 24866, 24571, 25724, 25515, 22869, 25505, 20907, 23805, 24061, 18670, 24963, 24071, 26051, 19355, 24678, 22455, 26013, 25862, 26497, 22440, 25665, 25303, 25747, 25822, 17419, 24870, 23873, 25890, 25622, 19272, 25339, 23213, 24902, 25962, 19445, 25399, 26058, 12251, 25354, 21381, 24962, 24110, 26088, 26227, 25238, 24542, 24777, 24809, 22889, 7467, 19319, 21026, 23305, 22777, 22393, 22224, 23505, 23629, 21278, 21056, 17000, 22750, 24331, 24579, 22387, 24487, 24391, 25828, 24337, 18485, 22536, 20275, 22614, 23890, 21910, 26026, 26437, 25001, 25344, 19306, 25717, 25401, 25806, 24347, 26970, 25612, 21936, 25454, 26164, 21251, 21322, 20249, 26576, 25687, 24599, 26491, 26511, 26979, 24680, 15862, 24989, 24597, 25326, 25741, 25875, 26067, 27341, 27079, 26328, 16415, 26114, 26366, 26087, 26281, 24837, 25285, 27134, 23386, 24792, 21133, 25693, 24425, 24471, 26007, 24609, 25208, 26409, 17272, 25476, 19068, 25643, 25873, 24742, 23812, 24961, 27468, 22207, 24538, 25350, 19146, 24606, 22992, 24242, 22897, 22101, 23031, 22694, 19416, 5500],\n",
    "    \"microsoft/phi-4\": [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 605, 806, 717, 1032, 975, 868, 845, 1114, 972, 777, 508, 1691, 1313, 1419, 1187, 914, 1627, 1544, 1591, 1682, 966, 2148, 843, 1644, 1958, 1758, 1927, 1806, 1987, 2137, 1272, 3174, 2983, 3391, 2096, 1774, 2790, 2618, 2166, 2491, 1135, 3971, 4103, 4331, 4370, 2131, 3487, 3226, 2970, 2946, 1399, 5547, 5538, 5495, 1227, 2397, 2287, 3080, 2614, 3076, 2031, 6028, 5332, 5958, 5728, 2075, 4767, 2813, 2495, 4643, 1490, 5932, 6086, 6069, 5833, 5313, 4218, 4044, 2421, 4578, 1954, 5925, 6083, 6365, 6281, 2721, 4161, 3534, 3264, 1484, 1041, 4645, 4278, 6889, 6849, 6550, 7461, 7699, 6640, 7743, 5120, 5037, 7261, 8190, 8011, 7322, 8027, 8546, 8899, 9079, 4364, 7994, 8259, 4513, 8874, 6549, 9390, 6804, 4386, 9748, 5894, 9263, 9413, 9423, 9565, 8878, 9795, 10148, 10350, 10125, 6860, 9335, 10239, 10290, 8929, 9591, 10465, 10288, 10410, 10161, 3965, 9690, 9756, 9800, 10559, 9992, 10132, 10895, 11286, 11068, 6330, 10718, 10674, 9892, 10513, 10680, 11247, 11515, 8953, 11739, 8258, 11123, 10861, 11908, 11771, 10005, 10967, 11242, 11256, 11128, 5245, 10562, 10828, 10750, 10336, 9741, 9714, 9674, 9367, 9378, 7028, 7529, 5926, 7285, 6393, 6280, 5162, 4468, 3753, 2550, 1049, 679, 2366, 9639, 7854, 10866, 11056, 12060, 12171, 12652, 8848, 11483, 11227, 11702, 11584, 12112, 12463, 13460, 13302, 13762, 8610, 12425, 9716, 12533, 10697, 11057, 14057, 14206, 14261, 14378, 9870, 12245, 12338, 12994, 11727, 12422, 14087, 14590, 13895, 14815, 8273, 13341, 12754, 14052, 13719, 13078, 14205, 14125, 14185, 14735, 5154, 13860, 12326, 14022, 12375, 3192, 4146, 15574, 15966, 15537, 11387, 15602, 14274, 15666, 12815, 14374, 15999, 16567, 16332, 16955, 10914, 15828, 15741, 15451, 16590, 14417, 16660, 16367, 16949, 17267, 11209, 15282, 16544, 16085, 17058, 15935, 17361, 17897, 15287, 17212, 13754, 17335, 16443, 17313, 17168, 16780, 17408, 18163, 17690, 15531, 3101, 12405, 13121, 13236, 12166, 13364, 12879, 14777, 14498, 15500, 12226, 15134, 13384, 15231, 16104, 15189, 15340, 16718, 17592, 16874, 9588, 14423, 15805, 15726, 16723, 15257, 17470, 13817, 16884, 18196, 10568, 16707, 17079, 8765, 17153, 16596, 17014, 17609, 18633, 17887, 13679, 16546, 17590, 16522, 17451, 12901, 18061, 17678, 19746, 18634, 8652, 18113, 16482, 17228, 18384, 17306, 18349, 18520, 17112, 19192, 6843, 18277, 18509, 18199, 15951, 12676, 18044, 18775, 19057, 19929, 14648, 18650, 17662, 18017, 18265, 12935, 18322, 10898, 19166, 19867, 13897, 19162, 18781, 19230, 12910, 18695, 16481, 20062, 19081, 20422, 15515, 19631, 19695, 18252, 20077, 19498, 19615, 20698, 19838, 18572, 3443, 10841, 16496, 13074, 7507, 16408, 17264, 18501, 18058, 12378, 14487, 17337, 17574, 19288, 17448, 18136, 17763, 19561, 19770, 19391, 12819, 18245, 16460, 19711, 18517, 17837, 20363, 20465, 19140, 16371, 14245, 19852, 16739, 20153, 20165, 19305, 21299, 18318, 20596, 20963, 14868, 18495, 20502, 17147, 14870, 19697, 20385, 20800, 19956, 21125, 10617, 20360, 21098, 20235, 20555, 20325, 10961, 21675, 21209, 22094, 16551, 19608, 20911, 21290, 21033, 19988, 21404, 20419, 20304, 21330, 17711, 20617, 21757, 21505, 21358, 19799, 22191, 21144, 22086, 21848, 11738, 21235, 21984, 21884, 20339, 19773, 21511, 22184, 21310, 22418, 18518, 21824, 21776, 22741, 22054, 21038, 19447, 22640, 21962, 18162, 2636, 14408, 17824, 17735, 18048, 17786, 19673, 20068, 19869, 12448, 15633, 18625, 8358, 21164, 20998, 19633, 20571, 22507, 21312, 21851, 15830, 20767, 20936, 21123, 21177, 18415, 22593, 22369, 21458, 21618, 17252, 20823, 20711, 21876, 22467, 20618, 21600, 19038, 22600, 23033, 17048, 22058, 21791, 19642, 21239, 20749, 22048, 23215, 22287, 22782, 13506, 21860, 21478, 22663, 22303, 14148, 20866, 23906, 22895, 22424, 17698, 20460, 19242, 21789, 22210, 20943, 23477, 19282, 22049, 23642, 18712, 22005, 22468, 22529, 23402, 21228, 20758, 23411, 22915, 24847, 18216, 23864, 23670, 23493, 23816, 21535, 22345, 22159, 20691, 22905, 20615, 24380, 20128, 22608, 23428, 22754, 24515, 24574, 21856, 21944, 5067, 18262, 20224, 21006, 20354, 19666, 20213, 21996, 19944, 21138, 17608, 20973, 21018, 22922, 22638, 21385, 21379, 21717, 21985, 23388, 17416, 22488, 19808, 22801, 23000, 15894, 22385, 23103, 23574, 24239, 18660, 21729, 20775, 23736, 24307, 22276, 22422, 21788, 24495, 23079, 14033, 23525, 22266, 22956, 21975, 22926, 22642, 22644, 23802, 24734, 13655, 23409, 23181, 21598, 21969, 15573, 20744, 23480, 23654, 25090, 19274, 24132, 24199, 24491, 23888, 23467, 10943, 19774, 24427, 25289, 21218, 23403, 22768, 24938, 25513, 21129, 24187, 24375, 17458, 25136, 17814, 25091, 25178, 24887, 24313, 23717, 22347, 21897, 23292, 25458, 21741, 25168, 25073, 25298, 25392, 24394, 23578, 25388, 25169, 23459, 7007, 19597, 20253, 20436, 21949, 21469, 22457, 18770, 21295, 22874, 19027, 22375, 22708, 22977, 23193, 22744, 23929, 25150, 21982, 24758, 13104, 20873, 23024, 24388, 24735, 23309, 24430, 23486, 24054, 22194, 20785, 24626, 24289, 24865, 24438, 24939, 23969, 22039, 25527, 25809, 21112, 25021, 25560, 26260, 23800, 23901, 25594, 23619, 20338, 25541, 11711, 23986, 23644, 25504, 23952, 23532, 24456, 23776, 25302, 26439, 19104, 25110, 24376, 26083, 24402, 22240, 25358, 23275, 17521, 24619, 20772, 24876, 23624, 23267, 24472, 22908, 23823, 15831, 23592, 25659, 19423, 21893, 23833, 26008, 22148, 22539, 25251, 23171, 24216, 16474, 22876, 26234, 24763, 24531, 25926, 25808, 24832, 25314, 26519, 23987, 4728, 17973, 13135, 20899, 20417, 21032, 22397, 23178, 11770, 21474, 19232, 22588, 19270, 24288, 25498, 23582, 23713, 25528, 23141, 18831, 18248, 23282, 23105, 23848, 25016, 22091, 23038, 24920, 22716, 26218, 21221, 25009, 23879, 22904, 26223, 23424, 25192, 26244, 24250, 25465, 19899, 25496, 25377, 23996, 24344, 24650, 26563, 25125, 24951, 26537, 16217, 24866, 24571, 25724, 25515, 22869, 25505, 20907, 23805, 24061, 18670, 24963, 24071, 26051, 19355, 24678, 22455, 26013, 25862, 26497, 22440, 25665, 25303, 25747, 25822, 17419, 24870, 23873, 25890, 25622, 19272, 25339, 23213, 24902, 25962, 19445, 25399, 26058, 12251, 25354, 21381, 24962, 24110, 26088, 26227, 25238, 24542, 24777, 24809, 22889, 7467, 19319, 21026, 23305, 22777, 22393, 22224, 23505, 23629, 21278, 21056, 17000, 22750, 24331, 24579, 22387, 24487, 24391, 25828, 24337, 18485, 22536, 20275, 22614, 23890, 21910, 26026, 26437, 25001, 25344, 19306, 25717, 25401, 25806, 24347, 26970, 25612, 21936, 25454, 26164, 21251, 21322, 20249, 26576, 25687, 24599, 26491, 26511, 26979, 24680, 15862, 24989, 24597, 25326, 25741, 25875, 26067, 27341, 27079, 26328, 16415, 26114, 26366, 26087, 26281, 24837, 25285, 27134, 23386, 24792, 21133, 25693, 24425, 24471, 26007, 24609, 25208, 26409, 17272, 25476, 19068, 25643, 25873, 24742, 23812, 24961, 27468, 22207, 24538, 25350, 19146, 24606, 22992, 24242, 22897, 22101, 23031, 22694, 19416, 5500],\n",
    "    \"EleutherAI/pythia-6.9b-deduped\": [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 740, 883, 805, 1012, 1047, 1010, 1036, 1166, 1093, 746, 938, 1797, 1423, 1508, 1348, 1099, 1731, 1630, 1619, 1717, 1229, 2405, 1237, 1610, 1706, 1671, 1812, 1787, 1839, 1867, 1449, 3156, 2945, 3079, 2031, 1857, 2950, 2504, 2385, 2537, 1235, 3712, 3583, 3357, 3439, 2417, 3208, 3011, 3680, 3046, 1549, 3832, 3763, 3571, 1540, 2082, 2526, 2251, 2358, 2090, 1967, 3677, 3547, 3655, 3566, 1976, 3121, 2357, 3141, 2787, 1438, 3593, 3507, 3245, 2759, 2227, 2691, 2597, 2055, 2511, 2270, 4739, 4529, 4590, 3953, 2222, 4196, 4148, 4185, 1525],\n",
    "    \"EleutherAI/pythia-12b-deduped\": [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 740, 883, 805, 1012, 1047, 1010, 1036, 1166, 1093, 746, 938, 1797, 1423, 1508, 1348, 1099, 1731, 1630, 1619, 1717, 1229, 2405, 1237, 1610, 1706, 1671, 1812, 1787, 1839, 1867, 1449, 3156, 2945, 3079, 2031, 1857, 2950, 2504, 2385, 2537, 1235, 3712, 3583, 3357, 3439, 2417, 3208, 3011, 3680, 3046, 1549, 3832, 3763, 3571, 1540, 2082, 2526, 2251, 2358, 2090, 1967, 3677, 3547, 3655, 3566, 1976, 3121, 2357, 3141, 2787, 1438, 3593, 3507, 3245, 2759, 2227, 2691, 2597, 2055, 2511, 2270, 4739, 4529, 4590, 3953, 2222, 4196, 4148, 4185, 1525],\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Using Transformer Lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as t\n",
    "from transformer_lens import HookedTransformer\n",
    "t.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant Setting center_unembed=False instead.\n",
      "Loading checkpoint shards: 100%|██████████| 8/8 [00:00<00:00, 159.75it/s]\n",
      "WARNING:root:You are not using LayerNorm, so the writing weights can't be centered! Skipping\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model google/gemma-2-9b into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "model = HookedTransformer.from_pretrained(model_name, device=device)\n",
    "model.set_ungroup_grouped_query_attention(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n"
     ]
    }
   ],
   "source": [
    "logits = model(input=\"1+1=3\\n2+2=\", return_type=\"logits\")\n",
    "digits = DIGIT_IDS_DICT[model_name]\n",
    "print(logits[0, -1, digits].argmax().item()) # next token"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Using Huggingface Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "pipe = pipeline(\"text-generation\", model=model_name, device=device)\n",
    "result = pipe(\"1+1=3\\n2+2=\", max_new_tokens=1, do_sample=False)\n",
    "\n",
    "print(result[0]['generated_text'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fi2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
