{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|████████████████| 19/19 [00:29<00:00,  1.53s/it]\n",
      "Loading checkpoint shards: 100%|████████████████| 19/19 [00:28<00:00,  1.52s/it]\n",
      "====CUT Config====\n",
      "model_name_or_path=mistralai/Mixtral-8x7B-Instruct-v0.1\n",
      "module_str={model_name}.model.layers[{layer_id}]\n",
      "output_dir=models/mixtral_cut_0\n",
      "retain_corpora=['wikitext', 'wikitext']\n",
      "forget_corpora=['bio-forget-corpus', 'cyber-forget-corpus']\n",
      "alpha=[1600.0, 1600.0]\n",
      "steering_coeffs=300,300\n",
      "lr=5e-05\n",
      "min_len=200\n",
      "max_len=2000\n",
      "batch_size=2\n",
      "max_num_batches=400\n",
      "layer_id=7\n",
      "layer_ids=[5, 6, 7]\n",
      "param_ids=[7]\n",
      "seed=42\n",
      "steering_coeff_list=[300.0, 300.0]\n",
      "=====\n",
      "/data/long_phan/anaconda3/lib/python3.10/site-packages/transformers/optimization.py:429: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "======= Epoch 0 =======\n",
      "  0%|                                                   | 0/400 [00:00<?, ?it/s]2024-04-16 00:20:46.213991: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-04-16 00:20:47.131226: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
      "/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 512, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 26.25 | unlearn_loss: 26.25 | retain_loss: 0 | param_change: 2.459e-06\n",
      "  0%|                                           | 1/400 [00:07<50:40,  7.62s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 768, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 25 | unlearn_loss: 25 | retain_loss: 0.01208 | param_change: 1.025e-05\n",
      "  0%|▏                                          | 2/400 [00:09<28:14,  4.26s/it]loss: 26.38 | unlearn_loss: 26.25 | retain_loss: 0.09082 | param_change: 0.0001221\n",
      "  1%|▎                                          | 3/400 [00:10<19:04,  2.88s/it]loss: 25.38 | unlearn_loss: 25 | retain_loss: 0.4336 | param_change: 0.0004025\n",
      "  1%|▍                                          | 4/400 [00:12<15:03,  2.28s/it]loss: 26.5 | unlearn_loss: 26.25 | retain_loss: 0.2412 | param_change: 0.0001411\n",
      "  1%|▌                                          | 5/400 [00:13<12:43,  1.93s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.2617 | param_change: 0.0002003\n",
      "  2%|▋                                          | 6/400 [00:14<11:22,  1.73s/it]loss: 26.25 | unlearn_loss: 26.25 | retain_loss: 0.01355 | param_change: 1.508e-05\n",
      "  2%|▊                                          | 7/400 [00:15<10:08,  1.55s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 671, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 25.38 | unlearn_loss: 25.38 | retain_loss: 0.0177 | param_change: 2.813e-05\n",
      "  2%|▊                                          | 8/400 [00:17<09:22,  1.44s/it]loss: 26.25 | unlearn_loss: 26.25 | retain_loss: 0.02271 | param_change: 1.049e-05\n",
      "  2%|▉                                          | 9/400 [00:18<08:54,  1.37s/it]loss: 25 | unlearn_loss: 25 | retain_loss: 0.01868 | param_change: 7.54e-06\n",
      "  2%|█                                         | 10/400 [00:19<08:44,  1.35s/it]loss: 26.25 | unlearn_loss: 26.25 | retain_loss: 0.02344 | param_change: 1.466e-05\n",
      "  3%|█▏                                        | 11/400 [00:20<08:31,  1.32s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 577, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 25.88 | unlearn_loss: 25.88 | retain_loss: 0.03271 | param_change: 1.788e-05\n",
      "  3%|█▎                                        | 12/400 [00:22<08:06,  1.25s/it]loss: 26.38 | unlearn_loss: 26.25 | retain_loss: 0.1729 | param_change: 5.96e-05\n",
      "  3%|█▎                                        | 13/400 [00:23<07:57,  1.23s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.08838 | param_change: 3.481e-05\n",
      "  4%|█▍                                        | 14/400 [00:24<08:02,  1.25s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1865 | param_change: 9.656e-06\n",
      "  4%|█▌                                        | 15/400 [00:25<07:55,  1.24s/it]loss: 25.75 | unlearn_loss: 25.75 | retain_loss: 0.04761 | param_change: 6.855e-06\n",
      "  4%|█▋                                        | 16/400 [00:26<07:57,  1.24s/it]loss: 26.38 | unlearn_loss: 26.25 | retain_loss: 0.1177 | param_change: 6.676e-06\n",
      "  4%|█▊                                        | 17/400 [00:28<08:00,  1.26s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.09863 | param_change: 7.182e-06\n",
      "  4%|█▉                                        | 18/400 [00:29<08:05,  1.27s/it]loss: 26.38 | unlearn_loss: 26.25 | retain_loss: 0.1494 | param_change: 1.466e-05\n",
      "  5%|█▉                                        | 19/400 [00:30<07:55,  1.25s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1377 | param_change: 1.609e-05\n",
      "  5%|██                                        | 20/400 [00:32<07:57,  1.26s/it]loss: 26.25 | unlearn_loss: 26.25 | retain_loss: 0.04443 | param_change: 9.477e-06\n",
      "  5%|██▏                                       | 21/400 [00:33<07:52,  1.25s/it]loss: 25.75 | unlearn_loss: 25.62 | retain_loss: 0.07422 | param_change: 8.464e-06\n",
      "  6%|██▎                                       | 22/400 [00:34<08:03,  1.28s/it]loss: 26.38 | unlearn_loss: 26.25 | retain_loss: 0.1143 | param_change: 1.836e-05\n",
      "  6%|██▍                                       | 23/400 [00:35<07:50,  1.25s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1123 | param_change: 1.067e-05\n",
      "  6%|██▌                                       | 24/400 [00:37<07:53,  1.26s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.0835 | param_change: 2.193e-05\n",
      "  6%|██▋                                       | 25/400 [00:38<07:37,  1.22s/it]loss: 25 | unlearn_loss: 25 | retain_loss: 0.0437 | param_change: 1.371e-05\n",
      "  6%|██▋                                       | 26/400 [00:39<07:32,  1.21s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1611 | param_change: 1.299e-05\n",
      "  7%|██▊                                       | 27/400 [00:40<07:18,  1.17s/it]loss: 26.12 | unlearn_loss: 26 | retain_loss: 0.08105 | param_change: 1.377e-05\n",
      "  7%|██▉                                       | 28/400 [00:41<07:22,  1.19s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.293 | param_change: 1.216e-05\n",
      "  7%|███                                       | 29/400 [00:42<07:12,  1.17s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.2969 | param_change: 1.007e-05\n",
      "  8%|███▏                                      | 30/400 [00:44<07:17,  1.18s/it]loss: 26.38 | unlearn_loss: 26.25 | retain_loss: 0.1367 | param_change: 1.156e-05\n",
      "  8%|███▎                                      | 31/400 [00:45<07:21,  1.20s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1162 | param_change: 9.418e-06\n",
      "  8%|███▎                                      | 32/400 [00:46<07:33,  1.23s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1562 | param_change: 2.36e-05\n",
      "  8%|███▍                                      | 33/400 [00:47<07:18,  1.19s/it]loss: 25.75 | unlearn_loss: 25.62 | retain_loss: 0.07324 | param_change: 1.317e-05\n",
      "  8%|███▌                                      | 34/400 [00:48<07:16,  1.19s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1182 | param_change: 2.539e-05\n",
      "  9%|███▋                                      | 35/400 [00:49<07:07,  1.17s/it]loss: 25.25 | unlearn_loss: 25.12 | retain_loss: 0.06445 | param_change: 1.442e-05\n",
      "  9%|███▊                                      | 36/400 [00:51<07:11,  1.18s/it]loss: 26.38 | unlearn_loss: 26.25 | retain_loss: 0.1177 | param_change: 8.166e-06\n",
      "  9%|███▉                                      | 37/400 [00:52<07:22,  1.22s/it]loss: 25.75 | unlearn_loss: 25.62 | retain_loss: 0.08301 | param_change: 1.138e-05\n",
      " 10%|███▉                                      | 38/400 [00:53<07:31,  1.25s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.07666 | param_change: 1.27e-05\n",
      " 10%|████                                      | 39/400 [00:54<07:20,  1.22s/it]loss: 25 | unlearn_loss: 25 | retain_loss: 0.05542 | param_change: 1.293e-05\n",
      " 10%|████▏                                     | 40/400 [00:56<07:24,  1.23s/it]loss: 26.12 | unlearn_loss: 26.12 | retain_loss: 0.05103 | param_change: 1.025e-05\n",
      " 10%|████▎                                     | 41/400 [00:57<07:30,  1.25s/it]loss: 25 | unlearn_loss: 25 | retain_loss: 0.0376 | param_change: 1.478e-05\n",
      " 10%|████▍                                     | 42/400 [00:58<07:41,  1.29s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.3105 | param_change: 1.979e-05\n",
      " 11%|████▌                                     | 43/400 [01:00<07:38,  1.28s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.2695 | param_change: 1.538e-05\n",
      " 11%|████▌                                     | 44/400 [01:01<07:43,  1.30s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2246 | param_change: 2.182e-05\n",
      " 11%|████▋                                     | 45/400 [01:02<07:21,  1.24s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1641 | param_change: 1.317e-05\n",
      " 12%|████▊                                     | 46/400 [01:03<07:19,  1.24s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.3379 | param_change: 4.005e-05\n",
      " 12%|████▉                                     | 47/400 [01:05<07:09,  1.22s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 236, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 31.12 | unlearn_loss: 30.88 | retain_loss: 0.2812 | param_change: 3.839e-05\n",
      " 12%|█████                                     | 48/400 [01:06<06:46,  1.16s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.08252 | param_change: 3.529e-05\n",
      " 12%|█████▏                                    | 49/400 [01:07<06:54,  1.18s/it]loss: 25 | unlearn_loss: 25 | retain_loss: 0.04663 | param_change: 1.317e-05\n",
      " 12%|█████▎                                    | 50/400 [01:08<07:04,  1.21s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.1973 | param_change: 2.122e-05\n",
      " 13%|█████▎                                    | 51/400 [01:09<07:03,  1.21s/it]loss: 25.25 | unlearn_loss: 25.12 | retain_loss: 0.1357 | param_change: 1.502e-05\n",
      " 13%|█████▍                                    | 52/400 [01:11<07:11,  1.24s/it]loss: 26.62 | unlearn_loss: 26.12 | retain_loss: 0.4668 | param_change: 2.11e-05\n",
      " 13%|█████▌                                    | 53/400 [01:12<07:15,  1.25s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.3203 | param_change: 1.52e-05\n",
      " 14%|█████▋                                    | 54/400 [01:13<07:26,  1.29s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1445 | param_change: 2.778e-05\n",
      " 14%|█████▊                                    | 55/400 [01:14<07:18,  1.27s/it]loss: 25.25 | unlearn_loss: 25.12 | retain_loss: 0.1182 | param_change: 2.432e-05\n",
      " 14%|█████▉                                    | 56/400 [01:16<07:24,  1.29s/it]loss: 26.88 | unlearn_loss: 26.12 | retain_loss: 0.7383 | param_change: 4.363e-05\n",
      " 14%|█████▉                                    | 57/400 [01:17<07:13,  1.26s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 317, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 29.25 | unlearn_loss: 28.75 | retain_loss: 0.4707 | param_change: 1.86e-05\n",
      " 14%|██████                                    | 58/400 [01:18<06:51,  1.20s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1084 | param_change: 1.895e-05\n",
      " 15%|██████▏                                   | 59/400 [01:19<06:55,  1.22s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1055 | param_change: 2.182e-05\n",
      " 15%|██████▎                                   | 60/400 [01:21<07:27,  1.32s/it]loss: 26.38 | unlearn_loss: 26.25 | retain_loss: 0.1387 | param_change: 2.956e-05\n",
      " 15%|██████▍                                   | 61/400 [01:22<07:17,  1.29s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1465 | param_change: 3.791e-05\n",
      " 16%|██████▌                                   | 62/400 [01:23<07:19,  1.30s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2363 | param_change: 2.408e-05\n",
      " 16%|██████▌                                   | 63/400 [01:25<07:11,  1.28s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.1875 | param_change: 1.335e-05\n",
      " 16%|██████▋                                   | 64/400 [01:26<07:13,  1.29s/it]loss: 26.62 | unlearn_loss: 26.12 | retain_loss: 0.4785 | param_change: 3.91e-05\n",
      " 16%|██████▊                                   | 65/400 [01:27<06:59,  1.25s/it]loss: 25.38 | unlearn_loss: 25 | retain_loss: 0.3301 | param_change: 2.062e-05\n",
      " 16%|██████▉                                   | 66/400 [01:28<06:57,  1.25s/it]loss: 26.62 | unlearn_loss: 26.12 | retain_loss: 0.457 | param_change: 1.562e-05\n",
      " 17%|███████                                   | 67/400 [01:30<07:09,  1.29s/it]loss: 25.5 | unlearn_loss: 25 | retain_loss: 0.4395 | param_change: 1.436e-05\n",
      " 17%|███████▏                                  | 68/400 [01:31<07:24,  1.34s/it]loss: 26.88 | unlearn_loss: 26.12 | retain_loss: 0.7031 | param_change: 2.193e-05\n",
      " 17%|███████▏                                  | 69/400 [01:33<07:30,  1.36s/it]loss: 26.25 | unlearn_loss: 25.75 | retain_loss: 0.5391 | param_change: 1.442e-05\n",
      " 18%|███████▎                                  | 70/400 [01:34<07:38,  1.39s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.3965 | param_change: 2.73e-05\n",
      " 18%|███████▍                                  | 71/400 [01:35<07:36,  1.39s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.3027 | param_change: 2.396e-05\n",
      " 18%|███████▌                                  | 72/400 [01:37<07:30,  1.37s/it]loss: 26.62 | unlearn_loss: 26.12 | retain_loss: 0.5195 | param_change: 1.21e-05\n",
      " 18%|███████▋                                  | 73/400 [01:38<07:31,  1.38s/it]loss: 25.5 | unlearn_loss: 25 | retain_loss: 0.4922 | param_change: 1.061e-05\n",
      " 18%|███████▊                                  | 74/400 [01:40<07:41,  1.42s/it]loss: 26.62 | unlearn_loss: 26.12 | retain_loss: 0.5234 | param_change: 2.67e-05\n",
      " 19%|███████▉                                  | 75/400 [01:41<07:24,  1.37s/it]loss: 25.38 | unlearn_loss: 25 | retain_loss: 0.4082 | param_change: 1.669e-05\n",
      " 19%|███████▉                                  | 76/400 [01:42<07:29,  1.39s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.3066 | param_change: 4.911e-05\n",
      " 19%|████████                                  | 77/400 [01:44<07:16,  1.35s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.2031 | param_change: 3.123e-05\n",
      " 20%|████████▏                                 | 78/400 [01:45<07:14,  1.35s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2852 | param_change: 2.086e-05\n",
      " 20%|████████▎                                 | 79/400 [01:46<07:09,  1.34s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.2227 | param_change: 1.645e-05\n",
      " 20%|████████▍                                 | 80/400 [01:48<07:11,  1.35s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1445 | param_change: 2.766e-05\n",
      " 20%|████████▌                                 | 81/400 [01:49<07:02,  1.33s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.09033 | param_change: 1.454e-05\n",
      " 20%|████████▌                                 | 82/400 [01:50<07:05,  1.34s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2393 | param_change: 6.151e-05\n",
      " 21%|████████▋                                 | 83/400 [01:52<06:55,  1.31s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 615, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 25.88 | unlearn_loss: 25.62 | retain_loss: 0.2246 | param_change: 6.914e-05\n",
      " 21%|████████▊                                 | 84/400 [01:53<06:50,  1.30s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1016 | param_change: 4.029e-05\n",
      " 21%|████████▉                                 | 85/400 [01:54<06:39,  1.27s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.07812 | param_change: 2.956e-05\n",
      " 22%|█████████                                 | 86/400 [01:55<06:40,  1.28s/it]loss: 28.62 | unlearn_loss: 26.12 | retain_loss: 2.469 | param_change: 9.108e-05\n",
      " 22%|█████████▏                                | 87/400 [01:57<06:47,  1.30s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 296, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 30.25 | unlearn_loss: 29.25 | retain_loss: 1.016 | param_change: 4.625e-05\n",
      " 22%|█████████▏                                | 88/400 [01:58<06:37,  1.27s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2715 | param_change: 4.458e-05\n",
      " 22%|█████████▎                                | 89/400 [01:59<06:32,  1.26s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.2695 | param_change: 3.29e-05\n",
      " 22%|█████████▍                                | 90/400 [02:00<06:33,  1.27s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.3086 | param_change: 8.774e-05\n",
      " 23%|█████████▌                                | 91/400 [02:02<06:18,  1.22s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.2891 | param_change: 0.0003872\n",
      " 23%|█████████▋                                | 92/400 [02:03<06:18,  1.23s/it]loss: 26.75 | unlearn_loss: 26.12 | retain_loss: 0.6055 | param_change: 6.962e-05\n",
      " 23%|█████████▊                                | 93/400 [02:04<06:28,  1.27s/it]loss: 26.25 | unlearn_loss: 25.62 | retain_loss: 0.5703 | param_change: 7.2e-05\n",
      " 24%|█████████▊                                | 94/400 [02:06<06:37,  1.30s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.4062 | param_change: 5.651e-05\n",
      " 24%|█████████▉                                | 95/400 [02:07<06:36,  1.30s/it]loss: 25.38 | unlearn_loss: 25 | retain_loss: 0.3281 | param_change: 4.029e-05\n",
      " 24%|██████████                                | 96/400 [02:08<06:40,  1.32s/it]loss: 26.75 | unlearn_loss: 26.12 | retain_loss: 0.6016 | param_change: 0.000186\n",
      " 24%|██████████▏                               | 97/400 [02:09<06:34,  1.30s/it]loss: 25.88 | unlearn_loss: 25.5 | retain_loss: 0.3457 | param_change: 9.108e-05\n",
      " 24%|██████████▎                               | 98/400 [02:11<06:46,  1.34s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.1895 | param_change: 1.931e-05\n",
      " 25%|██████████▍                               | 99/400 [02:12<06:29,  1.30s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1621 | param_change: 1.669e-05\n",
      " 25%|██████████▎                              | 100/400 [02:13<06:27,  1.29s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.1982 | param_change: 2.563e-05\n",
      " 25%|██████████▎                              | 101/400 [02:15<06:20,  1.27s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1553 | param_change: 1.657e-05\n",
      " 26%|██████████▍                              | 102/400 [02:16<06:25,  1.29s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2363 | param_change: 2.98e-05\n",
      " 26%|██████████▌                              | 103/400 [02:17<06:19,  1.28s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1445 | param_change: 1.943e-05\n",
      " 26%|██████████▋                              | 104/400 [02:19<06:29,  1.32s/it]loss: 26.62 | unlearn_loss: 26.12 | retain_loss: 0.4883 | param_change: 2.134e-05\n",
      " 26%|██████████▊                              | 105/400 [02:20<06:15,  1.27s/it]loss: 25.5 | unlearn_loss: 25 | retain_loss: 0.4395 | param_change: 1.52e-05\n",
      " 26%|██████████▊                              | 106/400 [02:21<06:16,  1.28s/it]loss: 26.75 | unlearn_loss: 26.12 | retain_loss: 0.6367 | param_change: 5.031e-05\n",
      " 27%|██████████▉                              | 107/400 [02:22<06:02,  1.24s/it]loss: 25.5 | unlearn_loss: 25 | retain_loss: 0.4707 | param_change: 3.076e-05\n",
      " 27%|███████████                              | 108/400 [02:23<06:04,  1.25s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.3887 | param_change: 2.36e-05\n",
      " 27%|███████████▏                             | 109/400 [02:25<06:00,  1.24s/it]loss: 25.38 | unlearn_loss: 25 | retain_loss: 0.3555 | param_change: 1.8e-05\n",
      " 28%|███████████▎                             | 110/400 [02:26<06:02,  1.25s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1387 | param_change: 2.301e-05\n",
      " 28%|███████████▍                             | 111/400 [02:27<05:50,  1.21s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 565, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 26 | unlearn_loss: 25.88 | retain_loss: 0.1177 | param_change: 2.182e-05\n",
      " 28%|███████████▍                             | 112/400 [02:28<05:39,  1.18s/it]loss: 27.12 | unlearn_loss: 26.12 | retain_loss: 0.9688 | param_change: 7.391e-05\n",
      " 28%|███████████▌                             | 113/400 [02:29<05:38,  1.18s/it]loss: 25.62 | unlearn_loss: 25 | retain_loss: 0.6328 | param_change: 3.362e-05\n",
      " 28%|███████████▋                             | 114/400 [02:31<05:43,  1.20s/it]loss: 26.75 | unlearn_loss: 26.12 | retain_loss: 0.6602 | param_change: 4.387e-05\n",
      " 29%|███████████▊                             | 115/400 [02:32<05:37,  1.19s/it]loss: 25.5 | unlearn_loss: 25 | retain_loss: 0.5078 | param_change: 2.706e-05\n",
      " 29%|███████████▉                             | 116/400 [02:33<05:45,  1.22s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.3809 | param_change: 6.485e-05\n",
      " 29%|███████████▉                             | 117/400 [02:34<05:53,  1.25s/it]loss: 25.38 | unlearn_loss: 25 | retain_loss: 0.3672 | param_change: 4.625e-05\n",
      " 30%|████████████                             | 118/400 [02:36<06:03,  1.29s/it]loss: 26.12 | unlearn_loss: 26 | retain_loss: 0.1006 | param_change: 3.123e-05\n",
      " 30%|████████████▏                            | 119/400 [02:37<06:03,  1.29s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.08643 | param_change: 2.444e-05\n",
      " 30%|████████████▎                            | 120/400 [02:38<06:11,  1.33s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.3652 | param_change: 5.937e-05\n",
      " 30%|████████████▍                            | 121/400 [02:40<06:09,  1.33s/it]loss: 25.5 | unlearn_loss: 25.25 | retain_loss: 0.3086 | param_change: 4.506e-05\n",
      " 30%|████████████▌                            | 122/400 [02:41<06:13,  1.34s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1167 | param_change: 3.839e-05\n",
      " 31%|████████████▌                            | 123/400 [02:42<06:07,  1.33s/it]loss: 25.38 | unlearn_loss: 25.25 | retain_loss: 0.06982 | param_change: 2.575e-05\n",
      " 31%|████████████▋                            | 124/400 [02:44<06:12,  1.35s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2441 | param_change: 7.153e-05\n",
      " 31%|████████████▊                            | 125/400 [02:45<06:00,  1.31s/it]loss: 25.12 | unlearn_loss: 24.88 | retain_loss: 0.2178 | param_change: 3.791e-05\n",
      " 32%|████████████▉                            | 126/400 [02:46<06:01,  1.32s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.1953 | param_change: 9.918e-05\n",
      " 32%|█████████████                            | 127/400 [02:48<05:45,  1.27s/it]loss: 26.25 | unlearn_loss: 26 | retain_loss: 0.2168 | param_change: 0.0001383\n",
      " 32%|█████████████                            | 128/400 [02:49<05:40,  1.25s/it]loss: 26.12 | unlearn_loss: 26.12 | retain_loss: 0.03613 | param_change: 3.099e-05\n",
      " 32%|█████████████▏                           | 129/400 [02:50<05:38,  1.25s/it]loss: 25 | unlearn_loss: 25 | retain_loss: 0.05542 | param_change: 4.792e-05\n",
      " 32%|█████████████▎                           | 130/400 [02:51<05:41,  1.27s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1504 | param_change: 4.697e-05\n",
      " 33%|█████████████▍                           | 131/400 [02:53<05:43,  1.28s/it]loss: 25 | unlearn_loss: 24.88 | retain_loss: 0.09766 | param_change: 3.6e-05\n",
      " 33%|█████████████▌                           | 132/400 [02:54<05:49,  1.31s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2471 | param_change: 6.437e-05\n",
      " 33%|█████████████▋                           | 133/400 [02:55<05:59,  1.34s/it]loss: 25.25 | unlearn_loss: 25 | retain_loss: 0.209 | param_change: 4.888e-05\n",
      " 34%|█████████████▋                           | 134/400 [02:57<05:52,  1.33s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2109 | param_change: 7.343e-05\n",
      " 34%|█████████████▊                           | 135/400 [02:58<05:42,  1.29s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1494 | param_change: 5.579e-05\n",
      " 34%|█████████████▉                           | 136/400 [02:59<05:38,  1.28s/it]loss: 26.5 | unlearn_loss: 26 | retain_loss: 0.5 | param_change: 3.958e-05\n",
      " 34%|██████████████                           | 137/400 [03:00<05:37,  1.28s/it]loss: 25.5 | unlearn_loss: 25 | retain_loss: 0.4668 | param_change: 3.672e-05\n",
      " 34%|██████████████▏                          | 138/400 [03:02<05:35,  1.28s/it]loss: 26.75 | unlearn_loss: 26.12 | retain_loss: 0.5625 | param_change: 4.792e-05\n",
      " 35%|██████████████▏                          | 139/400 [03:03<05:22,  1.24s/it]loss: 25.88 | unlearn_loss: 25.5 | retain_loss: 0.4023 | param_change: 1.943e-05\n",
      " 35%|██████████████▎                          | 140/400 [03:04<05:23,  1.24s/it]loss: 26.12 | unlearn_loss: 26 | retain_loss: 0.1787 | param_change: 8.249e-05\n",
      " 35%|██████████████▍                          | 141/400 [03:05<05:13,  1.21s/it]loss: 25 | unlearn_loss: 24.88 | retain_loss: 0.1138 | param_change: 4.983e-05\n",
      " 36%|██████████████▌                          | 142/400 [03:07<05:14,  1.22s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.3789 | param_change: 4.244e-05\n",
      " 36%|██████████████▋                          | 143/400 [03:08<05:12,  1.22s/it]loss: 25.38 | unlearn_loss: 25 | retain_loss: 0.3242 | param_change: 4.435e-05\n",
      " 36%|██████████████▊                          | 144/400 [03:09<05:15,  1.23s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.3359 | param_change: 7.629e-05\n",
      " 36%|██████████████▊                          | 145/400 [03:10<05:09,  1.21s/it]loss: 25.12 | unlearn_loss: 24.88 | retain_loss: 0.249 | param_change: 4.53e-05\n",
      " 36%|██████████████▉                          | 146/400 [03:11<05:10,  1.22s/it]loss: 26.12 | unlearn_loss: 26 | retain_loss: 0.1699 | param_change: 4.244e-05\n",
      " 37%|███████████████                          | 147/400 [03:13<05:12,  1.23s/it]loss: 25 | unlearn_loss: 24.88 | retain_loss: 0.1758 | param_change: 2.992e-05\n",
      " 37%|███████████████▏                         | 148/400 [03:14<05:13,  1.24s/it]loss: 26.5 | unlearn_loss: 26.12 | retain_loss: 0.4004 | param_change: 7.057e-05\n",
      " 37%|███████████████▎                         | 149/400 [03:15<05:20,  1.28s/it]loss: 26.25 | unlearn_loss: 26 | retain_loss: 0.3008 | param_change: 3.91e-05\n",
      " 38%|███████████████▍                         | 150/400 [03:17<05:27,  1.31s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2393 | param_change: 8.202e-05\n",
      " 38%|███████████████▍                         | 151/400 [03:18<05:22,  1.30s/it]loss: 25.25 | unlearn_loss: 25.12 | retain_loss: 0.1797 | param_change: 5.364e-05\n",
      " 38%|███████████████▌                         | 152/400 [03:19<05:24,  1.31s/it]loss: 26.12 | unlearn_loss: 26 | retain_loss: 0.165 | param_change: 2.956e-05\n",
      " 38%|███████████████▋                         | 153/400 [03:20<05:15,  1.28s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 765, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.1641 | param_change: 3.099e-05\n",
      " 38%|███████████████▊                         | 154/400 [03:22<05:15,  1.28s/it]loss: 26.62 | unlearn_loss: 26 | retain_loss: 0.6211 | param_change: 4.339e-05\n",
      " 39%|███████████████▉                         | 155/400 [03:23<05:15,  1.29s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 668, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 25.88 | unlearn_loss: 25.38 | retain_loss: 0.5 | param_change: 3.409e-05\n",
      " 39%|███████████████▉                         | 156/400 [03:24<05:15,  1.29s/it]loss: 26.75 | unlearn_loss: 26.12 | retain_loss: 0.6875 | param_change: 6.628e-05\n",
      " 39%|████████████████                         | 157/400 [03:26<05:11,  1.28s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 408, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 27.88 | unlearn_loss: 27.38 | retain_loss: 0.498 | param_change: 4.864e-05\n",
      " 40%|████████████████▏                        | 158/400 [03:27<05:00,  1.24s/it]loss: 26.62 | unlearn_loss: 26.12 | retain_loss: 0.5078 | param_change: 6.819e-05\n",
      " 40%|████████████████▎                        | 159/400 [03:28<04:58,  1.24s/it]loss: 25.5 | unlearn_loss: 25 | retain_loss: 0.4492 | param_change: 4.411e-05\n",
      " 40%|████████████████▍                        | 160/400 [03:29<05:04,  1.27s/it]loss: 26.38 | unlearn_loss: 26 | retain_loss: 0.4004 | param_change: 6.056e-05\n",
      " 40%|████████████████▌                        | 161/400 [03:31<04:59,  1.25s/it]loss: 25.25 | unlearn_loss: 24.88 | retain_loss: 0.3867 | param_change: 6.151e-05\n",
      " 40%|████████████████▌                        | 162/400 [03:32<04:59,  1.26s/it]loss: 26.25 | unlearn_loss: 26 | retain_loss: 0.2539 | param_change: 7.629e-05\n",
      " 41%|████████████████▋                        | 163/400 [03:33<04:57,  1.26s/it]loss: 25.12 | unlearn_loss: 24.88 | retain_loss: 0.2305 | param_change: 5.198e-05\n",
      " 41%|████████████████▊                        | 164/400 [03:34<04:59,  1.27s/it]loss: 26.62 | unlearn_loss: 26.12 | retain_loss: 0.5156 | param_change: 4.816e-05\n",
      " 41%|████████████████▉                        | 165/400 [03:36<04:53,  1.25s/it]loss: 25.38 | unlearn_loss: 24.88 | retain_loss: 0.4609 | param_change: 3.362e-05\n",
      " 42%|█████████████████                        | 166/400 [03:37<04:55,  1.26s/it]loss: 26.38 | unlearn_loss: 26 | retain_loss: 0.3906 | param_change: 6.485e-05\n",
      " 42%|█████████████████                        | 167/400 [03:38<04:52,  1.25s/it]loss: 25.25 | unlearn_loss: 24.88 | retain_loss: 0.3359 | param_change: 4.315e-05\n",
      " 42%|█████████████████▏                       | 168/400 [03:39<04:54,  1.27s/it]loss: 26.38 | unlearn_loss: 26 | retain_loss: 0.3145 | param_change: 5.078e-05\n",
      " 42%|█████████████████▎                       | 169/400 [03:41<04:46,  1.24s/it]loss: 25.12 | unlearn_loss: 24.88 | retain_loss: 0.207 | param_change: 2.575e-05\n",
      " 42%|█████████████████▍                       | 170/400 [03:42<04:46,  1.25s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2256 | param_change: 5.937e-05\n",
      " 43%|█████████████████▌                       | 171/400 [03:43<04:45,  1.25s/it]loss: 25.38 | unlearn_loss: 25.25 | retain_loss: 0.09863 | param_change: 2.67e-05\n",
      " 43%|█████████████████▋                       | 172/400 [03:44<04:44,  1.25s/it]loss: 26.25 | unlearn_loss: 26.12 | retain_loss: 0.1719 | param_change: 3.171e-05\n",
      " 43%|█████████████████▋                       | 173/400 [03:46<04:42,  1.24s/it]loss: 25 | unlearn_loss: 24.88 | retain_loss: 0.1523 | param_change: 2.325e-05\n",
      " 44%|█████████████████▊                       | 174/400 [03:47<04:43,  1.25s/it]loss: 26.25 | unlearn_loss: 25.88 | retain_loss: 0.4082 | param_change: 0.000145\n",
      " 44%|█████████████████▉                       | 175/400 [03:48<04:38,  1.24s/it]/data/long_phan/wmdp/wmdp/wmdp/cut/unlearn.py:68: UserWarning: Using a target size (torch.Size([1, 1, 4096])) that is different to the input size (torch.Size([2, 425, 4096])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  unlearn_loss = torch.nn.functional.mse_loss(\n",
      "loss: 27.25 | unlearn_loss: 27 | retain_loss: 0.2988 | param_change: 7.582e-05\n",
      " 44%|██████████████████                       | 176/400 [03:49<04:27,  1.19s/it]loss: 26.25 | unlearn_loss: 26 | retain_loss: 0.1904 | param_change: 2.944e-05\n",
      " 44%|██████████████████▏                      | 177/400 [03:50<04:25,  1.19s/it]loss: 24.88 | unlearn_loss: 24.75 | retain_loss: 0.1709 | param_change: 2.861e-05\n",
      " 44%|██████████████████▏                      | 178/400 [03:52<04:34,  1.23s/it]loss: 26.62 | unlearn_loss: 26.12 | retain_loss: 0.4727 | param_change: 5.126e-05\n",
      " 45%|██████████████████▎                      | 179/400 [03:53<04:30,  1.22s/it]loss: 25.38 | unlearn_loss: 24.88 | retain_loss: 0.4766 | param_change: 4.268e-05\n",
      " 45%|██████████████████▍                      | 180/400 [03:54<04:32,  1.24s/it]loss: 26.25 | unlearn_loss: 25.88 | retain_loss: 0.3301 | param_change: 7.534e-05\n",
      " 45%|██████████████████▌                      | 181/400 [03:55<04:32,  1.24s/it]loss: 25.12 | unlearn_loss: 24.88 | retain_loss: 0.2539 | param_change: 6.151e-05\n",
      " 46%|██████████████████▋                      | 182/400 [03:57<04:34,  1.26s/it]loss: 26.38 | unlearn_loss: 26 | retain_loss: 0.4336 | param_change: 0.0001612\n",
      " 46%|██████████████████▊                      | 183/400 [03:58<04:32,  1.26s/it]loss: 25 | unlearn_loss: 24.75 | retain_loss: 0.3066 | param_change: 9.394e-05\n",
      " 46%|██████████████████▊                      | 184/400 [03:59<04:36,  1.28s/it]loss: 26.25 | unlearn_loss: 26 | retain_loss: 0.2891 | param_change: 4.816e-05\n",
      " 46%|██████████████████▉                      | 185/400 [04:01<04:35,  1.28s/it]loss: 25 | unlearn_loss: 24.75 | retain_loss: 0.2617 | param_change: 5.603e-05\n",
      " 46%|███████████████████                      | 186/400 [04:02<04:38,  1.30s/it]loss: 26.38 | unlearn_loss: 26 | retain_loss: 0.3867 | param_change: 9.871e-05\n",
      " 47%|███████████████████▏                     | 187/400 [04:03<04:35,  1.29s/it]loss: 25 | unlearn_loss: 24.75 | retain_loss: 0.3008 | param_change: 7.772e-05\n",
      " 47%|███████████████████▎                     | 188/400 [04:05<04:36,  1.30s/it]loss: 26.5 | unlearn_loss: 26 | retain_loss: 0.4609 | param_change: 3.91e-05\n",
      " 47%|███████████████████▎                     | 189/400 [04:06<04:30,  1.28s/it]loss: 25.38 | unlearn_loss: 25 | retain_loss: 0.3887 | param_change: 3.099e-05\n",
      " 48%|███████████████████▍                     | 190/400 [04:07<04:31,  1.29s/it]loss: 26.38 | unlearn_loss: 26 | retain_loss: 0.3281 | param_change: 9.537e-05\n",
      " 48%|███████████████████▌                     | 191/400 [04:08<04:23,  1.26s/it]loss: 25.38 | unlearn_loss: 25.12 | retain_loss: 0.2539 | param_change: 5.913e-05\n",
      " 48%|███████████████████▋                     | 192/400 [04:10<04:23,  1.27s/it]loss: 26.25 | unlearn_loss: 26 | retain_loss: 0.3027 | param_change: 0.0001082\n",
      " 48%|███████████████████▊                     | 193/400 [04:11<04:17,  1.24s/it]loss: 25.12 | unlearn_loss: 25 | retain_loss: 0.166 | param_change: 5.15e-05\n",
      " 48%|███████████████████▉                     | 194/400 [04:12<04:21,  1.27s/it]loss: 26.38 | unlearn_loss: 26 | retain_loss: 0.4297 | param_change: 0.0001745\n",
      " 49%|███████████████████▉                     | 195/400 [04:13<04:15,  1.25s/it]loss: 25.12 | unlearn_loss: 24.88 | retain_loss: 0.2471 | param_change: 9.06e-05\n",
      " 49%|████████████████████                     | 196/400 [04:15<04:20,  1.27s/it]loss: 26.38 | unlearn_loss: 25.88 | retain_loss: 0.5234 | param_change: 7.439e-05\n",
      " 49%|████████████████████▏                    | 197/400 [04:16<04:11,  1.24s/it]loss: 25.25 | unlearn_loss: 24.88 | retain_loss: 0.3809 | param_change: 6.104e-05\n",
      " 50%|████████████████████▎                    | 198/400 [04:17<04:05,  1.22s/it]loss: 26.62 | unlearn_loss: 26 | retain_loss: 0.6016 | param_change: 4.387e-05\n",
      " 50%|████████████████████▍                    | 199/400 [04:18<03:59,  1.19s/it]loss: 25.25 | unlearn_loss: 24.88 | retain_loss: 0.4082 | param_change: 2.849e-05\n",
      " 50%|████████████████████▌                    | 200/400 [04:19<04:07,  1.24s/it]loss: 26.62 | unlearn_loss: 25.88 | retain_loss: 0.8008 | param_change: 4.244e-05\n",
      " 50%|████████████████████▌                    | 201/400 [04:21<04:04,  1.23s/it]loss: 26.25 | unlearn_loss: 25.62 | retain_loss: 0.625 | param_change: 2.265e-05\n",
      " 50%|████████████████████▋                    | 202/400 [04:22<04:21,  1.32s/it]loss: 26.38 | unlearn_loss: 26 | retain_loss: 0.375 | param_change: 3.791e-05\n",
      " 51%|████████████████████▊                    | 203/400 [04:23<04:13,  1.29s/it]loss: 25.12 | unlearn_loss: 24.88 | retain_loss: 0.3027 | param_change: 2.921e-05\n",
      " 51%|████████████████████▉                    | 204/400 [04:25<04:11,  1.28s/it]loss: 26.12 | unlearn_loss: 26 | retain_loss: 0.1553 | param_change: 2.623e-05\n",
      " 51%|█████████████████████                    | 205/400 [04:26<04:13,  1.30s/it]loss: 25.38 | unlearn_loss: 25.25 | retain_loss: 0.1553 | param_change: 1.967e-05\n",
      " 52%|█████████████████████                    | 206/400 [04:27<04:20,  1.34s/it]loss: 26.25 | unlearn_loss: 25.88 | retain_loss: 0.377 | param_change: 7.248e-05\n",
      " 52%|█████████████████████▏                   | 207/400 [04:29<04:06,  1.28s/it]loss: 26.25 | unlearn_loss: 26 | retain_loss: 0.2246 | param_change: 8.392e-05\n",
      " 52%|█████████████████████▎                   | 208/400 [04:30<04:04,  1.27s/it]loss: 26.38 | unlearn_loss: 26 | retain_loss: 0.3613 | param_change: 4.292e-05\n",
      " 52%|█████████████████████▍                   | 209/400 [04:31<03:55,  1.23s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.2285 | param_change: 2.348e-05\n",
      " 52%|█████████████████████▌                   | 210/400 [04:32<03:54,  1.23s/it]loss: 26.38 | unlearn_loss: 25.88 | retain_loss: 0.4941 | param_change: 6.914e-05\n",
      " 53%|█████████████████████▋                   | 211/400 [04:33<03:54,  1.24s/it]loss: 27 | unlearn_loss: 26.62 | retain_loss: 0.3965 | param_change: 4.387e-05\n",
      " 53%|█████████████████████▋                   | 212/400 [04:35<04:05,  1.31s/it]loss: 26.25 | unlearn_loss: 25.88 | retain_loss: 0.4219 | param_change: 5.865e-05\n",
      " 53%|█████████████████████▊                   | 213/400 [04:36<03:55,  1.26s/it]loss: 26.25 | unlearn_loss: 26 | retain_loss: 0.2617 | param_change: 4.244e-05\n",
      " 54%|█████████████████████▉                   | 214/400 [04:37<03:57,  1.27s/it]loss: 26.12 | unlearn_loss: 25.88 | retain_loss: 0.2832 | param_change: 6.247e-05\n",
      " 54%|██████████████████████                   | 215/400 [04:38<03:48,  1.24s/it]loss: 26.75 | unlearn_loss: 26.5 | retain_loss: 0.2461 | param_change: 7.248e-05\n",
      " 54%|██████████████████████▏                  | 216/400 [04:40<04:01,  1.31s/it]loss: 26.5 | unlearn_loss: 25.88 | retain_loss: 0.6641 | param_change: 4.101e-05\n",
      " 54%|██████████████████████▏                  | 217/400 [04:41<03:53,  1.28s/it]loss: 26.75 | unlearn_loss: 26.38 | retain_loss: 0.4219 | param_change: 3.016e-05\n",
      " 55%|██████████████████████▎                  | 218/400 [04:43<04:12,  1.39s/it]loss: 26.38 | unlearn_loss: 26.12 | retain_loss: 0.3066 | param_change: 2.658e-05\n",
      " 55%|██████████████████████▍                  | 219/400 [04:44<04:03,  1.34s/it]loss: 26.62 | unlearn_loss: 26.25 | retain_loss: 0.3281 | param_change: 2.825e-05\n",
      " 55%|██████████████████████▌                  | 220/400 [04:45<04:02,  1.35s/it]loss: 26.12 | unlearn_loss: 25.75 | retain_loss: 0.4023 | param_change: 0.0001163\n",
      " 55%|██████████████████████▋                  | 221/400 [04:47<03:49,  1.28s/it]loss: 25.12 | unlearn_loss: 24.88 | retain_loss: 0.3105 | param_change: 5.651e-05\n",
      " 56%|██████████████████████▊                  | 222/400 [04:48<03:49,  1.29s/it]loss: 26.25 | unlearn_loss: 25.62 | retain_loss: 0.5664 | param_change: 0.0001507\n",
      " 56%|██████████████████████▊                  | 223/400 [04:49<03:39,  1.24s/it]loss: 25.25 | unlearn_loss: 24.88 | retain_loss: 0.3477 | param_change: 8.869e-05\n",
      " 56%|██████████████████████▉                  | 224/400 [04:50<03:38,  1.24s/it]loss: 26.62 | unlearn_loss: 25.75 | retain_loss: 0.9102 | param_change: 6.628e-05\n",
      " 56%|███████████████████████                  | 225/400 [04:52<03:42,  1.27s/it]loss: 25.62 | unlearn_loss: 24.88 | retain_loss: 0.7344 | param_change: 7.01e-05\n",
      " 56%|███████████████████████▏                 | 226/400 [04:53<03:52,  1.33s/it]loss: 26.12 | unlearn_loss: 25.75 | retain_loss: 0.377 | param_change: 0.0001554\n",
      " 57%|███████████████████████▎                 | 227/400 [04:54<03:42,  1.29s/it]loss: 25.12 | unlearn_loss: 24.88 | retain_loss: 0.3066 | param_change: 0.0002422\n",
      " 57%|███████████████████████▎                 | 228/400 [04:55<03:40,  1.28s/it]loss: 26 | unlearn_loss: 25.75 | retain_loss: 0.2393 | param_change: 0.0001411\n",
      " 57%|███████████████████████▍                 | 229/400 [04:57<03:36,  1.27s/it]loss: 25.5 | unlearn_loss: 25.25 | retain_loss: 0.2275 | param_change: 0.000186\n",
      " 57%|███████████████████████▌                 | 230/400 [04:58<03:40,  1.29s/it]loss: 26 | unlearn_loss: 25.5 | retain_loss: 0.5547 | param_change: 0.0001822\n",
      " 58%|███████████████████████▋                 | 231/400 [04:59<03:32,  1.26s/it]loss: 25.25 | unlearn_loss: 24.75 | retain_loss: 0.4648 | param_change: 0.0002232\n",
      " 58%|███████████████████████▊                 | 232/400 [05:01<03:32,  1.26s/it]loss: 26.5 | unlearn_loss: 25.88 | retain_loss: 0.6016 | param_change: 0.0001373\n",
      " 58%|███████████████████████▉                 | 233/400 [05:02<03:27,  1.24s/it]loss: 25.12 | unlearn_loss: 24.75 | retain_loss: 0.3164 | param_change: 0.0001383\n",
      " 58%|███████████████████████▉                 | 234/400 [05:03<03:30,  1.27s/it]loss: 26.62 | unlearn_loss: 25.88 | retain_loss: 0.7812 | param_change: 0.0003548\n",
      " 59%|████████████████████████                 | 235/400 [05:04<03:24,  1.24s/it]loss: 25.38 | unlearn_loss: 24.75 | retain_loss: 0.582 | param_change: 0.0005608\n",
      " 59%|████████████████████████▏                | 236/400 [05:06<03:26,  1.26s/it]loss: 26.62 | unlearn_loss: 26 | retain_loss: 0.6211 | param_change: 0.0004196\n",
      " 59%|████████████████████████▎                | 237/400 [05:07<03:20,  1.23s/it]loss: 25.5 | unlearn_loss: 24.88 | retain_loss: 0.6719 | param_change: 0.0007591\n",
      " 60%|████████████████████████▍                | 238/400 [05:08<03:19,  1.23s/it]loss: 26.62 | unlearn_loss: 25.88 | retain_loss: 0.7891 | param_change: 0.0005341\n",
      " 60%|████████████████████████▍                | 239/400 [05:09<03:12,  1.19s/it]loss: 25.5 | unlearn_loss: 24.88 | retain_loss: 0.5898 | param_change: 0.0004025\n",
      " 60%|████████████████████████▌                | 240/400 [05:10<03:13,  1.21s/it]loss: 26.75 | unlearn_loss: 26 | retain_loss: 0.7461 | param_change: 0.0001035\n",
      " 60%|████████████████████████▋                | 241/400 [05:11<03:09,  1.19s/it]loss: 25.5 | unlearn_loss: 25 | retain_loss: 0.4395 | param_change: 4.506e-05\n",
      " 60%|████████████████████████▊                | 242/400 [05:13<03:10,  1.21s/it]loss: 26.12 | unlearn_loss: 25.62 | retain_loss: 0.5312 | param_change: 0.0002041\n",
      " 61%|████████████████████████▉                | 243/400 [05:14<03:07,  1.20s/it]loss: 25.5 | unlearn_loss: 24.88 | retain_loss: 0.5938 | param_change: 0.000186\n",
      " 61%|█████████████████████████                | 244/400 [05:15<03:08,  1.21s/it]loss: 26.25 | unlearn_loss: 25.5 | retain_loss: 0.707 | param_change: 0.0002289\n",
      " 61%|█████████████████████████                | 245/400 [05:16<03:03,  1.18s/it]loss: 25.12 | unlearn_loss: 24.62 | retain_loss: 0.5078 | param_change: 0.0001841\n",
      " 62%|█████████████████████████▏               | 246/400 [05:17<03:04,  1.20s/it]loss: 26.12 | unlearn_loss: 25.5 | retain_loss: 0.6172 | param_change: 0.0001163\n",
      " 62%|█████████████████████████▎               | 247/400 [05:19<03:03,  1.20s/it]loss: 25.25 | unlearn_loss: 24.75 | retain_loss: 0.5156 | param_change: 6.485e-05\n",
      " 62%|█████████████████████████▍               | 248/400 [05:20<03:08,  1.24s/it]loss: 25.75 | unlearn_loss: 25.38 | retain_loss: 0.3594 | param_change: 9.775e-05\n",
      " 62%|█████████████████████████▌               | 249/400 [05:21<03:06,  1.24s/it]loss: 25 | unlearn_loss: 24.75 | retain_loss: 0.2852 | param_change: 7.439e-05\n",
      " 62%|█████████████████████████▋               | 250/400 [05:23<03:08,  1.26s/it]loss: 26.5 | unlearn_loss: 25.88 | retain_loss: 0.5781 | param_change: 8.202e-05\n",
      " 63%|█████████████████████████▋               | 251/400 [05:24<03:06,  1.25s/it]loss: 25.12 | unlearn_loss: 24.75 | retain_loss: 0.3965 | param_change: 6.485e-05\n",
      " 63%|█████████████████████████▊               | 252/400 [05:25<03:08,  1.28s/it]loss: 26.25 | unlearn_loss: 25.88 | retain_loss: 0.3945 | param_change: 6.962e-05\n",
      " 63%|█████████████████████████▉               | 253/400 [05:26<03:10,  1.30s/it]loss: 25 | unlearn_loss: 24.62 | retain_loss: 0.3457 | param_change: 6.342e-05\n",
      " 64%|██████████████████████████               | 254/400 [05:28<03:18,  1.36s/it]loss: 26.25 | unlearn_loss: 25.5 | retain_loss: 0.75 | param_change: 0.0001717\n",
      " 64%|██████████████████████████▏              | 255/400 [05:29<03:06,  1.29s/it]loss: 26.38 | unlearn_loss: 25.75 | retain_loss: 0.5898 | param_change: 0.0001345\n",
      " 64%|██████████████████████████▏              | 256/400 [05:30<03:03,  1.28s/it]loss: 26.62 | unlearn_loss: 25.62 | retain_loss: 1.023 | param_change: 0.0004482\n",
      " 64%|██████████████████████████▎              | 257/400 [05:31<02:55,  1.23s/it]loss: 25.62 | unlearn_loss: 25.12 | retain_loss: 0.4785 | param_change: 0.0002384\n",
      " 64%|██████████████████████████▍              | 258/400 [05:33<02:55,  1.23s/it]loss: 28.38 | unlearn_loss: 25.25 | retain_loss: 3.094 | param_change: 0.0001717\n",
      " 65%|██████████████████████████▌              | 259/400 [05:34<02:49,  1.20s/it]loss: 25.62 | unlearn_loss: 24.38 | retain_loss: 1.297 | param_change: 0.0001154\n",
      " 65%|██████████████████████████▋              | 260/400 [05:35<02:51,  1.22s/it]loss: 26.5 | unlearn_loss: 25.75 | retain_loss: 0.8008 | param_change: 0.000124\n",
      " 65%|██████████████████████████▊              | 261/400 [05:36<02:56,  1.27s/it]loss: 25.62 | unlearn_loss: 24.88 | retain_loss: 0.7812 | param_change: 0.0001411\n",
      " 66%|██████████████████████████▊              | 262/400 [05:38<03:04,  1.33s/it]loss: 26.25 | unlearn_loss: 25 | retain_loss: 1.195 | param_change: 0.0001354\n",
      " 66%|██████████████████████████▉              | 263/400 [05:39<02:59,  1.31s/it]loss: 25 | unlearn_loss: 23.88 | retain_loss: 1.062 | param_change: 0.0001869\n",
      " 66%|███████████████████████████              | 264/400 [05:41<02:57,  1.31s/it]loss: 25.62 | unlearn_loss: 25.12 | retain_loss: 0.5508 | param_change: 0.000114\n",
      " 66%|███████████████████████████▏             | 265/400 [05:42<02:50,  1.26s/it]loss: 24.5 | unlearn_loss: 24 | retain_loss: 0.4824 | param_change: 0.0001764\n",
      " 66%|███████████████████████████▎             | 266/400 [05:43<02:51,  1.28s/it]loss: 25.12 | unlearn_loss: 24.5 | retain_loss: 0.6797 | param_change: 0.0001783\n",
      " 67%|███████████████████████████▎             | 267/400 [05:44<02:48,  1.27s/it]loss: 24.25 | unlearn_loss: 23.62 | retain_loss: 0.5898 | param_change: 0.0002823\n",
      " 67%|███████████████████████████▍             | 268/400 [05:46<02:53,  1.32s/it]loss: 25.25 | unlearn_loss: 24.62 | retain_loss: 0.5938 | param_change: 0.000144\n",
      " 67%|███████████████████████████▌             | 269/400 [05:47<02:46,  1.27s/it]loss: 23 | unlearn_loss: 22.5 | retain_loss: 0.5312 | param_change: 0.0004768\n",
      " 68%|███████████████████████████▋             | 270/400 [05:48<02:46,  1.28s/it]loss: 26.25 | unlearn_loss: 25.25 | retain_loss: 1.047 | param_change: 0.0002823\n",
      " 68%|███████████████████████████▊             | 271/400 [05:49<02:47,  1.30s/it]loss: 21.88 | unlearn_loss: 21.12 | retain_loss: 0.7031 | param_change: 0.0003109\n",
      " 68%|███████████████████████████▉             | 272/400 [05:51<02:53,  1.36s/it]loss: 25.5 | unlearn_loss: 24.75 | retain_loss: 0.7266 | param_change: 0.0001926\n",
      " 68%|███████████████████████████▉             | 273/400 [05:52<02:46,  1.31s/it]loss: 20.62 | unlearn_loss: 20 | retain_loss: 0.5664 | param_change: 0.0002184\n",
      " 68%|████████████████████████████             | 274/400 [05:54<02:56,  1.40s/it]loss: 25.62 | unlearn_loss: 24.75 | retain_loss: 0.8477 | param_change: 0.0003853\n",
      " 69%|████████████████████████████▏            | 275/400 [05:55<02:49,  1.36s/it]loss: 19.25 | unlearn_loss: 18.75 | retain_loss: 0.5625 | param_change: 0.0003052\n",
      " 69%|████████████████████████████▎            | 276/400 [05:56<02:46,  1.34s/it]loss: 25.25 | unlearn_loss: 23.75 | retain_loss: 1.516 | param_change: 0.0004425\n",
      " 69%|████████████████████████████▍            | 277/400 [05:58<02:38,  1.29s/it]loss: 17.88 | unlearn_loss: 16.88 | retain_loss: 1 | param_change: 0.000309\n",
      " 70%|████████████████████████████▍            | 278/400 [05:59<02:37,  1.29s/it]loss: 25.5 | unlearn_loss: 24.62 | retain_loss: 0.8672 | param_change: 0.0002937\n",
      " 70%|████████████████████████████▌            | 279/400 [06:00<02:34,  1.27s/it]loss: 16.5 | unlearn_loss: 15.62 | retain_loss: 0.8242 | param_change: 0.0002804\n",
      " 70%|████████████████████████████▋            | 280/400 [06:01<02:35,  1.29s/it]loss: 23.75 | unlearn_loss: 22.75 | retain_loss: 0.9961 | param_change: 0.0003529\n",
      " 70%|████████████████████████████▊            | 281/400 [06:03<02:32,  1.28s/it]loss: 15 | unlearn_loss: 14.12 | retain_loss: 0.9023 | param_change: 0.0002394\n",
      " 70%|████████████████████████████▉            | 282/400 [06:04<02:34,  1.31s/it]loss: 20.62 | unlearn_loss: 19.38 | retain_loss: 1.305 | param_change: 0.000349\n",
      " 71%|█████████████████████████████            | 283/400 [06:05<02:31,  1.30s/it]loss: 19.75 | unlearn_loss: 18.75 | retain_loss: 0.9688 | param_change: 0.0004673\n",
      " 71%|█████████████████████████████            | 284/400 [06:07<02:30,  1.30s/it]loss: 19.88 | unlearn_loss: 19 | retain_loss: 0.8281 | param_change: 0.0003147\n",
      " 71%|█████████████████████████████▏           | 285/400 [06:08<02:33,  1.34s/it]loss: 13.38 | unlearn_loss: 12.69 | retain_loss: 0.7188 | param_change: 0.0002613\n",
      " 72%|█████████████████████████████▎           | 286/400 [06:10<02:40,  1.41s/it]loss: 19.38 | unlearn_loss: 18.12 | retain_loss: 1.281 | param_change: 0.0003872\n",
      " 72%|█████████████████████████████▍           | 287/400 [06:11<02:33,  1.36s/it]loss: 12.81 | unlearn_loss: 11.75 | retain_loss: 1.07 | param_change: 0.0001659\n",
      " 72%|█████████████████████████████▌           | 288/400 [06:12<02:30,  1.35s/it]loss: 22.12 | unlearn_loss: 21.12 | retain_loss: 1 | param_change: 0.0005989\n",
      " 72%|█████████████████████████████▌           | 289/400 [06:13<02:28,  1.34s/it]loss: 13.19 | unlearn_loss: 12.25 | retain_loss: 0.9336 | param_change: 0.0001755\n",
      " 72%|█████████████████████████████▋           | 290/400 [06:15<02:30,  1.37s/it]loss: 16.62 | unlearn_loss: 15.81 | retain_loss: 0.793 | param_change: 0.0002728\n",
      " 73%|█████████████████████████████▊           | 291/400 [06:16<02:24,  1.32s/it]loss: 15.44 | unlearn_loss: 14.75 | retain_loss: 0.6953 | param_change: 0.0004883\n",
      " 73%|█████████████████████████████▉           | 292/400 [06:17<02:22,  1.32s/it]loss: 15.69 | unlearn_loss: 14.69 | retain_loss: 1.016 | param_change: 0.0003605\n",
      " 73%|██████████████████████████████           | 293/400 [06:19<02:15,  1.26s/it]loss: 10.75 | unlearn_loss: 10 | retain_loss: 0.7188 | param_change: 0.00037\n",
      " 74%|██████████████████████████████▏          | 294/400 [06:20<02:12,  1.25s/it]loss: 18.25 | unlearn_loss: 15.38 | retain_loss: 2.844 | param_change: 0.0007172\n",
      " 74%|██████████████████████████████▏          | 295/400 [06:21<02:06,  1.21s/it]loss: 10.75 | unlearn_loss: 9.062 | retain_loss: 1.695 | param_change: 0.0002041\n",
      " 74%|██████████████████████████████▎          | 296/400 [06:22<02:06,  1.21s/it]loss: 15.94 | unlearn_loss: 13.12 | retain_loss: 2.812 | param_change: 0.0007896\n",
      " 74%|██████████████████████████████▍          | 297/400 [06:23<02:02,  1.19s/it]loss: 9.688 | unlearn_loss: 8.75 | retain_loss: 0.9531 | param_change: 0.0001993\n",
      " 74%|██████████████████████████████▌          | 298/400 [06:24<02:03,  1.21s/it]loss: 13.38 | unlearn_loss: 12.75 | retain_loss: 0.6094 | param_change: 0.0002766\n",
      " 75%|██████████████████████████████▋          | 299/400 [06:26<02:00,  1.19s/it]loss: 11.31 | unlearn_loss: 10.94 | retain_loss: 0.375 | param_change: 0.0006981\n",
      " 75%|██████████████████████████████▊          | 300/400 [06:27<02:00,  1.21s/it]loss: 13.19 | unlearn_loss: 12.25 | retain_loss: 0.9141 | param_change: 0.0002613\n",
      " 75%|██████████████████████████████▊          | 301/400 [06:28<01:57,  1.18s/it]loss: 11.12 | unlearn_loss: 10.5 | retain_loss: 0.5977 | param_change: 0.000576\n",
      " 76%|██████████████████████████████▉          | 302/400 [06:29<01:57,  1.20s/it]loss: 13.12 | unlearn_loss: 12.19 | retain_loss: 0.9453 | param_change: 0.0002441\n",
      " 76%|███████████████████████████████          | 303/400 [06:30<01:56,  1.20s/it]loss: 10.56 | unlearn_loss: 9.688 | retain_loss: 0.8867 | param_change: 0.000412\n",
      " 76%|███████████████████████████████▏         | 304/400 [06:32<01:58,  1.24s/it]loss: 12.75 | unlearn_loss: 12 | retain_loss: 0.7344 | param_change: 0.0002422\n",
      " 76%|███████████████████████████████▎         | 305/400 [06:33<01:58,  1.24s/it]loss: 8.938 | unlearn_loss: 8.25 | retain_loss: 0.6953 | param_change: 0.0001178\n",
      " 76%|███████████████████████████████▎         | 306/400 [06:34<01:58,  1.26s/it]loss: 13.56 | unlearn_loss: 12.81 | retain_loss: 0.7656 | param_change: 0.0003719\n",
      " 77%|███████████████████████████████▍         | 307/400 [06:36<01:55,  1.24s/it]loss: 9.375 | unlearn_loss: 8.75 | retain_loss: 0.6328 | param_change: 0.0001717\n",
      " 77%|███████████████████████████████▌         | 308/400 [06:37<01:55,  1.25s/it]loss: 11.81 | unlearn_loss: 11.31 | retain_loss: 0.4922 | param_change: 0.0002956\n",
      " 77%|███████████████████████████████▋         | 309/400 [06:38<01:52,  1.23s/it]loss: 8.625 | unlearn_loss: 8.188 | retain_loss: 0.457 | param_change: 0.000185\n",
      " 78%|███████████████████████████████▊         | 310/400 [06:39<01:55,  1.28s/it]loss: 11.81 | unlearn_loss: 10.19 | retain_loss: 1.625 | param_change: 0.00193\n",
      " 78%|███████████████████████████████▉         | 311/400 [06:41<01:50,  1.24s/it]loss: 10.88 | unlearn_loss: 9.625 | retain_loss: 1.273 | param_change: 0.0005112\n",
      " 78%|███████████████████████████████▉         | 312/400 [06:42<01:49,  1.25s/it]loss: 11.69 | unlearn_loss: 10.69 | retain_loss: 0.9844 | param_change: 0.0003929\n",
      " 78%|████████████████████████████████         | 313/400 [06:43<01:47,  1.24s/it]loss: 8.062 | unlearn_loss: 7.281 | retain_loss: 0.793 | param_change: 0.0001841\n",
      " 78%|████████████████████████████████▏        | 314/400 [06:44<01:48,  1.26s/it]loss: 12.69 | unlearn_loss: 11.62 | retain_loss: 1.039 | param_change: 0.0005417\n",
      " 79%|████████████████████████████████▎        | 315/400 [06:45<01:42,  1.21s/it]loss: 8.25 | unlearn_loss: 7.344 | retain_loss: 0.9297 | param_change: 0.0001326\n",
      " 79%|████████████████████████████████▍        | 316/400 [06:47<01:41,  1.20s/it]loss: 13.44 | unlearn_loss: 9.625 | retain_loss: 3.828 | param_change: 0.007721\n",
      " 79%|████████████████████████████████▍        | 317/400 [06:48<01:36,  1.16s/it]loss: 9.125 | unlearn_loss: 8.438 | retain_loss: 0.7188 | param_change: 0.0001822\n",
      " 80%|████████████████████████████████▌        | 318/400 [06:49<01:45,  1.29s/it]loss: 10.62 | unlearn_loss: 9.562 | retain_loss: 1.039 | param_change: 0.0002956\n",
      " 80%|████████████████████████████████▋        | 319/400 [06:50<01:40,  1.24s/it]loss: 7.406 | unlearn_loss: 6.562 | retain_loss: 0.8359 | param_change: 0.0002289\n",
      " 80%|████████████████████████████████▊        | 320/400 [06:52<01:39,  1.24s/it]loss: 10.44 | unlearn_loss: 9.438 | retain_loss: 1 | param_change: 0.0001698\n",
      " 80%|████████████████████████████████▉        | 321/400 [06:53<01:33,  1.18s/it]loss: 7.812 | unlearn_loss: 7 | retain_loss: 0.8008 | param_change: 0.000113\n",
      " 80%|█████████████████████████████████        | 322/400 [06:54<01:35,  1.22s/it]loss: 10.25 | unlearn_loss: 9.25 | retain_loss: 1 | param_change: 0.000248\n",
      " 81%|█████████████████████████████████        | 323/400 [06:55<01:30,  1.18s/it]loss: 9.312 | unlearn_loss: 8.375 | retain_loss: 0.9414 | param_change: 0.0001354\n",
      " 81%|█████████████████████████████████▏       | 324/400 [06:57<01:35,  1.26s/it]loss: 12.19 | unlearn_loss: 10.06 | retain_loss: 2.141 | param_change: 0.0007935\n",
      " 81%|█████████████████████████████████▎       | 325/400 [06:58<01:32,  1.23s/it]loss: 8.25 | unlearn_loss: 6.781 | retain_loss: 1.453 | param_change: 0.0003548\n",
      " 82%|█████████████████████████████████▍       | 326/400 [06:59<01:33,  1.27s/it]loss: 11.5 | unlearn_loss: 9.375 | retain_loss: 2.156 | param_change: 0.0004864\n",
      " 82%|█████████████████████████████████▌       | 327/400 [07:00<01:29,  1.23s/it]loss: 8 | unlearn_loss: 6.312 | retain_loss: 1.711 | param_change: 0.000349\n",
      " 82%|█████████████████████████████████▌       | 328/400 [07:01<01:29,  1.24s/it]loss: 11.88 | unlearn_loss: 8.812 | retain_loss: 3.094 | param_change: 0.0005798\n",
      " 82%|█████████████████████████████████▋       | 329/400 [07:03<01:25,  1.20s/it]loss: 8.75 | unlearn_loss: 6.156 | retain_loss: 2.594 | param_change: 0.0003147\n",
      " 82%|█████████████████████████████████▊       | 330/400 [07:04<01:25,  1.22s/it]loss: 12.81 | unlearn_loss: 12 | retain_loss: 0.8125 | param_change: 0.001305\n",
      " 83%|█████████████████████████████████▉       | 331/400 [07:05<01:22,  1.20s/it]loss: 6.719 | unlearn_loss: 5.938 | retain_loss: 0.7734 | param_change: 9.537e-05\n",
      " 83%|██████████████████████████████████       | 332/400 [07:06<01:22,  1.21s/it]loss: 9 | unlearn_loss: 8.562 | retain_loss: 0.4141 | param_change: 0.0001054\n",
      " 83%|██████████████████████████████████▏      | 333/400 [07:08<01:23,  1.24s/it]loss: 6.875 | unlearn_loss: 6.5 | retain_loss: 0.3867 | param_change: 0.000186\n",
      " 84%|██████████████████████████████████▏      | 334/400 [07:09<01:27,  1.32s/it]loss: 9.25 | unlearn_loss: 8.312 | retain_loss: 0.9414 | param_change: 0.0002804\n",
      " 84%|██████████████████████████████████▎      | 335/400 [07:10<01:22,  1.27s/it]loss: 7.125 | unlearn_loss: 6.25 | retain_loss: 0.8672 | param_change: 0.0002918\n",
      " 84%|██████████████████████████████████▍      | 336/400 [07:11<01:21,  1.28s/it]loss: 9.25 | unlearn_loss: 8.125 | retain_loss: 1.094 | param_change: 0.0002193\n",
      " 84%|██████████████████████████████████▌      | 337/400 [07:13<01:20,  1.27s/it]loss: 6.625 | unlearn_loss: 5.719 | retain_loss: 0.8984 | param_change: 0.0001707\n",
      " 84%|██████████████████████████████████▋      | 338/400 [07:14<01:23,  1.34s/it]loss: 10.69 | unlearn_loss: 9.625 | retain_loss: 1.047 | param_change: 0.0007172\n",
      " 85%|██████████████████████████████████▋      | 339/400 [07:15<01:18,  1.29s/it]loss: 6.406 | unlearn_loss: 5.656 | retain_loss: 0.7578 | param_change: 0.0002155\n",
      " 85%|██████████████████████████████████▊      | 340/400 [07:17<01:17,  1.30s/it]loss: 9.375 | unlearn_loss: 7.906 | retain_loss: 1.5 | param_change: 0.0007668\n",
      " 85%|██████████████████████████████████▉      | 341/400 [07:18<01:16,  1.30s/it]loss: 7.469 | unlearn_loss: 6.5 | retain_loss: 0.9648 | param_change: 0.0002174\n",
      " 86%|███████████████████████████████████      | 342/400 [07:19<01:17,  1.33s/it]loss: 9 | unlearn_loss: 8.125 | retain_loss: 0.8555 | param_change: 0.0003414\n",
      " 86%|███████████████████████████████████▏     | 343/400 [07:21<01:12,  1.28s/it]loss: 6 | unlearn_loss: 5.406 | retain_loss: 0.5781 | param_change: 0.0001917\n",
      " 86%|███████████████████████████████████▎     | 344/400 [07:22<01:11,  1.27s/it]loss: 8.438 | unlearn_loss: 7.656 | retain_loss: 0.7891 | param_change: 0.0001926\n",
      " 86%|███████████████████████████████████▎     | 345/400 [07:23<01:10,  1.29s/it]loss: 8.438 | unlearn_loss: 7.781 | retain_loss: 0.6406 | param_change: 0.0001345\n",
      " 86%|███████████████████████████████████▍     | 346/400 [07:25<01:12,  1.35s/it]loss: 9.188 | unlearn_loss: 8.812 | retain_loss: 0.4004 | param_change: 0.0002842\n",
      " 87%|███████████████████████████████████▌     | 347/400 [07:26<01:10,  1.33s/it]loss: 5.469 | unlearn_loss: 5.125 | retain_loss: 0.334 | param_change: 9.537e-05\n",
      " 87%|███████████████████████████████████▋     | 348/400 [07:28<01:13,  1.42s/it]loss: 8.438 | unlearn_loss: 7.875 | retain_loss: 0.5391 | param_change: 0.0001268\n",
      " 87%|███████████████████████████████████▊     | 349/400 [07:29<01:09,  1.36s/it]loss: 5.688 | unlearn_loss: 5.188 | retain_loss: 0.5156 | param_change: 9.203e-05\n",
      " 88%|███████████████████████████████████▉     | 350/400 [07:30<01:07,  1.36s/it]loss: 10.31 | unlearn_loss: 8.625 | retain_loss: 1.672 | param_change: 0.004303\n",
      " 88%|███████████████████████████████████▉     | 351/400 [07:31<01:05,  1.33s/it]loss: 8.375 | unlearn_loss: 7.812 | retain_loss: 0.5859 | param_change: 0.0003452\n",
      " 88%|████████████████████████████████████     | 352/400 [07:33<01:05,  1.36s/it]loss: 8.812 | unlearn_loss: 8.125 | retain_loss: 0.6797 | param_change: 0.0001841\n",
      " 88%|████████████████████████████████████▏    | 353/400 [07:34<01:01,  1.31s/it]loss: 6.062 | unlearn_loss: 5.375 | retain_loss: 0.7031 | param_change: 0.0001488\n",
      " 88%|████████████████████████████████████▎    | 354/400 [07:35<01:00,  1.32s/it]loss: 8.438 | unlearn_loss: 7.781 | retain_loss: 0.6641 | param_change: 0.0002108\n",
      " 89%|████████████████████████████████████▍    | 355/400 [07:37<00:58,  1.30s/it]loss: 6.125 | unlearn_loss: 5.5 | retain_loss: 0.6289 | param_change: 0.0001144\n",
      " 89%|████████████████████████████████████▍    | 356/400 [07:38<00:58,  1.33s/it]loss: 8.125 | unlearn_loss: 7.469 | retain_loss: 0.6445 | param_change: 0.0001316\n",
      " 89%|████████████████████████████████████▌    | 357/400 [07:39<00:56,  1.30s/it]loss: 6.344 | unlearn_loss: 5.719 | retain_loss: 0.6211 | param_change: 0.0001216\n",
      " 90%|████████████████████████████████████▋    | 358/400 [07:41<00:55,  1.31s/it]loss: 8 | unlearn_loss: 7.344 | retain_loss: 0.6445 | param_change: 0.0001183\n",
      " 90%|████████████████████████████████████▊    | 359/400 [07:42<00:52,  1.28s/it]loss: 5.594 | unlearn_loss: 5 | retain_loss: 0.6055 | param_change: 0.0001035\n",
      " 90%|████████████████████████████████████▉    | 360/400 [07:43<00:52,  1.31s/it]loss: 8.312 | unlearn_loss: 7.438 | retain_loss: 0.8672 | param_change: 0.0001192\n",
      " 90%|█████████████████████████████████████    | 361/400 [07:44<00:50,  1.29s/it]loss: 6.062 | unlearn_loss: 5.25 | retain_loss: 0.8164 | param_change: 0.0001411\n",
      " 90%|█████████████████████████████████████    | 362/400 [07:46<00:49,  1.31s/it]loss: 8.312 | unlearn_loss: 7.438 | retain_loss: 0.8594 | param_change: 0.000124\n",
      " 91%|█████████████████████████████████████▏   | 363/400 [07:47<00:47,  1.29s/it]loss: 6.719 | unlearn_loss: 5.906 | retain_loss: 0.8008 | param_change: 0.0001535\n",
      " 91%|█████████████████████████████████████▎   | 364/400 [07:48<00:46,  1.29s/it]loss: 8.688 | unlearn_loss: 7.594 | retain_loss: 1.07 | param_change: 0.0003624\n",
      " 91%|█████████████████████████████████████▍   | 365/400 [07:50<00:44,  1.27s/it]loss: 6.094 | unlearn_loss: 5.156 | retain_loss: 0.9453 | param_change: 0.0001965\n",
      " 92%|█████████████████████████████████████▌   | 366/400 [07:51<00:44,  1.30s/it]loss: 9.312 | unlearn_loss: 8.5 | retain_loss: 0.8125 | param_change: 0.0002937\n",
      " 92%|█████████████████████████████████████▌   | 367/400 [07:52<00:41,  1.26s/it]loss: 5.75 | unlearn_loss: 4.969 | retain_loss: 0.7734 | param_change: 0.0001469\n",
      " 92%|█████████████████████████████████████▋   | 368/400 [07:53<00:40,  1.26s/it]loss: 7.531 | unlearn_loss: 7.094 | retain_loss: 0.4277 | param_change: 0.0001149\n",
      " 92%|█████████████████████████████████████▊   | 369/400 [07:55<00:38,  1.23s/it]loss: 5.656 | unlearn_loss: 5.219 | retain_loss: 0.4277 | param_change: 9.584e-05\n",
      " 92%|█████████████████████████████████████▉   | 370/400 [07:56<00:36,  1.23s/it]loss: 10.38 | unlearn_loss: 9.75 | retain_loss: 0.6328 | param_change: 0.000349\n",
      " 93%|██████████████████████████████████████   | 371/400 [07:57<00:35,  1.23s/it]loss: 6.75 | unlearn_loss: 6.125 | retain_loss: 0.6172 | param_change: 0.0001431\n",
      " 93%|██████████████████████████████████████▏  | 372/400 [07:58<00:35,  1.26s/it]loss: 9.062 | unlearn_loss: 8.688 | retain_loss: 0.3574 | param_change: 0.0002899\n",
      " 93%|██████████████████████████████████████▏  | 373/400 [08:00<00:33,  1.25s/it]loss: 5.312 | unlearn_loss: 4.969 | retain_loss: 0.3535 | param_change: 7.343e-05\n",
      " 94%|██████████████████████████████████████▎  | 374/400 [08:01<00:33,  1.28s/it]loss: 7.719 | unlearn_loss: 7.219 | retain_loss: 0.498 | param_change: 0.0001106\n",
      " 94%|██████████████████████████████████████▍  | 375/400 [08:02<00:31,  1.26s/it]loss: 5.938 | unlearn_loss: 5.469 | retain_loss: 0.4824 | param_change: 0.0001154\n",
      " 94%|██████████████████████████████████████▌  | 376/400 [08:03<00:30,  1.26s/it]loss: 8.375 | unlearn_loss: 7.844 | retain_loss: 0.5117 | param_change: 0.0001764\n",
      " 94%|██████████████████████████████████████▋  | 377/400 [08:05<00:28,  1.23s/it]loss: 5.031 | unlearn_loss: 4.531 | retain_loss: 0.498 | param_change: 9.775e-05\n",
      " 94%|██████████████████████████████████████▋  | 378/400 [08:06<00:27,  1.25s/it]loss: 7.688 | unlearn_loss: 7.156 | retain_loss: 0.5195 | param_change: 0.0001307\n",
      " 95%|██████████████████████████████████████▊  | 379/400 [08:07<00:26,  1.26s/it]loss: 5 | unlearn_loss: 4.5 | retain_loss: 0.4922 | param_change: 4.935e-05\n",
      " 95%|██████████████████████████████████████▉  | 380/400 [08:09<00:26,  1.31s/it]loss: 7.281 | unlearn_loss: 6.812 | retain_loss: 0.4668 | param_change: 0.0001011\n",
      " 95%|███████████████████████████████████████  | 381/400 [08:10<00:24,  1.30s/it]loss: 5.531 | unlearn_loss: 5.094 | retain_loss: 0.4395 | param_change: 8.869e-05\n",
      " 96%|███████████████████████████████████████▏ | 382/400 [08:11<00:23,  1.33s/it]loss: 7.5 | unlearn_loss: 6.906 | retain_loss: 0.5781 | param_change: 7.486e-05\n",
      " 96%|███████████████████████████████████████▎ | 383/400 [08:12<00:22,  1.30s/it]loss: 5.188 | unlearn_loss: 4.656 | retain_loss: 0.5469 | param_change: 8.345e-05\n",
      " 96%|███████████████████████████████████████▎ | 384/400 [08:14<00:21,  1.32s/it]loss: 8.188 | unlearn_loss: 7.344 | retain_loss: 0.8359 | param_change: 0.0001326\n",
      " 96%|███████████████████████████████████████▍ | 385/400 [08:15<00:18,  1.27s/it]loss: 5.531 | unlearn_loss: 4.75 | retain_loss: 0.793 | param_change: 0.0001016\n",
      " 96%|███████████████████████████████████████▌ | 386/400 [08:16<00:17,  1.26s/it]loss: 8.062 | unlearn_loss: 7.469 | retain_loss: 0.5977 | param_change: 0.0001431\n",
      " 97%|███████████████████████████████████████▋ | 387/400 [08:17<00:16,  1.27s/it]loss: 5.625 | unlearn_loss: 5.062 | retain_loss: 0.5547 | param_change: 9.394e-05\n",
      " 97%|███████████████████████████████████████▊ | 388/400 [08:19<00:15,  1.32s/it]loss: 8.5 | unlearn_loss: 7.875 | retain_loss: 0.6328 | param_change: 0.0002499\n",
      " 97%|███████████████████████████████████████▊ | 389/400 [08:20<00:14,  1.28s/it]loss: 5.688 | unlearn_loss: 5.031 | retain_loss: 0.6641 | param_change: 0.0001669\n",
      " 98%|███████████████████████████████████████▉ | 390/400 [08:21<00:12,  1.28s/it]loss: 12.44 | unlearn_loss: 11.81 | retain_loss: 0.6133 | param_change: 0.0002975\n",
      " 98%|████████████████████████████████████████ | 391/400 [08:23<00:11,  1.25s/it]loss: 5.625 | unlearn_loss: 5 | retain_loss: 0.6094 | param_change: 0.0001469\n",
      " 98%|████████████████████████████████████████▏| 392/400 [08:24<00:10,  1.25s/it]loss: 8.938 | unlearn_loss: 8.062 | retain_loss: 0.8594 | param_change: 0.0002241\n",
      " 98%|████████████████████████████████████████▎| 393/400 [08:25<00:08,  1.25s/it]loss: 5.406 | unlearn_loss: 4.625 | retain_loss: 0.7891 | param_change: 0.0001554\n",
      " 98%|████████████████████████████████████████▍| 394/400 [08:27<00:07,  1.32s/it]loss: 8.312 | unlearn_loss: 7.375 | retain_loss: 0.957 | param_change: 0.0002556\n",
      " 99%|████████████████████████████████████████▍| 395/400 [08:28<00:06,  1.25s/it]loss: 5.5 | unlearn_loss: 4.656 | retain_loss: 0.8359 | param_change: 0.0001774\n",
      " 99%|████████████████████████████████████████▌| 396/400 [08:29<00:05,  1.25s/it]loss: 8.062 | unlearn_loss: 7.188 | retain_loss: 0.8867 | param_change: 0.0001869\n",
      " 99%|████████████████████████████████████████▋| 397/400 [08:30<00:03,  1.24s/it]loss: 5.312 | unlearn_loss: 4.469 | retain_loss: 0.8438 | param_change: 0.0001507\n",
      "100%|████████████████████████████████████████▊| 398/400 [08:31<00:02,  1.25s/it]loss: 7.625 | unlearn_loss: 6.812 | retain_loss: 0.8008 | param_change: 0.0001373\n",
      "100%|████████████████████████████████████████▉| 399/400 [08:33<00:01,  1.23s/it]loss: 5.594 | unlearn_loss: 4.781 | retain_loss: 0.8008 | param_change: 0.000123\n",
      "100%|█████████████████████████████████████████| 400/400 [08:34<00:00,  1.29s/it]\n",
      "Saved model to models/mixtral_cut_0\n"
     ]
    }
   ],
   "source": [
    "# best\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0,1,2,3,4,5\"\n",
    "\n",
    "!python3 -m rmu.unlearn --model_name mistralai/Mixtral-8x7B-Instruct-v0.1  --batch_size 2 --param_ids 7 --max_num_batches 400 --retain_corpora wikitext,wikitext --forget_corpora bio-forget-corpus,cyber-forget-corpus --steering_coeffs 300,300 --alpha 1600,1600 --min_len 200 --lr 5e-5 --seed 42 --output_dir models/mixtral_rmu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-04-16 00:31:20.981214: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-04-16 00:31:21.855576: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
      "2024-04-16:00:31:25,619 INFO     [__main__.py:251] Verbosity set to INFO\n",
      "2024-04-16:00:31:31,813 INFO     [__main__.py:335] Selected Tasks: ['mmlu', 'wmdp']\n",
      "2024-04-16:00:31:31,820 INFO     [evaluator.py:131] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234\n",
      "2024-04-16:00:31:31,820 INFO     [evaluator.py:177] Initializing hf model, with arguments: {'pretrained': 'models/mixtral_cut_0', 'parallelize': True}\n",
      "2024-04-16:00:31:33,169 WARNING  [logging.py:61] Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
      "Loading checkpoint shards:  32%|█████▎           | 6/19 [00:10<00:22,  1.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|████████████████| 19/19 [00:32<00:00,  1.73s/it]\n",
      "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n",
      "/data/long_phan/anaconda3/lib/python3.10/site-packages/datasets/load.py:1429: FutureWarning: The repository for hails/mmlu_no_train contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/hails/mmlu_no_train\n",
      "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
      "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
      "  warnings.warn(\n",
      "2024-04-16:00:33:38,916 WARNING  [task.py:322] [Task: wmdp_bio] has_training_docs and has_validation_docs are False, using test_docs as fewshot_docs but this is not recommended.\n",
      "2024-04-16:00:33:38,916 WARNING  [task.py:322] [Task: wmdp_bio] has_training_docs and has_validation_docs are False, using test_docs as fewshot_docs but this is not recommended.\n",
      "2024-04-16:00:33:40,190 WARNING  [task.py:322] [Task: wmdp_chem] has_training_docs and has_validation_docs are False, using test_docs as fewshot_docs but this is not recommended.\n",
      "2024-04-16:00:33:40,190 WARNING  [task.py:322] [Task: wmdp_chem] has_training_docs and has_validation_docs are False, using test_docs as fewshot_docs but this is not recommended.\n",
      "2024-04-16:00:33:41,604 WARNING  [task.py:322] [Task: wmdp_cyber] has_training_docs and has_validation_docs are False, using test_docs as fewshot_docs but this is not recommended.\n",
      "2024-04-16:00:33:41,604 WARNING  [task.py:322] [Task: wmdp_cyber] has_training_docs and has_validation_docs are False, using test_docs as fewshot_docs but this is not recommended.\n",
      "2024-04-16:00:33:41,670 INFO     [task.py:395] Building contexts for wmdp_cyber on rank 0...\n",
      "100%|██████████████████████████████████████| 2225/2225 [00:02<00:00, 828.53it/s]\n",
      "2024-04-16:00:33:44,415 INFO     [task.py:395] Building contexts for wmdp_chem on rank 0...\n",
      "100%|████████████████████████████████████████| 412/412 [00:00<00:00, 832.64it/s]\n",
      "2024-04-16:00:33:44,922 INFO     [task.py:395] Building contexts for wmdp_bio on rank 0...\n",
      "100%|██████████████████████████████████████| 1243/1243 [00:01<00:00, 831.75it/s]\n",
      "2024-04-16:00:33:46,452 INFO     [task.py:395] Building contexts for mmlu_international_law on rank 0...\n",
      "100%|████████████████████████████████████████| 121/121 [00:00<00:00, 824.73it/s]\n",
      "2024-04-16:00:33:46,603 INFO     [task.py:395] Building contexts for mmlu_high_school_world_history on rank 0...\n",
      "100%|████████████████████████████████████████| 237/237 [00:00<00:00, 825.72it/s]\n",
      "2024-04-16:00:33:46,899 INFO     [task.py:395] Building contexts for mmlu_philosophy on rank 0...\n",
      "100%|████████████████████████████████████████| 311/311 [00:00<00:00, 822.08it/s]\n",
      "2024-04-16:00:33:47,289 INFO     [task.py:395] Building contexts for mmlu_logical_fallacies on rank 0...\n",
      "100%|████████████████████████████████████████| 163/163 [00:00<00:00, 823.91it/s]\n",
      "2024-04-16:00:33:47,495 INFO     [task.py:395] Building contexts for mmlu_high_school_european_history on rank 0...\n",
      "100%|████████████████████████████████████████| 165/165 [00:00<00:00, 824.13it/s]\n",
      "2024-04-16:00:33:47,702 INFO     [task.py:395] Building contexts for mmlu_moral_scenarios on rank 0...\n",
      "100%|████████████████████████████████████████| 895/895 [00:01<00:00, 656.39it/s]\n",
      "2024-04-16:00:33:49,095 INFO     [task.py:395] Building contexts for mmlu_professional_law on rank 0...\n",
      "100%|██████████████████████████████████████| 1534/1534 [00:01<00:00, 829.58it/s]\n",
      "2024-04-16:00:33:50,998 INFO     [task.py:395] Building contexts for mmlu_high_school_us_history on rank 0...\n",
      "100%|████████████████████████████████████████| 204/204 [00:00<00:00, 825.86it/s]\n",
      "2024-04-16:00:33:51,253 INFO     [task.py:395] Building contexts for mmlu_formal_logic on rank 0...\n",
      "100%|████████████████████████████████████████| 126/126 [00:00<00:00, 833.17it/s]\n",
      "2024-04-16:00:33:51,410 INFO     [task.py:395] Building contexts for mmlu_jurisprudence on rank 0...\n",
      "100%|████████████████████████████████████████| 108/108 [00:00<00:00, 828.25it/s]\n",
      "2024-04-16:00:33:51,545 INFO     [task.py:395] Building contexts for mmlu_moral_disputes on rank 0...\n",
      "100%|████████████████████████████████████████| 346/346 [00:00<00:00, 832.81it/s]\n",
      "2024-04-16:00:33:51,972 INFO     [task.py:395] Building contexts for mmlu_prehistory on rank 0...\n",
      "100%|████████████████████████████████████████| 324/324 [00:00<00:00, 834.71it/s]\n",
      "2024-04-16:00:33:52,372 INFO     [task.py:395] Building contexts for mmlu_world_religions on rank 0...\n",
      "100%|████████████████████████████████████████| 171/171 [00:00<00:00, 825.62it/s]\n",
      "2024-04-16:00:33:52,586 INFO     [task.py:395] Building contexts for mmlu_public_relations on rank 0...\n",
      "100%|████████████████████████████████████████| 110/110 [00:00<00:00, 825.95it/s]\n",
      "2024-04-16:00:33:52,724 INFO     [task.py:395] Building contexts for mmlu_high_school_macroeconomics on rank 0...\n",
      "100%|████████████████████████████████████████| 390/390 [00:00<00:00, 827.29it/s]\n",
      "2024-04-16:00:33:53,208 INFO     [task.py:395] Building contexts for mmlu_professional_psychology on rank 0...\n",
      "100%|████████████████████████████████████████| 612/612 [00:00<00:00, 823.68it/s]\n",
      "2024-04-16:00:33:53,974 INFO     [task.py:395] Building contexts for mmlu_high_school_microeconomics on rank 0...\n",
      "100%|████████████████████████████████████████| 238/238 [00:00<00:00, 827.88it/s]\n",
      "2024-04-16:00:33:54,270 INFO     [task.py:395] Building contexts for mmlu_sociology on rank 0...\n",
      "100%|████████████████████████████████████████| 201/201 [00:00<00:00, 820.75it/s]\n",
      "2024-04-16:00:33:54,523 INFO     [task.py:395] Building contexts for mmlu_high_school_geography on rank 0...\n",
      "100%|████████████████████████████████████████| 198/198 [00:00<00:00, 824.23it/s]\n",
      "2024-04-16:00:33:54,771 INFO     [task.py:395] Building contexts for mmlu_econometrics on rank 0...\n",
      "100%|████████████████████████████████████████| 114/114 [00:00<00:00, 823.19it/s]\n",
      "2024-04-16:00:33:54,914 INFO     [task.py:395] Building contexts for mmlu_us_foreign_policy on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 830.23it/s]\n",
      "2024-04-16:00:33:55,039 INFO     [task.py:395] Building contexts for mmlu_human_sexuality on rank 0...\n",
      "100%|████████████████████████████████████████| 131/131 [00:00<00:00, 823.99it/s]\n",
      "2024-04-16:00:33:55,203 INFO     [task.py:395] Building contexts for mmlu_security_studies on rank 0...\n",
      "100%|████████████████████████████████████████| 245/245 [00:00<00:00, 823.94it/s]\n",
      "2024-04-16:00:33:55,510 INFO     [task.py:395] Building contexts for mmlu_high_school_psychology on rank 0...\n",
      "100%|████████████████████████████████████████| 545/545 [00:00<00:00, 829.91it/s]\n",
      "2024-04-16:00:33:56,186 INFO     [task.py:395] Building contexts for mmlu_high_school_government_and_politics on rank 0...\n",
      "100%|████████████████████████████████████████| 193/193 [00:00<00:00, 825.00it/s]\n",
      "2024-04-16:00:33:56,427 INFO     [task.py:395] Building contexts for mmlu_human_aging on rank 0...\n",
      "100%|████████████████████████████████████████| 223/223 [00:00<00:00, 829.99it/s]\n",
      "2024-04-16:00:33:56,704 INFO     [task.py:395] Building contexts for mmlu_global_facts on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 818.95it/s]\n",
      "2024-04-16:00:33:56,830 INFO     [task.py:395] Building contexts for mmlu_medical_genetics on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 826.13it/s]\n",
      "2024-04-16:00:33:56,955 INFO     [task.py:395] Building contexts for mmlu_virology on rank 0...\n",
      "100%|████████████████████████████████████████| 166/166 [00:00<00:00, 824.32it/s]\n",
      "2024-04-16:00:33:57,163 INFO     [task.py:395] Building contexts for mmlu_professional_medicine on rank 0...\n",
      "100%|████████████████████████████████████████| 272/272 [00:00<00:00, 822.97it/s]\n",
      "2024-04-16:00:33:57,504 INFO     [task.py:395] Building contexts for mmlu_miscellaneous on rank 0...\n",
      "100%|████████████████████████████████████████| 783/783 [00:00<00:00, 830.21it/s]\n",
      "2024-04-16:00:33:58,473 INFO     [task.py:395] Building contexts for mmlu_business_ethics on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 816.75it/s]\n",
      "2024-04-16:00:33:58,602 INFO     [task.py:395] Building contexts for mmlu_professional_accounting on rank 0...\n",
      "100%|████████████████████████████████████████| 282/282 [00:00<00:00, 823.26it/s]\n",
      "2024-04-16:00:33:58,954 INFO     [task.py:395] Building contexts for mmlu_nutrition on rank 0...\n",
      "100%|████████████████████████████████████████| 306/306 [00:00<00:00, 827.44it/s]\n",
      "2024-04-16:00:33:59,335 INFO     [task.py:395] Building contexts for mmlu_management on rank 0...\n",
      "100%|████████████████████████████████████████| 103/103 [00:00<00:00, 823.39it/s]\n",
      "2024-04-16:00:33:59,464 INFO     [task.py:395] Building contexts for mmlu_clinical_knowledge on rank 0...\n",
      "100%|████████████████████████████████████████| 265/265 [00:00<00:00, 824.93it/s]\n",
      "2024-04-16:00:33:59,796 INFO     [task.py:395] Building contexts for mmlu_college_medicine on rank 0...\n",
      "100%|████████████████████████████████████████| 173/173 [00:00<00:00, 827.09it/s]\n",
      "2024-04-16:00:34:00,011 INFO     [task.py:395] Building contexts for mmlu_marketing on rank 0...\n",
      "100%|████████████████████████████████████████| 234/234 [00:00<00:00, 819.00it/s]\n",
      "2024-04-16:00:34:00,305 INFO     [task.py:395] Building contexts for mmlu_college_chemistry on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 820.67it/s]\n",
      "2024-04-16:00:34:00,431 INFO     [task.py:395] Building contexts for mmlu_machine_learning on rank 0...\n",
      "100%|████████████████████████████████████████| 112/112 [00:00<00:00, 820.45it/s]\n",
      "2024-04-16:00:34:00,573 INFO     [task.py:395] Building contexts for mmlu_high_school_computer_science on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 833.88it/s]\n",
      "2024-04-16:00:34:00,698 INFO     [task.py:395] Building contexts for mmlu_high_school_physics on rank 0...\n",
      "100%|████████████████████████████████████████| 151/151 [00:00<00:00, 833.61it/s]\n",
      "2024-04-16:00:34:00,886 INFO     [task.py:395] Building contexts for mmlu_conceptual_physics on rank 0...\n",
      "100%|████████████████████████████████████████| 235/235 [00:00<00:00, 839.96it/s]\n",
      "2024-04-16:00:34:01,174 INFO     [task.py:395] Building contexts for mmlu_high_school_statistics on rank 0...\n",
      "100%|████████████████████████████████████████| 216/216 [00:00<00:00, 349.27it/s]\n",
      "2024-04-16:00:34:01,800 INFO     [task.py:395] Building contexts for mmlu_college_mathematics on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 817.01it/s]\n",
      "2024-04-16:00:34:01,929 INFO     [task.py:395] Building contexts for mmlu_high_school_biology on rank 0...\n",
      "100%|████████████████████████████████████████| 310/310 [00:00<00:00, 825.47it/s]\n",
      "2024-04-16:00:34:02,317 INFO     [task.py:395] Building contexts for mmlu_high_school_mathematics on rank 0...\n",
      "100%|████████████████████████████████████████| 270/270 [00:00<00:00, 829.33it/s]\n",
      "2024-04-16:00:34:02,652 INFO     [task.py:395] Building contexts for mmlu_elementary_mathematics on rank 0...\n",
      "100%|████████████████████████████████████████| 378/378 [00:00<00:00, 831.81it/s]\n",
      "2024-04-16:00:34:03,120 INFO     [task.py:395] Building contexts for mmlu_college_physics on rank 0...\n",
      "100%|████████████████████████████████████████| 102/102 [00:00<00:00, 823.93it/s]\n",
      "2024-04-16:00:34:03,248 INFO     [task.py:395] Building contexts for mmlu_astronomy on rank 0...\n",
      "100%|████████████████████████████████████████| 152/152 [00:00<00:00, 830.15it/s]\n",
      "2024-04-16:00:34:03,436 INFO     [task.py:395] Building contexts for mmlu_college_computer_science on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 830.85it/s]\n",
      "2024-04-16:00:34:03,561 INFO     [task.py:395] Building contexts for mmlu_high_school_chemistry on rank 0...\n",
      "100%|████████████████████████████████████████| 203/203 [00:00<00:00, 832.98it/s]\n",
      "2024-04-16:00:34:03,812 INFO     [task.py:395] Building contexts for mmlu_computer_security on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 829.32it/s]\n",
      "2024-04-16:00:34:03,937 INFO     [task.py:395] Building contexts for mmlu_anatomy on rank 0...\n",
      "100%|████████████████████████████████████████| 135/135 [00:00<00:00, 833.04it/s]\n",
      "2024-04-16:00:34:04,105 INFO     [task.py:395] Building contexts for mmlu_college_biology on rank 0...\n",
      "100%|████████████████████████████████████████| 144/144 [00:00<00:00, 829.15it/s]\n",
      "2024-04-16:00:34:04,285 INFO     [task.py:395] Building contexts for mmlu_abstract_algebra on rank 0...\n",
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 824.72it/s]\n",
      "2024-04-16:00:34:04,410 INFO     [task.py:395] Building contexts for mmlu_electrical_engineering on rank 0...\n",
      "100%|████████████████████████████████████████| 145/145 [00:00<00:00, 826.52it/s]\n",
      "2024-04-16:00:34:04,591 INFO     [evaluator.py:379] Running loglikelihood requests\n",
      "Running loglikelihood requests: 100%|████| 71688/71688 [10:37<00:00, 112.50it/s]\n",
      "hf (pretrained=models/mixtral_cut_0,parallelize=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 32\n",
      "|                 Tasks                 |Version|Filter|n-shot|Metric|Value |   |Stderr|\n",
      "|---------------------------------------|-------|------|-----:|------|-----:|---|-----:|\n",
      "|wmdp                                   |N/A    |none  |     0|acc   |0.3309|±  |0.0075|\n",
      "| - wmdp_bio                            |      0|none  |     0|acc   |0.3170|±  |0.0132|\n",
      "| - wmdp_chem                           |      0|none  |     0|acc   |0.5340|±  |0.0246|\n",
      "| - wmdp_cyber                          |      0|none  |     0|acc   |0.3011|±  |0.0097|\n",
      "|mmlu                                   |N/A    |none  |     0|acc   |0.6661|±  |0.0037|\n",
      "| - humanities                          |N/A    |none  |     0|acc   |0.6123|±  |0.0064|\n",
      "|  - formal_logic                       |      0|none  |     0|acc   |0.5000|±  |0.0447|\n",
      "|  - high_school_european_history       |      0|none  |     0|acc   |0.7939|±  |0.0316|\n",
      "|  - high_school_us_history             |      0|none  |     0|acc   |0.8627|±  |0.0242|\n",
      "|  - high_school_world_history          |      0|none  |     0|acc   |0.8819|±  |0.0210|\n",
      "|  - international_law                  |      0|none  |     0|acc   |0.8595|±  |0.0317|\n",
      "|  - jurisprudence                      |      0|none  |     0|acc   |0.7963|±  |0.0389|\n",
      "|  - logical_fallacies                  |      0|none  |     0|acc   |0.7791|±  |0.0326|\n",
      "|  - moral_disputes                     |      0|none  |     0|acc   |0.8006|±  |0.0215|\n",
      "|  - moral_scenarios                    |      0|none  |     0|acc   |0.2827|±  |0.0151|\n",
      "|  - philosophy                         |      0|none  |     0|acc   |0.7492|±  |0.0246|\n",
      "|  - prehistory                         |      0|none  |     0|acc   |0.8179|±  |0.0215|\n",
      "|  - professional_law                   |      0|none  |     0|acc   |0.5267|±  |0.0128|\n",
      "|  - world_religions                    |      0|none  |     0|acc   |0.8713|±  |0.0257|\n",
      "| - other                               |N/A    |none  |     0|acc   |0.7039|±  |0.0077|\n",
      "|  - business_ethics                    |      0|none  |     0|acc   |0.6800|±  |0.0469|\n",
      "|  - clinical_knowledge                 |      0|none  |     0|acc   |0.7283|±  |0.0274|\n",
      "|  - college_medicine                   |      0|none  |     0|acc   |0.6416|±  |0.0366|\n",
      "|  - global_facts                       |      0|none  |     0|acc   |0.3700|±  |0.0485|\n",
      "|  - human_aging                        |      0|none  |     0|acc   |0.6996|±  |0.0308|\n",
      "|  - management                         |      0|none  |     0|acc   |0.8252|±  |0.0376|\n",
      "|  - marketing                          |      0|none  |     0|acc   |0.9103|±  |0.0187|\n",
      "|  - medical_genetics                   |      0|none  |     0|acc   |0.6100|±  |0.0490|\n",
      "|  - miscellaneous                      |      0|none  |     0|acc   |0.8608|±  |0.0124|\n",
      "|  - nutrition                          |      0|none  |     0|acc   |0.6961|±  |0.0263|\n",
      "|  - professional_accounting            |      0|none  |     0|acc   |0.5567|±  |0.0296|\n",
      "|  - professional_medicine              |      0|none  |     0|acc   |0.6029|±  |0.0297|\n",
      "|  - virology                           |      0|none  |     0|acc   |0.3313|±  |0.0366|\n",
      "| - social_sciences                     |N/A    |none  |     0|acc   |0.7920|±  |0.0072|\n",
      "|  - econometrics                       |      0|none  |     0|acc   |0.5526|±  |0.0468|\n",
      "|  - high_school_geography              |      0|none  |     0|acc   |0.8636|±  |0.0245|\n",
      "|  - high_school_government_and_politics|      0|none  |     0|acc   |0.9482|±  |0.0160|\n",
      "|  - high_school_macroeconomics         |      0|none  |     0|acc   |0.7205|±  |0.0228|\n",
      "|  - high_school_microeconomics         |      0|none  |     0|acc   |0.7647|±  |0.0276|\n",
      "|  - high_school_psychology             |      0|none  |     0|acc   |0.8679|±  |0.0145|\n",
      "|  - human_sexuality                    |      0|none  |     0|acc   |0.7634|±  |0.0373|\n",
      "|  - professional_psychology            |      0|none  |     0|acc   |0.7451|±  |0.0176|\n",
      "|  - public_relations                   |      0|none  |     0|acc   |0.7000|±  |0.0439|\n",
      "|  - security_studies                   |      0|none  |     0|acc   |0.7551|±  |0.0275|\n",
      "|  - sociology                          |      0|none  |     0|acc   |0.8756|±  |0.0233|\n",
      "|  - us_foreign_policy                  |      0|none  |     0|acc   |0.9000|±  |0.0302|\n",
      "| - stem                                |N/A    |none  |     0|acc   |0.5864|±  |0.0084|\n",
      "|  - abstract_algebra                   |      0|none  |     0|acc   |0.3800|±  |0.0488|\n",
      "|  - anatomy                            |      0|none  |     0|acc   |0.7111|±  |0.0392|\n",
      "|  - astronomy                          |      0|none  |     0|acc   |0.7895|±  |0.0332|\n",
      "|  - college_biology                    |      0|none  |     0|acc   |0.7917|±  |0.0340|\n",
      "|  - college_chemistry                  |      0|none  |     0|acc   |0.5000|±  |0.0503|\n",
      "|  - college_computer_science           |      0|none  |     0|acc   |0.6600|±  |0.0476|\n",
      "|  - college_mathematics                |      0|none  |     0|acc   |0.3900|±  |0.0490|\n",
      "|  - college_physics                    |      0|none  |     0|acc   |0.4412|±  |0.0494|\n",
      "|  - computer_security                  |      0|none  |     0|acc   |0.5500|±  |0.0500|\n",
      "|  - conceptual_physics                 |      0|none  |     0|acc   |0.6511|±  |0.0312|\n",
      "|  - electrical_engineering             |      0|none  |     0|acc   |0.6552|±  |0.0396|\n",
      "|  - elementary_mathematics             |      0|none  |     0|acc   |0.4894|±  |0.0257|\n",
      "|  - high_school_biology                |      0|none  |     0|acc   |0.7903|±  |0.0232|\n",
      "|  - high_school_chemistry              |      0|none  |     0|acc   |0.5616|±  |0.0349|\n",
      "|  - high_school_computer_science       |      0|none  |     0|acc   |0.7300|±  |0.0446|\n",
      "|  - high_school_mathematics            |      0|none  |     0|acc   |0.3630|±  |0.0293|\n",
      "|  - high_school_physics                |      0|none  |     0|acc   |0.4636|±  |0.0407|\n",
      "|  - high_school_statistics             |      0|none  |     0|acc   |0.5880|±  |0.0336|\n",
      "|  - machine_learning                   |      0|none  |     0|acc   |0.5893|±  |0.0467|\n",
      "\n",
      "|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|\n",
      "|------------------|-------|------|-----:|------|-----:|---|-----:|\n",
      "|wmdp              |N/A    |none  |     0|acc   |0.3309|±  |0.0075|\n",
      "|mmlu              |N/A    |none  |     0|acc   |0.6661|±  |0.0037|\n",
      "| - humanities     |N/A    |none  |     0|acc   |0.6123|±  |0.0064|\n",
      "| - other          |N/A    |none  |     0|acc   |0.7039|±  |0.0077|\n",
      "| - social_sciences|N/A    |none  |     0|acc   |0.7920|±  |0.0072|\n",
      "| - stem           |N/A    |none  |     0|acc   |0.5864|±  |0.0084|\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!lm-eval --model hf \\\n",
    "    --model_args pretrained=models/mixtral_rmu,parallelize=True \\\n",
    "    --tasks wmdp,mmlu \\\n",
    "    --batch_size=32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
