{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Diagnostic Task Evals\n",
    "\n",
    "This notebook runs the diagnostic task evals for the models reported in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "\n",
    "def run_model_evals(tasks_and_models, model_class):\n",
    "    # Loop through the list and run the eval script with each parameter\n",
    "    for task in tasks_and_models.keys():\n",
    "        models = tasks_and_models[task]\n",
    "        for model, seed in models:\n",
    "            command = f\"python diagnostic_task_eval.py {task} {model} {model_class} --random_seed {seed} --hard_delete\"\n",
    "            # Execute the command using subprocess.run\n",
    "            subprocess.run(command, shell=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### T5 Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model...\n",
      "vowel_removal\n",
      "Path: /nlp/scr3/nlp/llms-in-llms/mrt5/models/vowel_removal/T5/t5_simple_vowel_removal_seed121/checkpoints/checkpoint-30000\n",
      "\n",
      "Loading vowel_removal test dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 249/250 [00:41<00:00,  5.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval cross entropy loss: 5.142786203714422e-05\n",
      "Eval percent deleted tokens: 0.0\n",
      "Eval new sequence length: 64.0\n",
      "Eval token accuracy: 0.999987410263958\n",
      "Eval sequence accuracy: 0.99934375\n",
      "Examples evaluated: 32000\n",
      "\n",
      "Loading model...\n",
      "contextual_vowel_removal\n",
      "Path: /nlp/scr3/nlp/llms-in-llms/mrt5/models/contextual_vowel_removal/T5/t5_contextual_vowel_removal_seed11/checkpoints/checkpoint-30000\n",
      "\n",
      "Loading contextual_vowel_removal test dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 249/250 [00:43<00:00,  5.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval cross entropy loss: 5.095320295799866e-05\n",
      "Eval percent deleted tokens: 0.0\n",
      "Eval new sequence length: 64.0\n",
      "Eval token accuracy: 0.999982861869418\n",
      "Eval sequence accuracy: 0.99909375\n",
      "Examples evaluated: 32000\n",
      "\n",
      "Loading model...\n",
      "merge_ABC\n",
      "Path: /nlp/scr3/nlp/llms-in-llms/mrt5/models/merge_ABC/T5/t5_merge_ABC_seed241/checkpoints/checkpoint-30000\n",
      "\n",
      "Loading merge_ABC test dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 249/250 [00:42<00:00,  5.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval cross entropy loss: 7.225487371078999e-05\n",
      "Eval percent deleted tokens: 0.0\n",
      "Eval new sequence length: 64.0\n",
      "Eval token accuracy: 0.9999715122233978\n",
      "Eval sequence accuracy: 0.9984375\n",
      "Examples evaluated: 32000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# List of parameters you want to loop through\n",
    "tasks_and_models = {\n",
    "    \"vowel_removal\": [(\"t5_simple_vowel_removal\", 121)],\n",
    "    \"contextual_vowel_removal\": [(\"t5_contextual_vowel_removal\", 11)],\n",
    "    \"merge_ABC\": [(\"t5_merge_ABC\", 241)],\n",
    "}\n",
    "\n",
    "run_model_evals(tasks_and_models, \"T5\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MrT5 Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model...\n",
      "vowel_removal\n",
      "Path: /nlp/scr3/nlp/llms-in-llms/mrt5/models/vowel_removal/MrT5/mrt5_simple_vowel_removal_1e-4_seed429/checkpoints/checkpoint-30000\n",
      "\n",
      "Loading vowel_removal test dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 249/250 [00:42<00:00,  5.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval cross entropy loss: 4.533697633360134e-05\n",
      "Eval percent deleted tokens: 18.58466796875\n",
      "Eval new sequence length: 59.46\n",
      "Eval token accuracy: 0.9999850047236116\n",
      "Eval sequence accuracy: 0.99921875\n",
      "Examples evaluated: 32000\n",
      "\n",
      "Loading model...\n",
      "vowel_removal\n",
      "Path: /nlp/scr3/nlp/llms-in-llms/mrt5/models/vowel_removal/MrT5/mrt5_simple_vowel_removal_1e-3_seed93/checkpoints/checkpoint-30000\n",
      "\n",
      "Loading vowel_removal test dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 249/250 [00:42<00:00,  5.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval cross entropy loss: 5.423742860853053e-05\n",
      "Eval percent deleted tokens: 20.14716796875\n",
      "Eval new sequence length: 58.46\n",
      "Eval token accuracy: 0.9999838014902892\n",
      "Eval sequence accuracy: 0.9991875\n",
      "Examples evaluated: 32000\n",
      "\n",
      "Loading model...\n",
      "contextual_vowel_removal\n",
      "Path: /nlp/scr3/nlp/llms-in-llms/mrt5/models/contextual_vowel_removal/MrT5/mrt5_contextual_vowel_removal_1e-2_seed934/checkpoints/checkpoint-30000\n",
      "\n",
      "Loading contextual_vowel_removal test dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 249/250 [00:43<00:00,  5.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval cross entropy loss: 0.00013472470228134624\n",
      "Eval percent deleted tokens: 18.96748046875\n",
      "Eval new sequence length: 60.56\n",
      "Eval token accuracy: 0.99994624392371\n",
      "Eval sequence accuracy: 0.99715625\n",
      "Examples evaluated: 32000\n",
      "\n",
      "Loading model...\n",
      "contextual_vowel_removal\n",
      "Path: /nlp/scr3/nlp/llms-in-llms/mrt5/models/contextual_vowel_removal/MrT5/mrt5_contextual_vowel_removal_1e-3_seed510/checkpoints/checkpoint-30000\n",
      "\n",
      "Loading contextual_vowel_removal test dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 249/250 [00:42<00:00,  5.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval cross entropy loss: 0.00010874113667932761\n",
      "Eval percent deleted tokens: 1.5625\n",
      "Eval new sequence length: 63.0\n",
      "Eval token accuracy: 0.9999592178636283\n",
      "Eval sequence accuracy: 0.99784375\n",
      "Examples evaluated: 32000\n",
      "\n",
      "Loading model...\n",
      "merge_ABC\n",
      "Path: /nlp/scr3/nlp/llms-in-llms/mrt5/models/merge_ABC/MrT5/mrt5_merge_ABC_1e-2_seed14/checkpoints/checkpoint-30000\n",
      "\n",
      "Loading merge_ABC test dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 249/250 [00:42<00:00,  5.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval cross entropy loss: 0.00025213346126292893\n",
      "Eval percent deleted tokens: 10.146875\n",
      "Eval new sequence length: 61.996\n",
      "Eval token accuracy: 0.9998968378415278\n",
      "Eval sequence accuracy: 0.994375\n",
      "Examples evaluated: 32000\n",
      "\n",
      "Loading model...\n",
      "merge_ABC\n",
      "Path: /nlp/scr3/nlp/llms-in-llms/mrt5/models/merge_ABC/MrT5/mrt5_merge_ABC_1.5e-2_seed123/checkpoints/checkpoint-30000\n",
      "\n",
      "Loading merge_ABC test dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 249/250 [00:42<00:00,  5.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval cross entropy loss: 0.0004563363664965436\n",
      "Eval percent deleted tokens: 17.371142578125\n",
      "Eval new sequence length: 60.996\n",
      "Eval token accuracy: 0.9998471618873562\n",
      "Eval sequence accuracy: 0.991875\n",
      "Examples evaluated: 32000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# List of parameters you want to loop through\n",
    "tasks_and_models = {\n",
    "    \"vowel_removal\":\n",
    "      [\n",
    "       (\"mrt5_simple_vowel_removal_1e-4\", 429),\n",
    "       (\"mrt5_simple_vowel_removal_1e-3\", 93),\n",
    "      ],\n",
    "    \"contextual_vowel_removal\":\n",
    "      [\n",
    "        (\"mrt5_contextual_vowel_removal_1e-2\", 934),\n",
    "        (\"mrt5_contextual_vowel_removal_1e-3\", 510),\n",
    "      ],\n",
    "    \"merge_ABC\":\n",
    "      [\n",
    "        (\"mrt5_merge_ABC_1e-2\", 14),\n",
    "        (\"mrt5_merge_ABC_1.5e-2\", 123)\n",
    "      ],\n",
    "}\n",
    "\n",
    "run_model_evals(tasks_and_models, \"MrT5\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "charlm-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
