import os
import torch
import tqdm
import pandas as pd

from transformer_lens import HookedTransformer

from data_utils import read_jsonl, save_jsonl, process_dataset

torch.set_grad_enabled(False)

# number tokens for each model
DIGIT_IDS_DICT = {
    "meta-llama/Llama-2-7b-hf": [29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929],
    "mistralai/Mistral-7B-v0.1": [28734, 28740, 28750, 28770, 28781, 28782, 28784, 28787, 28783, 28774],
    "google/gemma-2-9b": [235276, 235274, 235284, 235304, 235310, 235308, 235318, 235324, 235321, 235315],
    "Qwen/Qwen2.5-7B": [15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
    "meta-llama/Meta-Llama-3-8B": [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 605, 806, 717, 1032, 975, 868, 845, 1114, 972, 777, 508, 1691, 1313, 1419, 1187, 914, 1627, 1544, 1591, 1682, 966, 2148, 843, 1644, 1958, 1758, 1927, 1806, 1987, 2137, 1272, 3174, 2983, 3391, 2096, 1774, 2790, 2618, 2166, 2491, 1135, 3971, 4103, 4331, 4370, 2131, 3487, 3226, 2970, 2946, 1399, 5547, 5538, 5495, 1227, 2397, 2287, 3080, 2614, 3076, 2031, 6028, 5332, 5958, 5728, 2075, 4767, 2813, 2495, 4643, 1490, 5932, 6086, 6069, 5833, 5313, 4218, 4044, 2421, 4578, 1954, 5925, 6083, 6365, 6281, 2721, 4161, 3534, 3264, 1484, 1041, 4645, 4278, 6889, 6849, 6550, 7461, 7699, 6640, 7743, 5120, 5037, 7261, 8190, 8011, 7322, 8027, 8546, 8899, 9079, 4364, 7994, 8259, 4513, 8874, 6549, 9390, 6804, 4386, 9748, 5894, 9263, 9413, 9423, 9565, 8878, 9795, 10148, 10350, 10125, 6860, 9335, 10239, 10290, 8929, 9591, 10465, 10288, 10410, 10161, 3965, 9690, 9756, 9800, 10559, 9992, 10132, 10895, 11286, 11068, 6330, 10718, 10674, 9892, 10513, 10680, 11247, 11515, 8953, 11739, 8258, 11123, 10861, 11908, 11771, 10005, 10967, 11242, 11256, 11128, 5245, 10562, 10828, 10750, 10336, 9741, 9714, 9674, 9367, 9378, 7028, 7529, 5926, 7285, 6393, 6280, 5162, 4468, 3753, 2550, 1049, 679, 2366, 9639, 7854, 10866, 11056, 12060, 12171, 12652, 8848, 11483, 11227, 11702, 11584, 12112, 12463, 13460, 13302, 13762, 8610, 12425, 9716, 12533, 10697, 11057, 14057, 14206, 14261, 14378, 9870, 12245, 12338, 12994, 11727, 12422, 14087, 14590, 13895, 14815, 8273, 13341, 12754, 14052, 13719, 13078, 14205, 14125, 14185, 14735, 5154, 13860, 12326, 14022, 12375, 3192, 4146, 15574, 15966, 15537, 11387, 15602, 14274, 15666, 12815, 14374, 15999, 16567, 16332, 16955, 10914, 15828, 15741, 15451, 16590, 14417, 16660, 16367, 16949, 17267, 11209, 15282, 16544, 16085, 17058, 15935, 17361, 17897, 15287, 17212, 13754, 17335, 16443, 17313, 17168, 16780, 17408, 18163, 17690, 15531, 3101, 12405, 13121, 13236, 12166, 13364, 12879, 14777, 14498, 15500, 12226, 15134, 13384, 15231, 16104, 15189, 15340, 16718, 17592, 16874, 9588, 14423, 15805, 15726, 16723, 15257, 17470, 13817, 16884, 18196, 10568, 16707, 17079, 8765, 17153, 16596, 17014, 17609, 18633, 17887, 13679, 16546, 17590, 16522, 17451, 12901, 18061, 17678, 19746, 18634, 8652, 18113, 16482, 17228, 18384, 17306, 18349, 18520, 17112, 19192, 6843, 18277, 18509, 18199, 15951, 12676, 18044, 18775, 19057, 19929, 14648, 18650, 17662, 18017, 18265, 12935, 18322, 10898, 19166, 19867, 13897, 19162, 18781, 19230, 12910, 18695, 16481, 20062, 19081, 20422, 15515, 19631, 19695, 18252, 20077, 19498, 19615, 20698, 19838, 18572, 3443, 10841, 16496, 13074, 7507, 16408, 17264, 18501, 18058, 12378, 14487, 17337, 17574, 19288, 17448, 18136, 17763, 19561, 19770, 19391, 12819, 18245, 16460, 19711, 18517, 17837, 20363, 20465, 19140, 16371, 14245, 19852, 16739, 20153, 20165, 19305, 21299, 18318, 20596, 20963, 14868, 18495, 20502, 17147, 14870, 19697, 20385, 20800, 19956, 21125, 10617, 20360, 21098, 20235, 20555, 20325, 10961, 21675, 21209, 22094, 16551, 19608, 20911, 21290, 21033, 19988, 21404, 20419, 20304, 21330, 17711, 20617, 21757, 21505, 21358, 19799, 22191, 21144, 22086, 21848, 11738, 21235, 21984, 21884, 20339, 19773, 21511, 22184, 21310, 22418, 18518, 21824, 21776, 22741, 22054, 21038, 19447, 22640, 21962, 18162, 2636, 14408, 17824, 17735, 18048, 17786, 19673, 20068, 19869, 12448, 15633, 18625, 8358, 21164, 20998, 19633, 20571, 22507, 21312, 21851, 15830, 20767, 20936, 21123, 21177, 18415, 22593, 22369, 21458, 21618, 17252, 20823, 20711, 21876, 22467, 20618, 21600, 19038, 22600, 23033, 17048, 22058, 21791, 19642, 21239, 20749, 22048, 23215, 22287, 22782, 13506, 21860, 21478, 22663, 22303, 14148, 20866, 23906, 22895, 22424, 17698, 20460, 19242, 21789, 22210, 20943, 23477, 19282, 22049, 23642, 18712, 22005, 22468, 22529, 23402, 21228, 20758, 23411, 22915, 24847, 18216, 23864, 23670, 23493, 23816, 21535, 22345, 22159, 20691, 22905, 20615, 24380, 20128, 22608, 23428, 22754, 24515, 24574, 21856, 21944, 5067, 18262, 20224, 21006, 20354, 19666, 20213, 21996, 19944, 21138, 17608, 20973, 21018, 22922, 22638, 21385, 21379, 21717, 21985, 23388, 17416, 22488, 19808, 22801, 23000, 15894, 22385, 23103, 23574, 24239, 18660, 21729, 20775, 23736, 24307, 22276, 22422, 21788, 24495, 23079, 14033, 23525, 22266, 22956, 21975, 22926, 22642, 22644, 23802, 24734, 13655, 23409, 23181, 21598, 21969, 15573, 20744, 23480, 23654, 25090, 19274, 24132, 24199, 24491, 23888, 23467, 10943, 19774, 24427, 25289, 21218, 23403, 22768, 24938, 25513, 21129, 24187, 24375, 17458, 25136, 17814, 25091, 25178, 24887, 24313, 23717, 22347, 21897, 23292, 25458, 21741, 25168, 25073, 25298, 25392, 24394, 23578, 25388, 25169, 23459, 7007, 19597, 20253, 20436, 21949, 21469, 22457, 18770, 21295, 22874, 19027, 22375, 22708, 22977, 23193, 22744, 23929, 25150, 21982, 24758, 13104, 20873, 23024, 24388, 24735, 23309, 24430, 23486, 24054, 22194, 20785, 24626, 24289, 24865, 24438, 24939, 23969, 22039, 25527, 25809, 21112, 25021, 25560, 26260, 23800, 23901, 25594, 23619, 20338, 25541, 11711, 23986, 23644, 25504, 23952, 23532, 24456, 23776, 25302, 26439, 19104, 25110, 24376, 26083, 24402, 22240, 25358, 23275, 17521, 24619, 20772, 24876, 23624, 23267, 24472, 22908, 23823, 15831, 23592, 25659, 19423, 21893, 23833, 26008, 22148, 22539, 25251, 23171, 24216, 16474, 22876, 26234, 24763, 24531, 25926, 25808, 24832, 25314, 26519, 23987, 4728, 17973, 13135, 20899, 20417, 21032, 22397, 23178, 11770, 21474, 19232, 22588, 19270, 24288, 25498, 23582, 23713, 25528, 23141, 18831, 18248, 23282, 23105, 23848, 25016, 22091, 23038, 24920, 22716, 26218, 21221, 25009, 23879, 22904, 26223, 23424, 25192, 26244, 24250, 25465, 19899, 25496, 25377, 23996, 24344, 24650, 26563, 25125, 24951, 26537, 16217, 24866, 24571, 25724, 25515, 22869, 25505, 20907, 23805, 24061, 18670, 24963, 24071, 26051, 19355, 24678, 22455, 26013, 25862, 26497, 22440, 25665, 25303, 25747, 25822, 17419, 24870, 23873, 25890, 25622, 19272, 25339, 23213, 24902, 25962, 19445, 25399, 26058, 12251, 25354, 21381, 24962, 24110, 26088, 26227, 25238, 24542, 24777, 24809, 22889, 7467, 19319, 21026, 23305, 22777, 22393, 22224, 23505, 23629, 21278, 21056, 17000, 22750, 24331, 24579, 22387, 24487, 24391, 25828, 24337, 18485, 22536, 20275, 22614, 23890, 21910, 26026, 26437, 25001, 25344, 19306, 25717, 25401, 25806, 24347, 26970, 25612, 21936, 25454, 26164, 21251, 21322, 20249, 26576, 25687, 24599, 26491, 26511, 26979, 24680, 15862, 24989, 24597, 25326, 25741, 25875, 26067, 27341, 27079, 26328, 16415, 26114, 26366, 26087, 26281, 24837, 25285, 27134, 23386, 24792, 21133, 25693, 24425, 24471, 26007, 24609, 25208, 26409, 17272, 25476, 19068, 25643, 25873, 24742, 23812, 24961, 27468, 22207, 24538, 25350, 19146, 24606, 22992, 24242, 22897, 22101, 23031, 22694, 19416, 5500],
    "microsoft/phi-4": [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 605, 806, 717, 1032, 975, 868, 845, 1114, 972, 777, 508, 1691, 1313, 1419, 1187, 914, 1627, 1544, 1591, 1682, 966, 2148, 843, 1644, 1958, 1758, 1927, 1806, 1987, 2137, 1272, 3174, 2983, 3391, 2096, 1774, 2790, 2618, 2166, 2491, 1135, 3971, 4103, 4331, 4370, 2131, 3487, 3226, 2970, 2946, 1399, 5547, 5538, 5495, 1227, 2397, 2287, 3080, 2614, 3076, 2031, 6028, 5332, 5958, 5728, 2075, 4767, 2813, 2495, 4643, 1490, 5932, 6086, 6069, 5833, 5313, 4218, 4044, 2421, 4578, 1954, 5925, 6083, 6365, 6281, 2721, 4161, 3534, 3264, 1484, 1041, 4645, 4278, 6889, 6849, 6550, 7461, 7699, 6640, 7743, 5120, 5037, 7261, 8190, 8011, 7322, 8027, 8546, 8899, 9079, 4364, 7994, 8259, 4513, 8874, 6549, 9390, 6804, 4386, 9748, 5894, 9263, 9413, 9423, 9565, 8878, 9795, 10148, 10350, 10125, 6860, 9335, 10239, 10290, 8929, 9591, 10465, 10288, 10410, 10161, 3965, 9690, 9756, 9800, 10559, 9992, 10132, 10895, 11286, 11068, 6330, 10718, 10674, 9892, 10513, 10680, 11247, 11515, 8953, 11739, 8258, 11123, 10861, 11908, 11771, 10005, 10967, 11242, 11256, 11128, 5245, 10562, 10828, 10750, 10336, 9741, 9714, 9674, 9367, 9378, 7028, 7529, 5926, 7285, 6393, 6280, 5162, 4468, 3753, 2550, 1049, 679, 2366, 9639, 7854, 10866, 11056, 12060, 12171, 12652, 8848, 11483, 11227, 11702, 11584, 12112, 12463, 13460, 13302, 13762, 8610, 12425, 9716, 12533, 10697, 11057, 14057, 14206, 14261, 14378, 9870, 12245, 12338, 12994, 11727, 12422, 14087, 14590, 13895, 14815, 8273, 13341, 12754, 14052, 13719, 13078, 14205, 14125, 14185, 14735, 5154, 13860, 12326, 14022, 12375, 3192, 4146, 15574, 15966, 15537, 11387, 15602, 14274, 15666, 12815, 14374, 15999, 16567, 16332, 16955, 10914, 15828, 15741, 15451, 16590, 14417, 16660, 16367, 16949, 17267, 11209, 15282, 16544, 16085, 17058, 15935, 17361, 17897, 15287, 17212, 13754, 17335, 16443, 17313, 17168, 16780, 17408, 18163, 17690, 15531, 3101, 12405, 13121, 13236, 12166, 13364, 12879, 14777, 14498, 15500, 12226, 15134, 13384, 15231, 16104, 15189, 15340, 16718, 17592, 16874, 9588, 14423, 15805, 15726, 16723, 15257, 17470, 13817, 16884, 18196, 10568, 16707, 17079, 8765, 17153, 16596, 17014, 17609, 18633, 17887, 13679, 16546, 17590, 16522, 17451, 12901, 18061, 17678, 19746, 18634, 8652, 18113, 16482, 17228, 18384, 17306, 18349, 18520, 17112, 19192, 6843, 18277, 18509, 18199, 15951, 12676, 18044, 18775, 19057, 19929, 14648, 18650, 17662, 18017, 18265, 12935, 18322, 10898, 19166, 19867, 13897, 19162, 18781, 19230, 12910, 18695, 16481, 20062, 19081, 20422, 15515, 19631, 19695, 18252, 20077, 19498, 19615, 20698, 19838, 18572, 3443, 10841, 16496, 13074, 7507, 16408, 17264, 18501, 18058, 12378, 14487, 17337, 17574, 19288, 17448, 18136, 17763, 19561, 19770, 19391, 12819, 18245, 16460, 19711, 18517, 17837, 20363, 20465, 19140, 16371, 14245, 19852, 16739, 20153, 20165, 19305, 21299, 18318, 20596, 20963, 14868, 18495, 20502, 17147, 14870, 19697, 20385, 20800, 19956, 21125, 10617, 20360, 21098, 20235, 20555, 20325, 10961, 21675, 21209, 22094, 16551, 19608, 20911, 21290, 21033, 19988, 21404, 20419, 20304, 21330, 17711, 20617, 21757, 21505, 21358, 19799, 22191, 21144, 22086, 21848, 11738, 21235, 21984, 21884, 20339, 19773, 21511, 22184, 21310, 22418, 18518, 21824, 21776, 22741, 22054, 21038, 19447, 22640, 21962, 18162, 2636, 14408, 17824, 17735, 18048, 17786, 19673, 20068, 19869, 12448, 15633, 18625, 8358, 21164, 20998, 19633, 20571, 22507, 21312, 21851, 15830, 20767, 20936, 21123, 21177, 18415, 22593, 22369, 21458, 21618, 17252, 20823, 20711, 21876, 22467, 20618, 21600, 19038, 22600, 23033, 17048, 22058, 21791, 19642, 21239, 20749, 22048, 23215, 22287, 22782, 13506, 21860, 21478, 22663, 22303, 14148, 20866, 23906, 22895, 22424, 17698, 20460, 19242, 21789, 22210, 20943, 23477, 19282, 22049, 23642, 18712, 22005, 22468, 22529, 23402, 21228, 20758, 23411, 22915, 24847, 18216, 23864, 23670, 23493, 23816, 21535, 22345, 22159, 20691, 22905, 20615, 24380, 20128, 22608, 23428, 22754, 24515, 24574, 21856, 21944, 5067, 18262, 20224, 21006, 20354, 19666, 20213, 21996, 19944, 21138, 17608, 20973, 21018, 22922, 22638, 21385, 21379, 21717, 21985, 23388, 17416, 22488, 19808, 22801, 23000, 15894, 22385, 23103, 23574, 24239, 18660, 21729, 20775, 23736, 24307, 22276, 22422, 21788, 24495, 23079, 14033, 23525, 22266, 22956, 21975, 22926, 22642, 22644, 23802, 24734, 13655, 23409, 23181, 21598, 21969, 15573, 20744, 23480, 23654, 25090, 19274, 24132, 24199, 24491, 23888, 23467, 10943, 19774, 24427, 25289, 21218, 23403, 22768, 24938, 25513, 21129, 24187, 24375, 17458, 25136, 17814, 25091, 25178, 24887, 24313, 23717, 22347, 21897, 23292, 25458, 21741, 25168, 25073, 25298, 25392, 24394, 23578, 25388, 25169, 23459, 7007, 19597, 20253, 20436, 21949, 21469, 22457, 18770, 21295, 22874, 19027, 22375, 22708, 22977, 23193, 22744, 23929, 25150, 21982, 24758, 13104, 20873, 23024, 24388, 24735, 23309, 24430, 23486, 24054, 22194, 20785, 24626, 24289, 24865, 24438, 24939, 23969, 22039, 25527, 25809, 21112, 25021, 25560, 26260, 23800, 23901, 25594, 23619, 20338, 25541, 11711, 23986, 23644, 25504, 23952, 23532, 24456, 23776, 25302, 26439, 19104, 25110, 24376, 26083, 24402, 22240, 25358, 23275, 17521, 24619, 20772, 24876, 23624, 23267, 24472, 22908, 23823, 15831, 23592, 25659, 19423, 21893, 23833, 26008, 22148, 22539, 25251, 23171, 24216, 16474, 22876, 26234, 24763, 24531, 25926, 25808, 24832, 25314, 26519, 23987, 4728, 17973, 13135, 20899, 20417, 21032, 22397, 23178, 11770, 21474, 19232, 22588, 19270, 24288, 25498, 23582, 23713, 25528, 23141, 18831, 18248, 23282, 23105, 23848, 25016, 22091, 23038, 24920, 22716, 26218, 21221, 25009, 23879, 22904, 26223, 23424, 25192, 26244, 24250, 25465, 19899, 25496, 25377, 23996, 24344, 24650, 26563, 25125, 24951, 26537, 16217, 24866, 24571, 25724, 25515, 22869, 25505, 20907, 23805, 24061, 18670, 24963, 24071, 26051, 19355, 24678, 22455, 26013, 25862, 26497, 22440, 25665, 25303, 25747, 25822, 17419, 24870, 23873, 25890, 25622, 19272, 25339, 23213, 24902, 25962, 19445, 25399, 26058, 12251, 25354, 21381, 24962, 24110, 26088, 26227, 25238, 24542, 24777, 24809, 22889, 7467, 19319, 21026, 23305, 22777, 22393, 22224, 23505, 23629, 21278, 21056, 17000, 22750, 24331, 24579, 22387, 24487, 24391, 25828, 24337, 18485, 22536, 20275, 22614, 23890, 21910, 26026, 26437, 25001, 25344, 19306, 25717, 25401, 25806, 24347, 26970, 25612, 21936, 25454, 26164, 21251, 21322, 20249, 26576, 25687, 24599, 26491, 26511, 26979, 24680, 15862, 24989, 24597, 25326, 25741, 25875, 26067, 27341, 27079, 26328, 16415, 26114, 26366, 26087, 26281, 24837, 25285, 27134, 23386, 24792, 21133, 25693, 24425, 24471, 26007, 24609, 25208, 26409, 17272, 25476, 19068, 25643, 25873, 24742, 23812, 24961, 27468, 22207, 24538, 25350, 19146, 24606, 22992, 24242, 22897, 22101, 23031, 22694, 19416, 5500],
    "EleutherAI/pythia-6.9b-deduped": [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 740, 883, 805, 1012, 1047, 1010, 1036, 1166, 1093, 746, 938, 1797, 1423, 1508, 1348, 1099, 1731, 1630, 1619, 1717, 1229, 2405, 1237, 1610, 1706, 1671, 1812, 1787, 1839, 1867, 1449, 3156, 2945, 3079, 2031, 1857, 2950, 2504, 2385, 2537, 1235, 3712, 3583, 3357, 3439, 2417, 3208, 3011, 3680, 3046, 1549, 3832, 3763, 3571, 1540, 2082, 2526, 2251, 2358, 2090, 1967, 3677, 3547, 3655, 3566, 1976, 3121, 2357, 3141, 2787, 1438, 3593, 3507, 3245, 2759, 2227, 2691, 2597, 2055, 2511, 2270, 4739, 4529, 4590, 3953, 2222, 4196, 4148, 4185, 1525],
    "EleutherAI/pythia-12b-deduped": [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 740, 883, 805, 1012, 1047, 1010, 1036, 1166, 1093, 746, 938, 1797, 1423, 1508, 1348, 1099, 1731, 1630, 1619, 1717, 1229, 2405, 1237, 1610, 1706, 1671, 1812, 1787, 1839, 1867, 1449, 3156, 2945, 3079, 2031, 1857, 2950, 2504, 2385, 2537, 1235, 3712, 3583, 3357, 3439, 2417, 3208, 3011, 3680, 3046, 1549, 3832, 3763, 3571, 1540, 2082, 2526, 2251, 2358, 2090, 1967, 3677, 3547, 3655, 3566, 1976, 3121, 2357, 3141, 2787, 1438, 3593, 3507, 3245, 2759, 2227, 2691, 2597, 2055, 2511, 2270, 4739, 4529, 4590, 3953, 2222, 4196, 4148, 4185, 1525],
}
NMAX_TO_LEN = {9: 1, 99: 2, 999: 3}

def use_single_digit_tokenization(model_name):
    return len(DIGIT_IDS_DICT[model_name]) == 10

def run_inference(model, datapoints, nmax, model_name):

    cf_n_correct = 0
    orig_n_correct = 0

    for dp in tqdm.tqdm(datapoints):

        assert len(dp["contrast_output"]) == len(dp["normal_output"])
        digit_ids = DIGIT_IDS_DICT[model_name]
        decoding_len = len(dp["contrast_output"]) if use_single_digit_tokenization(model_name) else 1

        # part 1: does the model decode the contrast output
        logits = model.forward(
            input=dp["contrast_input"]+dp["contrast_output"],
            return_type="logits",
        )
        _output = ""
        contrast_output_logit, normal_output_logit, all_digit_logit = [], [], []
        for pos in range(-decoding_len-1, -1, 1):
            pos_token_logits = logits[:, pos, :].squeeze(0)
            digit_logits = pos_token_logits[digit_ids]
            _output += str(digit_logits.argmax().item())

            contrast_output_digit = int(dp["contrast_output"][pos+decoding_len+1])
            contrast_output_logit.append(digit_logits[contrast_output_digit].item())
            all_digit_logit.append(digit_logits.tolist())
            
        cf_n_correct += int(_output == dp["contrast_output"])
        dp["contrast_pred"] = _output
        dp["contrast_output_logit"] = contrast_output_logit
        dp["all_digit_logit"] = all_digit_logit

        # part 2: does the model decode the normal output
        _output = ""
        logits = model.forward(
            input=dp["contrast_input"]+dp["normal_output"],
            return_type="logits",
        )
        for pos in range(-decoding_len-1, -1, 1):
            pos_token_logits = logits[:, pos, :].squeeze(0)
            digit_logits = pos_token_logits[digit_ids]
            _output += str(digit_logits.argmax().item())

            normal_output_digit = int(dp["normal_output"][pos+decoding_len+1])
            normal_output_logit.append(digit_logits[normal_output_digit].item())

        orig_n_correct += int(_output == dp["normal_output"])
        dp["normal_output_logit"] = normal_output_logit

        
    cf_acc = cf_n_correct/len(datapoints)
    orig_acc = orig_n_correct/len(datapoints)

    return datapoints, cf_acc, orig_acc

def main():
    skip_func = lambda nmax, offset: (nmax == 9 and abs(offset) > 2)

    df = pd.DataFrame(columns=["model_name", "setting", "nmax", "offset", "n_icl_example", "cf_acc", "orig_acc"])
    df_filename = "results.csv"

    for model_name in ["google/gemma-2-9b", "mistralai/Mistral-7B-v0.1", "meta-llama/Llama-2-7b-hf"]:
    # for model_name in ["Qwen/Qwen2.5-7B", "meta-llama/Meta-Llama-3-8B", "microsoft/phi-4"]:
    # for model_name in ["EleutherAI/pythia-6.9b-deduped", "EleutherAI/pythia-12b-deduped"]:
        dtype = "fp16" if model_name in ["microsoft/phi-4", "EleutherAI/pythia-12b-deduped"] else "fp32"
        model = HookedTransformer.from_pretrained(model_name, device="cuda", dtype=dtype)

        for setting in ["normal", "setting1", "setting2"]:
            for nmax in [9, 99, 999]:
                for offset in range(-10, 11, 1):
                    for n_icl_examples in [2, 4, 8, 16, 32]:
                        if skip_func(nmax, offset): continue

                        filename = f"../data/addition/{setting}/addition_nmax{nmax}_offset{offset}.jsonl"
                        data = read_jsonl(filename)
                        processed_data = process_dataset(data, n_icl_examples=n_icl_examples, offset=offset)

                        datapoints, cf_acc, orig_acc = run_inference(model, processed_data, nmax, model_name)
                        print(f"model_name: {model_name}, setting: {setting}, nmax: {nmax}, offset: {offset}, n_icl_examples: {n_icl_examples}")
                        print(f"cf_acc: {cf_acc}, orig_acc: {orig_acc}")

                        df.loc[len(df)] = [model_name, setting, nmax, offset, n_icl_examples, cf_acc, orig_acc]
                        df.to_csv(df_filename)

                        
        del model
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()