{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "        <script type=\"text/javascript\">\n",
       "        window.PlotlyConfig = {MathJaxConfig: 'local'};\n",
       "        if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}\n",
       "        if (typeof require !== 'undefined') {\n",
       "        require.undef(\"plotly\");\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': ['https://cdn.plot.ly/plotly-2.27.0.min']\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(Plotly) {\n",
       "            window._Plotly = Plotly;\n",
       "        });\n",
       "        }\n",
       "        </script>\n",
       "        "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-01-14 22:43:41,561] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "from IPython.display import display, Markdown as md\n",
    "import pandas as pd\n",
    "from transformers import (\n",
    "    Trainer,\n",
    "    # TrainingArguments,\n",
    "    DataCollatorForLanguageModeling,\n",
    ")\n",
    "from datasets import Dataset, DatasetDict\n",
    "\n",
    "from lib_llm.models.load import (\n",
    "    ModelConfig,\n",
    "    load_tokenizer,\n",
    "    load_model_tokenizer,\n",
    ")\n",
    "from lib_llm.training import (\n",
    "    TrainingConfig,\n",
    "    train,\n",
    "    OptimizerConfig,\n",
    "    TrainingArguments,\n",
    ")\n",
    "from lib_project.experiment import ExperimentID\n",
    "from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics\n",
    "from data.synthetic_strings.random import (\n",
    "    RandomStringConfig,\n",
    "    generate_random_strings,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d635f657-388a-457b-8fbb-2f66ebfcda69",
   "metadata": {},
   "source": [
    "# Reimplementing Memorization\n",
    "\n",
    "We are troubleshooting why Pythia models other than the 1B parameter one don't train properly anymore."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f8061296-a3a0-4b16-830e-da7299eb42c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_type = \"pythia\"\n",
    "model, tokenizer = load_model_tokenizer(\n",
    "    ModelConfig(\n",
    "        model_id=\"EleutherAI/pythia-70m\",\n",
    "        # base_dir=\"/home/exp/base_models\",\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6d9e3e60-dd13-4d25-98d2-88b2b1bfb82a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "conf = RandomStringConfig(\n",
    "    seed_id=0,\n",
    "    num_tokens=1024,\n",
    "    tokenizer_type=tokenizer_type,\n",
    "    num_partitions=1,\n",
    "    alphabet_size=26,\n",
    ")\n",
    "data = generate_random_strings(conf, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b92cc8f4-3bc0-4cd0-bb7d-8bcfce60739c",
   "metadata": {},
   "source": [
    "## Minimal project code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "467871d5-929a-4de7-b18c-5ca7e0034ebd",
   "metadata": {},
   "outputs": [],
   "source": [
    "string = data.tokens[0]\n",
    "encoding = tokenizer(string)\n",
    "dataset = Dataset.from_dict({\n",
    "    \"input_ids\": [[iid[0] for iid in encoding.input_ids]],\n",
    "    \"attention_mask\": [[atm[0] for atm in encoding.attention_mask]],\n",
    "})\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3765a38d-4785-484f-b977-bbb584eb9398",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=TrainingArguments(\n",
    "        output_dir=\"./test_output\",\n",
    "        learning_rate=5e-5,\n",
    "        num_train_epochs=100,\n",
    "        logging_strategy=\"epoch\",\n",
    "        logging_steps=1,\n",
    "    ),\n",
    "    train_dataset=dataset,\n",
    "    eval_dataset=dataset,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "392c2569-6735-49cc-bd93-49d5a6ec85c5",
   "metadata": {},
   "source": [
    "## With training code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4c2c288-6517-4bb5-a5de-286dd738b354",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Without eval callbacks\n",
    "model.to(\"cuda:0\")\n",
    "training_res = train(\n",
    "    ExperimentID(\"mem_reimplementation\", \"training_code\", 0),\n",
    "    (\"pyt-70m\", model),\n",
    "    (\"testdata\", DatasetDict({\n",
    "        \"train\": dataset,\n",
    "        \"test\": dataset,\n",
    "    })),\n",
    "    config=TrainingConfig(\n",
    "        seed=483,\n",
    "        optimizer=OptimizerConfig(\n",
    "            learning_rate=5e-5,\n",
    "        ),\n",
    "        args=TrainingArguments(\n",
    "            num_train_epochs=100,\n",
    "            eval_steps=1,\n",
    "            deepspeed=\"experiments/memorization_dynamics/ds_memorization_config.json\",\n",
    "        ),\n",
    "    ),\n",
    "    tokenizer=tokenizer,\n",
    "    data_already_preprocessed=True,\n",
    "    run_callbacks_initially=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22ae9428-b434-4afd-b8d0-7b6031ab3fc2",
   "metadata": {},
   "source": [
    "## With default dataset construction code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63ca43cb-a0ea-4475-af11-5088afcdf8c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_dataset = data.dataset()\n",
    "# Without eval callbacks\n",
    "model.to(\"cuda:0\")\n",
    "training_res = train(\n",
    "    ExperimentID(\"mem_reimplementation\", \"with_default_dataset\", 0),\n",
    "    (\"pyt-70m\", model),\n",
    "    (\"testdata\", default_dataset),\n",
    "    config=TrainingConfig(\n",
    "        seed=483,\n",
    "        optimizer=OptimizerConfig(\n",
    "            learning_rate=5e-5,\n",
    "        ),\n",
    "        args=TrainingArguments(\n",
    "            num_train_epochs=100,\n",
    "            evaluation_strategy=\"steps\",\n",
    "            eval_steps=1,\n",
    "            logging_steps=10,\n",
    "            deepspeed=\"experiments/memorization_dynamics/ds_memorization_config.json\",\n",
    "        ),\n",
    "    ),\n",
    "    tokenizer=tokenizer,\n",
    "    data_already_preprocessed=True,\n",
    "    run_callbacks_initially=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cce4ea88-a31f-4806-bd5c-7260b9f052fe",
   "metadata": {},
   "source": [
    "## With eval callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d26f498f-b1f9-4e8d-b4d7-3a9b8450466b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mANONYMOUS\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.16.2 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.12"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/ANONYMOUS/llm-1/work/ANONYMOUS/llm_memorization/src/wandb/run-20240114_224407-1pw5ca9n</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/ANONYMOUS/llms/runs/1pw5ca9n' target=\"_blank\">['', '', 'sid_-1', 'training', 'pyt-70m', 'testdata']</a></strong> to <a href='https://wandb.ai/ANONYMOUS/llms' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/ANONYMOUS/llms' target=\"_blank\">https://wandb.ai/ANONYMOUS/llms</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/ANONYMOUS/llms/runs/1pw5ca9n' target=\"_blank\">https://wandb.ai/ANONYMOUS/llms/runs/1pw5ca9n</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training_args TrainingArguments(\n",
      "_n_gpu=2,\n",
      "adafactor=False,\n",
      "adam_beta1=0.9,\n",
      "adam_beta2=0.999,\n",
      "adam_epsilon=1e-08,\n",
      "auto_find_batch_size=False,\n",
      "bf16=False,\n",
      "bf16_full_eval=False,\n",
      "data_seed=None,\n",
      "dataloader_drop_last=False,\n",
      "dataloader_num_workers=0,\n",
      "dataloader_pin_memory=True,\n",
      "ddp_backend=None,\n",
      "ddp_broadcast_buffers=None,\n",
      "ddp_bucket_cap_mb=None,\n",
      "ddp_find_unused_parameters=None,\n",
      "ddp_timeout=1800,\n",
      "debug=[],\n",
      "deepspeed=experiments/memorization_dynamics/ds_memorization_config.json,\n",
      "disable_tqdm=False,\n",
      "dispatch_batches=None,\n",
      "do_eval=True,\n",
      "do_predict=False,\n",
      "do_train=False,\n",
      "eval_accumulation_steps=None,\n",
      "eval_delay=0,\n",
      "eval_steps=1,\n",
      "evaluation_strategy=IntervalStrategy.STEPS,\n",
      "fp16=False,\n",
      "fp16_backend=auto,\n",
      "fp16_full_eval=False,\n",
      "fp16_opt_level=O1,\n",
      "fsdp=[],\n",
      "fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},\n",
      "fsdp_min_num_params=0,\n",
      "fsdp_transformer_layer_cls_to_wrap=None,\n",
      "full_determinism=False,\n",
      "generation_config=None,\n",
      "gradient_accumulation_steps=1,\n",
      "gradient_checkpointing=False,\n",
      "gradient_checkpointing_kwargs=None,\n",
      "greater_is_better=None,\n",
      "group_by_length=False,\n",
      "half_precision_backend=auto,\n",
      "hub_always_push=False,\n",
      "hub_model_id=None,\n",
      "hub_private_repo=False,\n",
      "hub_strategy=HubStrategy.EVERY_SAVE,\n",
      "hub_token=<HUB_TOKEN>,\n",
      "ignore_data_skip=False,\n",
      "include_inputs_for_metrics=False,\n",
      "include_tokens_per_second=False,\n",
      "jit_mode_eval=False,\n",
      "label_names=None,\n",
      "label_smoothing_factor=0.0,\n",
      "learning_rate=5e-05,\n",
      "length_column_name=length,\n",
      "load_best_model_at_end=False,\n",
      "local_rank=0,\n",
      "log_level=passive,\n",
      "log_level_replica=warning,\n",
      "log_on_each_node=True,\n",
      "logging_dir=../artifacts/sid_-1/training/pyt-70m/testdata/runs,\n",
      "logging_first_step=False,\n",
      "logging_nan_inf_filter=True,\n",
      "logging_steps=10,\n",
      "logging_strategy=IntervalStrategy.STEPS,\n",
      "lr_scheduler_type=SchedulerType.LINEAR,\n",
      "max_grad_norm=1.0,\n",
      "max_steps=-1,\n",
      "metric_for_best_model=None,\n",
      "mp_parameters=,\n",
      "neftune_noise_alpha=None,\n",
      "no_cuda=False,\n",
      "num_train_epochs=100,\n",
      "optim=OptimizerNames.ADAMW_TORCH,\n",
      "optim_args=None,\n",
      "output_dir=../artifacts/sid_-1/training/pyt-70m/testdata,\n",
      "overwrite_output_dir=False,\n",
      "past_index=-1,\n",
      "per_device_eval_batch_size=8,\n",
      "per_device_train_batch_size=8,\n",
      "prediction_loss_only=False,\n",
      "push_to_hub=False,\n",
      "push_to_hub_model_id=None,\n",
      "push_to_hub_organization=None,\n",
      "push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\n",
      "ray_scope=last,\n",
      "remove_unused_columns=True,\n",
      "report_to=['wandb'],\n",
      "resume_from_checkpoint=None,\n",
      "run_name=['', '', 'sid_-1', 'training', 'pyt-70m', 'testdata'],\n",
      "save_on_each_node=False,\n",
      "save_safetensors=True,\n",
      "save_steps=500,\n",
      "save_strategy=IntervalStrategy.STEPS,\n",
      "save_total_limit=None,\n",
      "seed=42,\n",
      "skip_memory_metrics=True,\n",
      "split_batches=False,\n",
      "tf32=None,\n",
      "torch_coANONYMOUSle=False,\n",
      "torch_coANONYMOUSle_backend=None,\n",
      "torch_coANONYMOUSle_mode=None,\n",
      "torchdynamo=None,\n",
      "tpu_metrics_debug=False,\n",
      "tpu_num_cores=None,\n",
      "use_cpu=False,\n",
      "use_ipex=False,\n",
      "use_legacy_prediction_loop=False,\n",
      "use_mps_device=False,\n",
      "warmup_ratio=0.0,\n",
      "warmup_steps=0,\n",
      "weight_decay=0.0,\n",
      ")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='28' max='100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 28/100 00:43 < 02:00, 0.60 it/s, Epoch 27/100]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.652994</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.521860</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.335954</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.314292</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.281249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.264071</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.255247</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.235761</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.223027</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>3.216608</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>3.202103</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>3.183030</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>3.169605</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>3.152917</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>3.126855</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>3.099688</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>3.067245</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>3.023463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>3.511600</td>\n",
       "      <td>2.965610</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>3.120700</td>\n",
       "      <td>2.898153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>3.120700</td>\n",
       "      <td>2.851792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>3.120700</td>\n",
       "      <td>2.789148</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>3.120700</td>\n",
       "      <td>2.745824</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>3.120700</td>\n",
       "      <td>2.660734</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>3.120700</td>\n",
       "      <td>2.562968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>3.120700</td>\n",
       "      <td>2.545372</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1' max='1' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1/1 : < :]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n",
      "/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:68: UserWarning:\n",
      "\n",
      "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 11\u001b[0m\n\u001b[1;32m      3\u001b[0m model\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda:0\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      5\u001b[0m memorization_task \u001b[38;5;241m=\u001b[39m memorization_dynamics_metrics(\n\u001b[1;32m      6\u001b[0m     data\u001b[38;5;241m.\u001b[39malphabet_tokens,\n\u001b[1;32m      7\u001b[0m     data\u001b[38;5;241m.\u001b[39malphabet_token_ids,\n\u001b[1;32m      8\u001b[0m     default_dataset[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m      9\u001b[0m )\n\u001b[0;32m---> 11\u001b[0m training_res \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     12\u001b[0m \u001b[43m    \u001b[49m\u001b[43mExperimentID\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmem_reimplementation\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mwith_eval_callbacks\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     13\u001b[0m \u001b[43m    \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpyt-70m\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     14\u001b[0m \u001b[43m    \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtestdata\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdefault_dataset\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     15\u001b[0m \u001b[43m    \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTrainingConfig\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     16\u001b[0m \u001b[43m        \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m483\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     17\u001b[0m \u001b[43m        \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mOptimizerConfig\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     18\u001b[0m \u001b[43m            \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5e-5\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     19\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     20\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTrainingArguments\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     21\u001b[0m \u001b[43m            \u001b[49m\u001b[43mnum_train_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     22\u001b[0m \u001b[43m            \u001b[49m\u001b[43mevaluation_strategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msteps\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     23\u001b[0m \u001b[43m            \u001b[49m\u001b[43meval_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     24\u001b[0m \u001b[43m            \u001b[49m\u001b[43mlogging_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     25\u001b[0m \u001b[43m            \u001b[49m\u001b[43mdeepspeed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mexperiments/memorization_dynamics/ds_memorization_config.json\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     26\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     27\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     28\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mmemorization_task\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     29\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     30\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_already_preprocessed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m     31\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrun_callbacks_initially\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m     32\u001b[0m \u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/ANONYMOUS/llm-1/work/ANONYMOUS/llm_memorization/src/lib_llm/lib_llm/training/train.py:136\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(task_id, model_info, dataset_info, config, tokenizer, callbacks, run_callbacks_initially, data_already_preprocessed, set_subdirs, output_dir)\u001b[0m\n\u001b[1;32m    134\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m dataset \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    135\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMust provide a dataset to train on\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 136\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    137\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    138\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    139\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    140\u001b[0m \u001b[43m        \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    141\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtask_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtask_description\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    142\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    143\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    144\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrun_callbacks_initially\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_callbacks_initially\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    145\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdata_already_preprocessed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_already_preprocessed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    146\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    147\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    148\u001b[0m     model \u001b[38;5;241m=\u001b[39m _load_model(output_dir, config\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mnum_train_epochs)\n",
      "File \u001b[0;32m/ANONYMOUS/llm-1/work/ANONYMOUS/llm_memorization/src/lib_llm/lib_llm/training/train.py:238\u001b[0m, in \u001b[0;36m_train\u001b[0;34m(model, dataset, tokenizer, config, task_id, output_dir, callbacks, run_callbacks_initially, data_already_preprocessed)\u001b[0m\n\u001b[1;32m    229\u001b[0m         callback\u001b[38;5;241m.\u001b[39mon_evaluate(\n\u001b[1;32m    230\u001b[0m             trainer\u001b[38;5;241m.\u001b[39margs,\n\u001b[1;32m    231\u001b[0m             trainer_state,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    234\u001b[0m             tokenizer\u001b[38;5;241m=\u001b[39mtokenizer,\n\u001b[1;32m    235\u001b[0m         )\n\u001b[1;32m    237\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 238\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    239\u001b[0m \u001b[43m    \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m    240\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresume_path\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresume_path\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\n\u001b[1;32m    241\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    242\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    244\u001b[0m end_state \u001b[38;5;241m=\u001b[39m trainer\u001b[38;5;241m.\u001b[39mstate\n\u001b[1;32m    245\u001b[0m eval_steps \u001b[38;5;241m=\u001b[39m trainer\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39meval_steps\n",
      "File \u001b[0;32m/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/transformers/trainer.py:1555\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   1553\u001b[0m         hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m   1554\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1555\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1556\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1557\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1558\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1559\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1560\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/transformers/trainer.py:1922\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   1919\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m=\u001b[39m epoch \u001b[38;5;241m+\u001b[39m (step \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m steps_skipped) \u001b[38;5;241m/\u001b[39m steps_in_epoch\n\u001b[1;32m   1920\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 1922\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1923\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1924\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_substep_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
      "File \u001b[0;32m/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/transformers/trainer.py:2271\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2269\u001b[0m         metrics\u001b[38;5;241m.\u001b[39mupdate(dataset_metrics)\n\u001b[1;32m   2270\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2271\u001b[0m     metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2272\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_report_to_hp_search(trial, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step, metrics)\n\u001b[1;32m   2274\u001b[0m \u001b[38;5;66;03m# Run delayed LR scheduler now that metrics are populated\u001b[39;00m\n",
      "File \u001b[0;32m/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/transformers/trainer.py:3039\u001b[0m, in \u001b[0;36mTrainer.evaluate\u001b[0;34m(self, eval_dataset, ignore_keys, metric_key_prefix)\u001b[0m\n\u001b[1;32m   3035\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m DebugOption\u001b[38;5;241m.\u001b[39mTPU_METRICS_DEBUG \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdebug:\n\u001b[1;32m   3036\u001b[0m     \u001b[38;5;66;03m# tpu-comment: Logging debug metrics for PyTorch/XLA (coANONYMOUSle, execute times, ops, etc.)\u001b[39;00m\n\u001b[1;32m   3037\u001b[0m     xm\u001b[38;5;241m.\u001b[39mmaster_print(met\u001b[38;5;241m.\u001b[39mmetrics_report())\n\u001b[0;32m-> 3039\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcallback_handler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontrol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmetrics\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3041\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_memory_tracker\u001b[38;5;241m.\u001b[39mstop_and_update_metrics(output\u001b[38;5;241m.\u001b[39mmetrics)\n\u001b[1;32m   3043\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mmetrics\n",
      "File \u001b[0;32m/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/transformers/trainer_callback.py:389\u001b[0m, in \u001b[0;36mCallbackHandler.on_evaluate\u001b[0;34m(self, args, state, control, metrics)\u001b[0m\n\u001b[1;32m    387\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mon_evaluate\u001b[39m(\u001b[38;5;28mself\u001b[39m, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):\n\u001b[1;32m    388\u001b[0m     control\u001b[38;5;241m.\u001b[39mshould_evaluate \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m--> 389\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall_event\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mon_evaluate\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontrol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetrics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetrics\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/transformers/trainer_callback.py:407\u001b[0m, in \u001b[0;36mCallbackHandler.call_event\u001b[0;34m(self, event, args, state, control, **kwargs)\u001b[0m\n\u001b[1;32m    405\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcall_event\u001b[39m(\u001b[38;5;28mself\u001b[39m, event, args, state, control, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    406\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m callback \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallbacks:\n\u001b[0;32m--> 407\u001b[0m         result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mevent\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    408\u001b[0m \u001b[43m            \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    409\u001b[0m \u001b[43m            \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    410\u001b[0m \u001b[43m            \u001b[49m\u001b[43mcontrol\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    411\u001b[0m \u001b[43m            \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    412\u001b[0m \u001b[43m            \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    413\u001b[0m \u001b[43m            \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    414\u001b[0m \u001b[43m            \u001b[49m\u001b[43mlr_scheduler\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlr_scheduler\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    415\u001b[0m \u001b[43m            \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    416\u001b[0m \u001b[43m            \u001b[49m\u001b[43meval_dataloader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meval_dataloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    417\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    418\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    419\u001b[0m         \u001b[38;5;66;03m# A Callback can skip the return of `control` if it doesn't change it.\u001b[39;00m\n\u001b[1;32m    420\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m/ANONYMOUS/llm-1/work/ANONYMOUS/llm_memorization/src/lib_llm/lib_llm/eval/metrics/eval_tasks.py:52\u001b[0m, in \u001b[0;36mEvaluationTask.on_evaluate\u001b[0;34m(self, args, state, control, **kwargs)\u001b[0m\n\u001b[1;32m     46\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m     47\u001b[0m eval_args \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m     48\u001b[0m     arg_name: kwargs[arg_name]\n\u001b[1;32m     49\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m arg_name \u001b[38;5;129;01min\u001b[39;00m kwargs\n\u001b[1;32m     50\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m arg_name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mevaluate_args\n\u001b[1;32m     51\u001b[0m }\n\u001b[0;32m---> 52\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43meval_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     53\u001b[0m epoch \u001b[38;5;241m=\u001b[39m state\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;28;01mif\u001b[39;00m state\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m     54\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstep_results\u001b[38;5;241m.\u001b[39mappend((epoch, res))\n",
      "File \u001b[0;32m/ANONYMOUS/llm-1/work/ANONYMOUS/llm_memorization/src/lib_llm/lib_llm/eval/metrics/eval_tasks.py:147\u001b[0m, in \u001b[0;36mSequenceEvaluationTask.evaluate\u001b[0;34m(self, model)\u001b[0m\n\u001b[1;32m    142\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m metric_name, metric \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmetrics\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m    143\u001b[0m     metric_args \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m    144\u001b[0m         arg_name: produce_value(arg_name)\u001b[38;5;241m.\u001b[39mto(ACCELERATOR)\n\u001b[1;32m    145\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m arg_name \u001b[38;5;129;01min\u001b[39;00m metric\u001b[38;5;241m.\u001b[39mrequired_args\n\u001b[1;32m    146\u001b[0m     }\n\u001b[0;32m--> 147\u001b[0m     \u001b[43mmetric\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmetric_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    149\u001b[0m     result \u001b[38;5;241m=\u001b[39m metric\u001b[38;5;241m.\u001b[39mcompute()\n\u001b[1;32m    150\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(result, torch\u001b[38;5;241m.\u001b[39mTensor)\n",
      "File \u001b[0;32m/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torchmetrics/metric.py:457\u001b[0m, in \u001b[0;36mMetric._wrap_update.<locals>.wrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    455\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_grad):\n\u001b[1;32m    456\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 457\u001b[0m         \u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    458\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[1;32m    459\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpected all tensors to be on\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(err):\n",
      "File \u001b[0;32m/ANONYMOUS/llm-1/work/ANONYMOUS/llm_memorization/src/lib_llm/lib_llm/eval/metrics/token_level.py:144\u001b[0m, in \u001b[0;36mLossMetric.update\u001b[0;34m(self, logits)\u001b[0m\n\u001b[1;32m    143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mupdate\u001b[39m(\u001b[38;5;28mself\u001b[39m, logits: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 144\u001b[0m     loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    145\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlogits\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mswapaxes\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    146\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtarget_sequences\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    147\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;66;43;03m# Get a loss value for each element\u001b[39;49;00m\n\u001b[1;32m    148\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreduction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnone\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    149\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    150\u001b[0m     \u001b[38;5;66;03m# loss = loss.mean(dim=-1).reshape(-1)\u001b[39;00m\n\u001b[1;32m    151\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlosses\u001b[38;5;241m.\u001b[39mappend(loss)\n",
      "File \u001b[0;32m/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/torch/nn/functional.py:3053\u001b[0m, in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[0m\n\u001b[1;32m   3051\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m size_average \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   3052\u001b[0m     reduction \u001b[38;5;241m=\u001b[39m _Reduction\u001b[38;5;241m.\u001b[39mlegacy_get_string(size_average, reduce)\n\u001b[0;32m-> 3053\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_Reduction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_enum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f3a81de25d0>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f3a780d76d0, execution_count=5 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7f3a780d7250, raw_cell=\"default_dataset = data.dataset()\n",
      "# Without eval ca..\" store_history=True silent=False shell_futures=True cell_id=d26f498f-b1f9-4e8d-b4d7-3a9b8450466b> result=None>,),kwargs {}:\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "_WandbInit._pause_backend() takes 1 positional argument but 2 were given",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;31mTypeError\u001b[0m: _WandbInit._pause_backend() takes 1 positional argument but 2 were given"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Thread WriterThread:\n",
      "Traceback (most recent call last):\n",
      "  File \"/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/wandb/sdk/internal/internal_util.py\", line 49, in run\n",
      "    self._run()\n",
      "  File \"/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/wandb/sdk/internal/internal_util.py\", line 100, in _run\n",
      "    self._process(record)\n",
      "  File \"/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/wandb/sdk/internal/internal.py\", line 380, in _process\n",
      "    self._wm.write(record)\n",
      "  File \"/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/wandb/sdk/internal/writer.py\", line 154, in write\n",
      "    write_handler(record)\n",
      "  File \"/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/wandb/sdk/internal/writer.py\", line 135, in _write\n",
      "    self._write_record(record)\n",
      "  File \"/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/wandb/sdk/internal/writer.py\", line 109, in _write_record\n",
      "    ret = self._ds.write(record)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/wandb/sdk/internal/datastore.py\", line 289, in write\n",
      "    ret = self._write_data(s)\n",
      "          ^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/ANONYMOUS/venvs/nobackup/ANONYMOUS/pypoetry_cache/virtualenvs/llm-memorization-aIMa-l9K-py3.11/lib/python3.11/site-packages/wandb/sdk/internal/datastore.py\", line 267, in _write_data\n",
      "    os.fsync(self._fp.fileno())\n",
      "OSError: [Errno 116] Stale file handle\n",
      "wandb: ERROR Internal wandb error: file data was not synced\n"
     ]
    }
   ],
   "source": [
    "default_dataset = data.dataset()\n",
    "# Without eval callbacks\n",
    "model.to(\"cuda:0\")\n",
    "\n",
    "memorization_task = memorization_dynamics_metrics(\n",
    "    data.alphabet_tokens,\n",
    "    data.alphabet_token_ids,\n",
    "    default_dataset[\"test\"],\n",
    ")\n",
    "\n",
    "training_res = train(\n",
    "    ExperimentID(\"mem_reimplementation\", \"with_eval_callbacks\", 0),\n",
    "    (\"pyt-70m\", model),\n",
    "    (\"testdata\", default_dataset),\n",
    "    config=TrainingConfig( \n",
    "        seed=483,\n",
    "        optimizer=OptimizerConfig(\n",
    "            learning_rate=5e-5,\n",
    "        ),\n",
    "        args=TrainingArguments(\n",
    "            num_train_epochs=100,\n",
    "            evaluation_strategy=\"steps\",\n",
    "            eval_steps=1,\n",
    "            logging_steps=10,\n",
    "            deepspeed=\"experiments/memorization_dynamics/ds_memorization_config.json\",\n",
    "        ),\n",
    "    ),\n",
    "    callbacks=[memorization_task],\n",
    "    tokenizer=tokenizer,\n",
    "    data_already_preprocessed=True,\n",
    "    run_callbacks_initially=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad7c4ed4-3a70-4eec-88e1-d024dadf9ce6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
