{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ae8b47d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4c096173",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"6\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4facebb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4c3f89fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Download GloVe (one time, ~1 min)\n",
    "# ! wget http://nlp.stanford.edu/data/glove.6B.zip\n",
    "# ! unzip glove.6B.zip\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "13b236a4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======================================================================\n",
      "SQUAD ANSWER GENERATION WITH GLOVE EMBEDDINGS\n",
      "======================================================================\n",
      "Model: 6L, 300d, 6h\n",
      "Device: cuda\n",
      "======================================================================\n",
      "\n",
      "✓ GloVe embeddings found: glove.6B.300d.txt\n",
      "Loading tokenizer...\n",
      "\n",
      "======================================================================\n",
      "LOADING GLOVE EMBEDDINGS\n",
      "======================================================================\n",
      "Reading GloVe file (this takes ~1 minute)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading GloVe: 100%|█████████████████| 400000/400000 [00:21<00:00, 18609.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✓ Loaded 400,000 GloVe vectors\n",
      "Matching tokenizer vocabulary with GloVee alpha finding is really significant - it suggests transformers might naturally want to do error correction but standard architectures prevent it!...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Matching: 100%|███████████████████████| 50257/50257 [00:00<00:00, 332747.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✓ Matched 43,058/50,257 tokens (85.7%)\n",
      "======================================================================\n",
      "\n",
      "Loading datasets...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train: 60000, Val: 5928\n",
      "\n",
      "Initializing model...\n",
      "Training for seed 1234\n",
      "Initializing token embeddings with GloVe...\n",
      "✓ Token embeddings initialized with GloVe\n",
      "Total parameters: 21.7M\n",
      "Trainable parameters: 21.7M\n",
      "\n",
      "======================================================================\n",
      "BASELINE (Standard LR)\n",
      "======================================================================\n",
      "\n",
      "Using differential LR: embeddings=0.1x, other=1.0x\n",
      "\n",
      "\n",
      "======================================================================\n",
      "EPOCH 1/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|██████████████████| 1875/1875 [06:51<00:00,  4.56it/s, loss=7.055]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 8.5839\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:18<00:00, 10.70it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:33<00:00,  8.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.0256 | Val F1: 0.0266 | Gap: -0.0010 | EM: 0.0100\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the\n",
      "  F1: 0.000\n",
      "✓ SAVED! Best F1: 0.0266\n",
      "\n",
      "======================================================================\n",
      "EPOCH 2/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|██████████████████| 1875/1875 [06:26<00:00,  4.85it/s, loss=6.599]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.5589\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:13<00:00, 15.38it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:20<00:00, 14.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1264 | Val F1: 0.0899 | Gap: 0.0365 | EM: 0.0367\n",
      "✓ SAVED! Best F1: 0.0899\n",
      "\n",
      "======================================================================\n",
      "EPOCH 3/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.14it/s, loss=5.427]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.1364\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:17<00:00, 11.12it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:26<00:00, 11.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1573 | Val F1: 0.1300 | Gap: 0.0273 | EM: 0.0633\n",
      "✓ SAVED! Best F1: 0.1300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 4/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████████████| 1875/1875 [06:56<00:00,  4.50it/s, loss=5.888]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.8278\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:17<00:00, 11.27it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:33<00:00,  8.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1966 | Val F1: 0.1503 | Gap: 0.0463 | EM: 0.0600\n",
      "✓ SAVED! Best F1: 0.1503\n",
      "\n",
      "======================================================================\n",
      "EPOCH 5/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5: 100%|██████████████████| 1875/1875 [07:04<00:00,  4.42it/s, loss=5.354]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.5469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:14<00:00, 13.67it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:23<00:00, 12.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2248 | Val F1: 0.1925 | Gap: 0.0324 | EM: 0.0800\n",
      "✓ SAVED! Best F1: 0.1925\n",
      "\n",
      "======================================================================\n",
      "EPOCH 6/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6: 100%|██████████████████| 1875/1875 [07:04<00:00,  4.42it/s, loss=5.051]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.2606\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:13<00:00, 15.32it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:22<00:00, 13.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2613 | Val F1: 0.2105 | Gap: 0.0507 | EM: 0.0867\n",
      "✓ SAVED! Best F1: 0.2105\n",
      "\n",
      "======================================================================\n",
      "EPOCH 7/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7: 100%|██████████████████| 1875/1875 [07:06<00:00,  4.39it/s, loss=4.945]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.9125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:11<00:00, 17.28it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.3554 | Val F1: 0.2575 | Gap: 0.0979 | EM: 0.1167\n",
      "✓ SAVED! Best F1: 0.2575\n",
      "\n",
      "======================================================================\n",
      "EPOCH 8/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8: 100%|██████████████████| 1875/1875 [07:10<00:00,  4.35it/s, loss=4.309]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.4344\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 18.55it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.3711 | Val F1: 0.3208 | Gap: 0.0503 | EM: 0.1600\n",
      "✓ SAVED! Best F1: 0.3208\n",
      "\n",
      "======================================================================\n",
      "EPOCH 9/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9: 100%|██████████████████| 1875/1875 [06:40<00:00,  4.68it/s, loss=3.842]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.9369\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 20.01it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:16<00:00, 18.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4371 | Val F1: 0.3565 | Gap: 0.0805 | EM: 0.1800\n",
      "✓ SAVED! Best F1: 0.3565\n",
      "\n",
      "======================================================================\n",
      "EPOCH 10/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=3.815]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.6026\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.03it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 21.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5041 | Val F1: 0.3925 | Gap: 0.1116 | EM: 0.1967\n",
      "✓ SAVED! Best F1: 0.3925\n",
      "\n",
      "======================================================================\n",
      "EPOCH 11/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=3.603]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.4145\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.48it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5090 | Val F1: 0.4167 | Gap: 0.0924 | EM: 0.2300\n",
      "✓ SAVED! Best F1: 0.4167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 12/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12: 100%|█████████████████| 1875/1875 [05:33<00:00,  5.62it/s, loss=3.586]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.2826\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 18.59it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:17<00:00, 17.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4953 | Val F1: 0.4010 | Gap: 0.0943 | EM: 0.2100\n",
      "\n",
      "======================================================================\n",
      "EPOCH 13/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13: 100%|█████████████████| 1875/1875 [05:11<00:00,  6.02it/s, loss=3.586]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.1859\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.14it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5358 | Val F1: 0.4412 | Gap: 0.0946 | EM: 0.2467\n",
      "✓ SAVED! Best F1: 0.4412\n",
      "\n",
      "======================================================================\n",
      "EPOCH 14/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14: 100%|█████████████████| 1875/1875 [05:40<00:00,  5.50it/s, loss=3.223]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.0997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 18.96it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:16<00:00, 18.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5408 | Val F1: 0.4305 | Gap: 0.1103 | EM: 0.2400\n",
      "\n",
      "======================================================================\n",
      "EPOCH 15/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.783]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.0306\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.99it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5245 | Val F1: 0.4553 | Gap: 0.0693 | EM: 0.2667\n",
      "✓ SAVED! Best F1: 0.4553\n",
      "\n",
      "======================================================================\n",
      "EPOCH 16/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16: 100%|█████████████████| 1875/1875 [05:23<00:00,  5.80it/s, loss=3.138]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.9674\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 20.15it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5621 | Val F1: 0.4519 | Gap: 0.1102 | EM: 0.2667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 17/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.738]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.9066\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.22it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 27.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6052 | Val F1: 0.4739 | Gap: 0.1313 | EM: 0.2900\n",
      "✓ SAVED! Best F1: 0.4739\n",
      "\n",
      "======================================================================\n",
      "EPOCH 18/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18: 100%|█████████████████| 1875/1875 [05:24<00:00,  5.79it/s, loss=2.928]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.8524\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.63it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6317 | Val F1: 0.4563 | Gap: 0.1755 | EM: 0.2700\n",
      "\n",
      "======================================================================\n",
      "EPOCH 19/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.641]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.8021\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.01it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6034 | Val F1: 0.4986 | Gap: 0.1048 | EM: 0.2933\n",
      "✓ SAVED! Best F1: 0.4986\n",
      "\n",
      "======================================================================\n",
      "EPOCH 20/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20: 100%|█████████████████| 1875/1875 [05:18<00:00,  5.90it/s, loss=2.516]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7552\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 18.95it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:17<00:00, 17.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6233 | Val F1: 0.4454 | Gap: 0.1779 | EM: 0.2433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 21/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.726]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7071\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.93it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6759 | Val F1: 0.4600 | Gap: 0.2159 | EM: 0.2733\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: manually suppress the fire\n",
      "  F1: 1.000\n",
      "\n",
      "======================================================================\n",
      "EPOCH 22/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.721]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6705\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.91it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 21.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6054 | Val F1: 0.4610 | Gap: 0.1444 | EM: 0.2667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 23/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23: 100%|█████████████████| 1875/1875 [05:17<00:00,  5.91it/s, loss=2.596]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6352\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.10it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 21.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6928 | Val F1: 0.5020 | Gap: 0.1908 | EM: 0.3033\n",
      "✓ SAVED! Best F1: 0.5020\n",
      "\n",
      "======================================================================\n",
      "EPOCH 24/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24: 100%|█████████████████| 1875/1875 [05:10<00:00,  6.05it/s, loss=2.696]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5999\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.46it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6841 | Val F1: 0.4826 | Gap: 0.2015 | EM: 0.2933\n",
      "\n",
      "======================================================================\n",
      "EPOCH 25/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25: 100%|█████████████████| 1875/1875 [05:24<00:00,  5.79it/s, loss=2.533]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5649\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.97it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6926 | Val F1: 0.4818 | Gap: 0.2108 | EM: 0.2900\n",
      "\n",
      "======================================================================\n",
      "EPOCH 26/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=2.732]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5279\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.94it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7311 | Val F1: 0.5098 | Gap: 0.2213 | EM: 0.3067\n",
      "✓ SAVED! Best F1: 0.5098\n",
      "\n",
      "======================================================================\n",
      "EPOCH 27/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27: 100%|█████████████████| 1875/1875 [05:22<00:00,  5.81it/s, loss=2.534]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4973\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.76it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7455 | Val F1: 0.4701 | Gap: 0.2754 | EM: 0.2900\n",
      "\n",
      "======================================================================\n",
      "EPOCH 28/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28: 100%|█████████████████| 1875/1875 [06:11<00:00,  5.04it/s, loss=2.455]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4713\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.76it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7068 | Val F1: 0.4992 | Gap: 0.2077 | EM: 0.3133\n",
      "\n",
      "======================================================================\n",
      "EPOCH 29/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29: 100%|█████████████████| 1875/1875 [07:02<00:00,  4.43it/s, loss=2.476]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4427\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.96it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7410 | Val F1: 0.5161 | Gap: 0.2249 | EM: 0.3300\n",
      "✓ SAVED! Best F1: 0.5161\n",
      "\n",
      "======================================================================\n",
      "EPOCH 30/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30: 100%|█████████████████| 1875/1875 [06:57<00:00,  4.49it/s, loss=2.431]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4181\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.77it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7536 | Val F1: 0.4948 | Gap: 0.2588 | EM: 0.3100\n",
      "\n",
      "======================================================================\n",
      "EPOCH 31/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31: 100%|█████████████████| 1875/1875 [07:01<00:00,  4.45it/s, loss=2.490]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3926\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.65it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7928 | Val F1: 0.5107 | Gap: 0.2821 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 32/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32: 100%|█████████████████| 1875/1875 [07:01<00:00,  4.45it/s, loss=2.760]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3680\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.63it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7822 | Val F1: 0.5083 | Gap: 0.2739 | EM: 0.3233\n",
      "\n",
      "======================================================================\n",
      "EPOCH 33/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33: 100%|█████████████████| 1875/1875 [07:04<00:00,  4.42it/s, loss=2.447]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3437\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.49it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7907 | Val F1: 0.5065 | Gap: 0.2842 | EM: 0.3133\n",
      "\n",
      "======================================================================\n",
      "EPOCH 34/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34: 100%|█████████████████| 1875/1875 [06:03<00:00,  5.16it/s, loss=2.565]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3222\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.23it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8064 | Val F1: 0.5083 | Gap: 0.2981 | EM: 0.3200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 35/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.226]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2999\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.57it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7960 | Val F1: 0.5452 | Gap: 0.2508 | EM: 0.3533\n",
      "✓ SAVED! Best F1: 0.5452\n",
      "\n",
      "======================================================================\n",
      "EPOCH 36/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.166]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2821\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.59it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7973 | Val F1: 0.5263 | Gap: 0.2710 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 37/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=2.249]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2653\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.22it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8197 | Val F1: 0.5117 | Gap: 0.3080 | EM: 0.3067\n",
      "\n",
      "======================================================================\n",
      "EPOCH 38/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=2.335]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2430\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.83it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8207 | Val F1: 0.5028 | Gap: 0.3179 | EM: 0.3100\n",
      "\n",
      "======================================================================\n",
      "EPOCH 39/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.425]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2268\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.15it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8300 | Val F1: 0.5277 | Gap: 0.3023 | EM: 0.3100\n",
      "\n",
      "======================================================================\n",
      "EPOCH 40/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.109]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2088\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.07it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8264 | Val F1: 0.4961 | Gap: 0.3303 | EM: 0.3033\n",
      "\n",
      "======================================================================\n",
      "EPOCH 41/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.343]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1914\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.33it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8300 | Val F1: 0.5360 | Gap: 0.2940 | EM: 0.3367\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the fire\n",
      "  F1: 0.500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 42/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.195]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1765\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.22it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8679 | Val F1: 0.5065 | Gap: 0.3615 | EM: 0.3100\n",
      "\n",
      "======================================================================\n",
      "EPOCH 43/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.100]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1606\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.62it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8327 | Val F1: 0.5242 | Gap: 0.3085 | EM: 0.3200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 44/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=2.040]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1481\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.39it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8323 | Val F1: 0.5256 | Gap: 0.3067 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 45/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.061]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.83it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8715 | Val F1: 0.5424 | Gap: 0.3290 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 46/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.211]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.42it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8778 | Val F1: 0.5240 | Gap: 0.3538 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 47/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.906]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1036\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.71it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8763 | Val F1: 0.5153 | Gap: 0.3610 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 48/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=2.264]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0919\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.63it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 27.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8822 | Val F1: 0.5631 | Gap: 0.3192 | EM: 0.3333\n",
      "✓ SAVED! Best F1: 0.5631\n",
      "\n",
      "======================================================================\n",
      "EPOCH 49/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.076]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0797\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.50it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 27.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8435 | Val F1: 0.5335 | Gap: 0.3100 | EM: 0.3200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 50/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.011]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0692\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.71it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8656 | Val F1: 0.5534 | Gap: 0.3122 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 51/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.017]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0537\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.72it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8778 | Val F1: 0.5424 | Gap: 0.3353 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 52/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.093]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0462\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.25it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8792 | Val F1: 0.5478 | Gap: 0.3314 | EM: 0.3467\n",
      "\n",
      "======================================================================\n",
      "EPOCH 53/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 53: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.902]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0346\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.18it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8819 | Val F1: 0.5543 | Gap: 0.3276 | EM: 0.3400\n",
      "\n",
      "======================================================================\n",
      "EPOCH 54/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 54: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.269]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0214\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.76it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8999 | Val F1: 0.5251 | Gap: 0.3748 | EM: 0.3100\n",
      "\n",
      "======================================================================\n",
      "EPOCH 55/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 55: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.011]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0124\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.61it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9100 | Val F1: 0.5297 | Gap: 0.3803 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 56/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 56: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.989]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.51it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9060 | Val F1: 0.5550 | Gap: 0.3511 | EM: 0.3533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 57/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 57: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=2.073]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.11it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8946 | Val F1: 0.5260 | Gap: 0.3686 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 58/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 58: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.810]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9820\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.63it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 33.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9043 | Val F1: 0.5431 | Gap: 0.3612 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 59/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.880]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9718\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.87it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 33.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9141 | Val F1: 0.5681 | Gap: 0.3460 | EM: 0.3400\n",
      "✓ SAVED! Best F1: 0.5681\n",
      "\n",
      "======================================================================\n",
      "EPOCH 60/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 60: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.984]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9642\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.95it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 34.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9022 | Val F1: 0.5199 | Gap: 0.3823 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 61/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 61: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=1.953]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9531\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:05<00:00, 34.19it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 33.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9350 | Val F1: 0.5409 | Gap: 0.3940 | EM: 0.3167\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the fire\n",
      "  F1: 0.500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 62/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 62: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=1.971]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9463\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.19it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9210 | Val F1: 0.5559 | Gap: 0.3651 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 63/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 63: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.068]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9372\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.65it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9185 | Val F1: 0.5394 | Gap: 0.3791 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 64/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 64: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.908]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9321\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.68it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9288 | Val F1: 0.5488 | Gap: 0.3800 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 65/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 65: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=1.984]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.95it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9562 | Val F1: 0.5688 | Gap: 0.3874 | EM: 0.3533\n",
      "✓ SAVED! Best F1: 0.5688\n",
      "\n",
      "======================================================================\n",
      "EPOCH 66/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 66: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.901]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9162\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.56it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9218 | Val F1: 0.5439 | Gap: 0.3779 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 67/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 67: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.907]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9070\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.88it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9469 | Val F1: 0.5581 | Gap: 0.3888 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 68/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 68: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=1.881]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8992\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 33.29it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9386 | Val F1: 0.5681 | Gap: 0.3706 | EM: 0.3867\n",
      "\n",
      "======================================================================\n",
      "EPOCH 69/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.895]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8936\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.11it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9504 | Val F1: 0.5575 | Gap: 0.3929 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 70/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 70: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.998]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8841\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.63it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9481 | Val F1: 0.5447 | Gap: 0.4033 | EM: 0.3400\n",
      "\n",
      "======================================================================\n",
      "FINAL RESULTS\n",
      "======================================================================\n",
      "Best Val F1: 56.9%\n",
      "Final Val F1: 54.5%\n",
      "Final EM: 34.0%\n",
      "Train-Val Gap: 0.4033\n",
      "Training for seed 1235\n",
      "Initializing token embeddings with GloVe...\n",
      "✓ Token embeddings initialized with GloVe\n",
      "Total parameters: 21.7M\n",
      "Trainable parameters: 21.7M\n",
      "\n",
      "======================================================================\n",
      "BASELINE (Standard LR)\n",
      "======================================================================\n",
      "\n",
      "Using differential LR: embeddings=0.1x, other=1.0x\n",
      "\n",
      "\n",
      "======================================================================\n",
      "EPOCH 1/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=7.098]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 8.5269\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:21<00:00,  9.44it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:33<00:00,  9.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.0259 | Val F1: 0.0261 | Gap: -0.0002 | EM: 0.0100\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the\n",
      "  F1: 0.000\n",
      "✓ SAVED! Best F1: 0.0261\n",
      "\n",
      "======================================================================\n",
      "EPOCH 2/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.14it/s, loss=6.070]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.5654\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:13<00:00, 14.88it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:22<00:00, 13.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1066 | Val F1: 0.0772 | Gap: 0.0294 | EM: 0.0300\n",
      "✓ SAVED! Best F1: 0.0772\n",
      "\n",
      "======================================================================\n",
      "EPOCH 3/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=6.463]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.1441\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:15<00:00, 13.19it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:24<00:00, 12.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1420 | Val F1: 0.1133 | Gap: 0.0287 | EM: 0.0333\n",
      "✓ SAVED! Best F1: 0.1133\n",
      "\n",
      "======================================================================\n",
      "EPOCH 4/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=5.966]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.8417\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:15<00:00, 12.72it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:23<00:00, 12.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1987 | Val F1: 0.1453 | Gap: 0.0534 | EM: 0.0467\n",
      "✓ SAVED! Best F1: 0.1453\n",
      "\n",
      "======================================================================\n",
      "EPOCH 5/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=5.361]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.5649\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 18.41it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:18<00:00, 16.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2304 | Val F1: 0.1677 | Gap: 0.0626 | EM: 0.0600\n",
      "✓ SAVED! Best F1: 0.1677\n",
      "\n",
      "======================================================================\n",
      "EPOCH 6/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=5.342]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.2752\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:13<00:00, 14.41it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:22<00:00, 13.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2639 | Val F1: 0.2019 | Gap: 0.0621 | EM: 0.0833\n",
      "✓ SAVED! Best F1: 0.2019\n",
      "\n",
      "======================================================================\n",
      "EPOCH 7/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=5.403]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.9349\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 18.61it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.3506 | Val F1: 0.2799 | Gap: 0.0707 | EM: 0.1267\n",
      "✓ SAVED! Best F1: 0.2799\n",
      "\n",
      "======================================================================\n",
      "EPOCH 8/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=3.877]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.4560\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 22.17it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4224 | Val F1: 0.3355 | Gap: 0.0869 | EM: 0.1633\n",
      "✓ SAVED! Best F1: 0.3355\n",
      "\n",
      "======================================================================\n",
      "EPOCH 9/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=3.740]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.9780\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.57it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4595 | Val F1: 0.3578 | Gap: 0.1017 | EM: 0.1867\n",
      "✓ SAVED! Best F1: 0.3578\n",
      "\n",
      "======================================================================\n",
      "EPOCH 10/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.758]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.6167\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.44it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 23.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4844 | Val F1: 0.3934 | Gap: 0.0910 | EM: 0.2100\n",
      "✓ SAVED! Best F1: 0.3934\n",
      "\n",
      "======================================================================\n",
      "EPOCH 11/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.246]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.4163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.52it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 18.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5201 | Val F1: 0.4021 | Gap: 0.1179 | EM: 0.2200\n",
      "✓ SAVED! Best F1: 0.4021\n",
      "\n",
      "======================================================================\n",
      "EPOCH 12/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.191]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.2879\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.69it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5434 | Val F1: 0.4407 | Gap: 0.1027 | EM: 0.2533\n",
      "✓ SAVED! Best F1: 0.4407\n",
      "\n",
      "======================================================================\n",
      "EPOCH 13/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.258]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.1884\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.28it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5513 | Val F1: 0.4432 | Gap: 0.1081 | EM: 0.2767\n",
      "✓ SAVED! Best F1: 0.4432\n",
      "\n",
      "======================================================================\n",
      "EPOCH 14/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.196]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.1045\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.21it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 21.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5741 | Val F1: 0.4207 | Gap: 0.1534 | EM: 0.2200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 15/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.926]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.0265\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.45it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 21.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6128 | Val F1: 0.4679 | Gap: 0.1449 | EM: 0.2900\n",
      "✓ SAVED! Best F1: 0.4679\n",
      "\n",
      "======================================================================\n",
      "EPOCH 16/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.027]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.9647\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.07it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6016 | Val F1: 0.4803 | Gap: 0.1213 | EM: 0.2767\n",
      "✓ SAVED! Best F1: 0.4803\n",
      "\n",
      "======================================================================\n",
      "EPOCH 17/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.005]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.9071\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.71it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6253 | Val F1: 0.4771 | Gap: 0.1481 | EM: 0.2800\n",
      "\n",
      "======================================================================\n",
      "EPOCH 18/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=2.715]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.8508\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.86it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6315 | Val F1: 0.4595 | Gap: 0.1720 | EM: 0.2900\n",
      "\n",
      "======================================================================\n",
      "EPOCH 19/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.553]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.8046\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.05it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 23.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6628 | Val F1: 0.4811 | Gap: 0.1817 | EM: 0.2900\n",
      "✓ SAVED! Best F1: 0.4811\n",
      "\n",
      "======================================================================\n",
      "EPOCH 20/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.745]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7576\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 20.59it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6546 | Val F1: 0.4775 | Gap: 0.1771 | EM: 0.2867\n",
      "\n",
      "======================================================================\n",
      "EPOCH 21/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.714]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7126\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.83it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6755 | Val F1: 0.4931 | Gap: 0.1825 | EM: 0.3133\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: fire\n",
      "  F1: 0.500\n",
      "✓ SAVED! Best F1: 0.4931\n",
      "\n",
      "======================================================================\n",
      "EPOCH 22/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.826]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6757\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.30it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 27.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7060 | Val F1: 0.5281 | Gap: 0.1778 | EM: 0.3200\n",
      "✓ SAVED! Best F1: 0.5281\n",
      "\n",
      "======================================================================\n",
      "EPOCH 23/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.836]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6354\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:05<00:00, 36.72it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6976 | Val F1: 0.4937 | Gap: 0.2039 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 24/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.497]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5971\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.40it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7083 | Val F1: 0.4989 | Gap: 0.2094 | EM: 0.3000\n",
      "\n",
      "======================================================================\n",
      "EPOCH 25/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.536]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5669\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.29it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7074 | Val F1: 0.4789 | Gap: 0.2285 | EM: 0.2967\n",
      "\n",
      "======================================================================\n",
      "EPOCH 26/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.620]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5311\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.78it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7222 | Val F1: 0.5003 | Gap: 0.2219 | EM: 0.3233\n",
      "\n",
      "======================================================================\n",
      "EPOCH 27/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27: 100%|█████████████████| 1875/1875 [05:20<00:00,  5.86it/s, loss=2.771]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4992\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.63it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7189 | Val F1: 0.5128 | Gap: 0.2061 | EM: 0.3200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 28/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28: 100%|█████████████████| 1875/1875 [06:44<00:00,  4.64it/s, loss=2.741]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4720\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.70it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7300 | Val F1: 0.5345 | Gap: 0.1955 | EM: 0.3367\n",
      "✓ SAVED! Best F1: 0.5345\n",
      "\n",
      "======================================================================\n",
      "EPOCH 29/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.65it/s, loss=2.387]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4461\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.67it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7742 | Val F1: 0.5330 | Gap: 0.2412 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 30/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30: 100%|█████████████████| 1875/1875 [06:44<00:00,  4.63it/s, loss=2.493]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4137\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.95it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7684 | Val F1: 0.5177 | Gap: 0.2507 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 31/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31: 100%|█████████████████| 1875/1875 [06:05<00:00,  5.14it/s, loss=2.473]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3896\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.55it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7785 | Val F1: 0.5326 | Gap: 0.2460 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 32/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.221]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3684\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.17it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7802 | Val F1: 0.5300 | Gap: 0.2503 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 33/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.238]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3411\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.53it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7971 | Val F1: 0.5170 | Gap: 0.2801 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 34/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.279]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3202\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.02it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7583 | Val F1: 0.5248 | Gap: 0.2335 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 35/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.160]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2963\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.59it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7861 | Val F1: 0.5373 | Gap: 0.2489 | EM: 0.3433\n",
      "✓ SAVED! Best F1: 0.5373\n",
      "\n",
      "======================================================================\n",
      "EPOCH 36/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.278]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2774\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.05it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7828 | Val F1: 0.5258 | Gap: 0.2570 | EM: 0.3233\n",
      "\n",
      "======================================================================\n",
      "EPOCH 37/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.546]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2569\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.35it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7853 | Val F1: 0.5371 | Gap: 0.2482 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 38/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38: 100%|█████████████████| 1875/1875 [05:16<00:00,  5.93it/s, loss=2.788]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.56it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8100 | Val F1: 0.5254 | Gap: 0.2846 | EM: 0.3400\n",
      "\n",
      "======================================================================\n",
      "EPOCH 39/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.65it/s, loss=2.207]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2229\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.22it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8086 | Val F1: 0.5458 | Gap: 0.2629 | EM: 0.3600\n",
      "✓ SAVED! Best F1: 0.5458\n",
      "\n",
      "======================================================================\n",
      "EPOCH 40/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.211]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2068\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.16it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7946 | Val F1: 0.5398 | Gap: 0.2549 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 41/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.484]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1844\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.34it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7970 | Val F1: 0.5356 | Gap: 0.2614 | EM: 0.3467\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the fire\n",
      "  F1: 0.500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 42/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.259]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1705\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.85it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8390 | Val F1: 0.5449 | Gap: 0.2941 | EM: 0.3400\n",
      "\n",
      "======================================================================\n",
      "EPOCH 43/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.65it/s, loss=2.045]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1553\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.43it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8191 | Val F1: 0.5506 | Gap: 0.2684 | EM: 0.3367\n",
      "✓ SAVED! Best F1: 0.5506\n",
      "\n",
      "======================================================================\n",
      "EPOCH 44/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.387]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1397\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.99it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8267 | Val F1: 0.5484 | Gap: 0.2783 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 45/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.65it/s, loss=2.085]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1229\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.31it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8676 | Val F1: 0.5455 | Gap: 0.3221 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 46/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.255]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1077\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.21it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8677 | Val F1: 0.5487 | Gap: 0.3191 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 47/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.284]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0975\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.76it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8883 | Val F1: 0.5493 | Gap: 0.3390 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 48/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.217]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0834\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.76it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8866 | Val F1: 0.5435 | Gap: 0.3431 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 49/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.045]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0691\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.89it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8683 | Val F1: 0.5434 | Gap: 0.3249 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 50/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.65it/s, loss=1.989]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0569\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.94it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8748 | Val F1: 0.5551 | Gap: 0.3196 | EM: 0.3400\n",
      "✓ SAVED! Best F1: 0.5551\n",
      "\n",
      "======================================================================\n",
      "EPOCH 51/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.254]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0497\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.68it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8613 | Val F1: 0.5312 | Gap: 0.3301 | EM: 0.3133\n",
      "\n",
      "======================================================================\n",
      "EPOCH 52/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52: 100%|█████████████████| 1875/1875 [06:44<00:00,  4.64it/s, loss=2.041]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0353\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.89it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8874 | Val F1: 0.5762 | Gap: 0.3112 | EM: 0.3600\n",
      "✓ SAVED! Best F1: 0.5762\n",
      "\n",
      "======================================================================\n",
      "EPOCH 53/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 53: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.65it/s, loss=2.047]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0232\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.97it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8757 | Val F1: 0.5360 | Gap: 0.3397 | EM: 0.3400\n",
      "\n",
      "======================================================================\n",
      "EPOCH 54/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 54: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.65it/s, loss=2.105]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0149\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.91it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8935 | Val F1: 0.5603 | Gap: 0.3333 | EM: 0.3667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 55/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 55: 100%|█████████████████| 1875/1875 [06:44<00:00,  4.64it/s, loss=1.896]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0042\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.75it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8826 | Val F1: 0.5500 | Gap: 0.3327 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 56/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 56: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.64it/s, loss=2.035]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9929\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.75it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9122 | Val F1: 0.5640 | Gap: 0.3483 | EM: 0.3667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 57/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 57: 100%|█████████████████| 1875/1875 [06:44<00:00,  4.64it/s, loss=2.017]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.62it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9123 | Val F1: 0.5543 | Gap: 0.3579 | EM: 0.3600\n",
      "\n",
      "======================================================================\n",
      "EPOCH 58/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 58: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.65it/s, loss=1.896]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9732\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.45it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8951 | Val F1: 0.5693 | Gap: 0.3258 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 59/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.66it/s, loss=2.010]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9649\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.35it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9229 | Val F1: 0.5446 | Gap: 0.3783 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 60/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 60: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.65it/s, loss=2.013]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9554\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.82it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8986 | Val F1: 0.5441 | Gap: 0.3545 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 61/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 61: 100%|█████████████████| 1875/1875 [06:44<00:00,  4.64it/s, loss=2.021]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9476\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.94it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9252 | Val F1: 0.5688 | Gap: 0.3563 | EM: 0.3467\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: only in the fire\n",
      "  F1: 0.333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 62/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 62: 100%|█████████████████| 1875/1875 [06:44<00:00,  4.63it/s, loss=1.941]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9399\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.46it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9285 | Val F1: 0.5345 | Gap: 0.3941 | EM: 0.3200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 63/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 63: 100%|█████████████████| 1875/1875 [05:37<00:00,  5.56it/s, loss=1.943]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9321\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.52it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9591 | Val F1: 0.5623 | Gap: 0.3968 | EM: 0.3467\n",
      "\n",
      "======================================================================\n",
      "EPOCH 64/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 64: 100%|█████████████████| 1875/1875 [05:07<00:00,  6.11it/s, loss=2.059]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9221\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.84it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9175 | Val F1: 0.5538 | Gap: 0.3637 | EM: 0.3467\n",
      "\n",
      "======================================================================\n",
      "EPOCH 65/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 65: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.814]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9159\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.13it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9203 | Val F1: 0.5728 | Gap: 0.3474 | EM: 0.3633\n",
      "\n",
      "======================================================================\n",
      "EPOCH 66/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 66: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.908]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9077\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.18it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9246 | Val F1: 0.5657 | Gap: 0.3589 | EM: 0.3467\n",
      "\n",
      "======================================================================\n",
      "EPOCH 67/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 67: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.068]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8998\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 33.22it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 33.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9206 | Val F1: 0.5690 | Gap: 0.3516 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 68/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 68: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.915]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8904\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.03it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9207 | Val F1: 0.5653 | Gap: 0.3554 | EM: 0.3633\n",
      "\n",
      "======================================================================\n",
      "EPOCH 69/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.946]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8840\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.95it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 27.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9203 | Val F1: 0.5672 | Gap: 0.3531 | EM: 0.3633\n",
      "\n",
      "======================================================================\n",
      "EPOCH 70/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 70: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.041]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8788\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.74it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9361 | Val F1: 0.5683 | Gap: 0.3678 | EM: 0.3467\n",
      "\n",
      "======================================================================\n",
      "FINAL RESULTS\n",
      "======================================================================\n",
      "Best Val F1: 57.6%\n",
      "Final Val F1: 56.8%\n",
      "Final EM: 34.7%\n",
      "Train-Val Gap: 0.3678\n",
      "Training for seed 1236\n",
      "Initializing token embeddings with GloVe...\n",
      "✓ Token embeddings initialized with GloVe\n",
      "Total parameters: 21.7M\n",
      "Trainable parameters: 21.7M\n",
      "\n",
      "======================================================================\n",
      "BASELINE (Standard LR)\n",
      "======================================================================\n",
      "\n",
      "Using differential LR: embeddings=0.1x, other=1.0x\n",
      "\n",
      "\n",
      "======================================================================\n",
      "EPOCH 1/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=6.886]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 8.6963\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:24<00:00,  8.26it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:40<00:00,  7.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.0310 | Val F1: 0.0203 | Gap: 0.0107 | EM: 0.0100\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the\n",
      "  F1: 0.000\n",
      "✓ SAVED! Best F1: 0.0203\n",
      "\n",
      "======================================================================\n",
      "EPOCH 2/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=6.018]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.5821\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:20<00:00,  9.80it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:33<00:00,  8.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.0549 | Val F1: 0.0484 | Gap: 0.0065 | EM: 0.0233\n",
      "✓ SAVED! Best F1: 0.0484\n",
      "\n",
      "======================================================================\n",
      "EPOCH 3/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=5.961]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.1677\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:11<00:00, 18.07it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:19<00:00, 15.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1565 | Val F1: 0.1031 | Gap: 0.0534 | EM: 0.0367\n",
      "✓ SAVED! Best F1: 0.1031\n",
      "\n",
      "======================================================================\n",
      "EPOCH 4/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=5.497]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.8659\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:14<00:00, 13.63it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:23<00:00, 12.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1889 | Val F1: 0.1426 | Gap: 0.0463 | EM: 0.0400\n",
      "✓ SAVED! Best F1: 0.1426\n",
      "\n",
      "======================================================================\n",
      "EPOCH 5/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5: 100%|██████████████████| 1875/1875 [05:23<00:00,  5.79it/s, loss=5.693]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.5928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:17<00:00, 11.65it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:28<00:00, 10.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2370 | Val F1: 0.1727 | Gap: 0.0643 | EM: 0.0733\n",
      "✓ SAVED! Best F1: 0.1727\n",
      "\n",
      "======================================================================\n",
      "EPOCH 6/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6: 100%|██████████████████| 1875/1875 [05:43<00:00,  5.45it/s, loss=5.165]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.3211\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:16<00:00, 11.94it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:23<00:00, 12.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2664 | Val F1: 0.1931 | Gap: 0.0733 | EM: 0.0967\n",
      "✓ SAVED! Best F1: 0.1931\n",
      "\n",
      "======================================================================\n",
      "EPOCH 7/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=5.254]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.0161\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:13<00:00, 14.92it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:18<00:00, 16.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2993 | Val F1: 0.2274 | Gap: 0.0718 | EM: 0.0933\n",
      "✓ SAVED! Best F1: 0.2274\n",
      "\n",
      "======================================================================\n",
      "EPOCH 8/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8: 100%|██████████████████| 1875/1875 [05:11<00:00,  6.02it/s, loss=4.524]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.6169\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:13<00:00, 14.48it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:19<00:00, 15.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.3848 | Val F1: 0.2958 | Gap: 0.0890 | EM: 0.1400\n",
      "✓ SAVED! Best F1: 0.2958\n",
      "\n",
      "======================================================================\n",
      "EPOCH 9/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=4.350]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.1824\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 22.07it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:18<00:00, 16.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4309 | Val F1: 0.3193 | Gap: 0.1116 | EM: 0.1567\n",
      "✓ SAVED! Best F1: 0.3193\n",
      "\n",
      "======================================================================\n",
      "EPOCH 10/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.236]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.7878\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.48it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4635 | Val F1: 0.3826 | Gap: 0.0809 | EM: 0.2200\n",
      "✓ SAVED! Best F1: 0.3826\n",
      "\n",
      "======================================================================\n",
      "EPOCH 11/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.300]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.4993\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 19.47it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5057 | Val F1: 0.3913 | Gap: 0.1144 | EM: 0.2200\n",
      "✓ SAVED! Best F1: 0.3913\n",
      "\n",
      "======================================================================\n",
      "EPOCH 12/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=3.428]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.3333\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.43it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5159 | Val F1: 0.3904 | Gap: 0.1255 | EM: 0.2100\n",
      "\n",
      "======================================================================\n",
      "EPOCH 13/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=3.265]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.2151\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.11it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5414 | Val F1: 0.4458 | Gap: 0.0956 | EM: 0.2400\n",
      "✓ SAVED! Best F1: 0.4458\n",
      "\n",
      "======================================================================\n",
      "EPOCH 14/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.239]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.1272\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.26it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5316 | Val F1: 0.4462 | Gap: 0.0854 | EM: 0.2633\n",
      "✓ SAVED! Best F1: 0.4462\n",
      "\n",
      "======================================================================\n",
      "EPOCH 15/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.960]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.0457\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.32it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 21.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5707 | Val F1: 0.4559 | Gap: 0.1148 | EM: 0.2767\n",
      "✓ SAVED! Best F1: 0.4559\n",
      "\n",
      "======================================================================\n",
      "EPOCH 16/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=3.049]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.9727\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.01it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5710 | Val F1: 0.4393 | Gap: 0.1317 | EM: 0.2667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 17/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=3.205]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.9099\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.63it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:16<00:00, 18.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5620 | Val F1: 0.4592 | Gap: 0.1028 | EM: 0.2633\n",
      "✓ SAVED! Best F1: 0.4592\n",
      "\n",
      "======================================================================\n",
      "EPOCH 18/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.903]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.8541\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.03it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6015 | Val F1: 0.4626 | Gap: 0.1389 | EM: 0.2733\n",
      "✓ SAVED! Best F1: 0.4626\n",
      "\n",
      "======================================================================\n",
      "EPOCH 19/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=3.073]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7948\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.36it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6459 | Val F1: 0.4733 | Gap: 0.1726 | EM: 0.2967\n",
      "✓ SAVED! Best F1: 0.4733\n",
      "\n",
      "======================================================================\n",
      "EPOCH 20/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.952]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.73it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6379 | Val F1: 0.4528 | Gap: 0.1851 | EM: 0.2867\n",
      "\n",
      "======================================================================\n",
      "EPOCH 21/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.756]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7044\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.48it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6727 | Val F1: 0.4793 | Gap: 0.1935 | EM: 0.3100\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the fire\n",
      "  F1: 0.500\n",
      "✓ SAVED! Best F1: 0.4793\n",
      "\n",
      "======================================================================\n",
      "EPOCH 22/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.424]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6637\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.48it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6757 | Val F1: 0.4897 | Gap: 0.1860 | EM: 0.2933\n",
      "✓ SAVED! Best F1: 0.4897\n",
      "\n",
      "======================================================================\n",
      "EPOCH 23/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.567]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6243\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.66it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 21.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6966 | Val F1: 0.4820 | Gap: 0.2146 | EM: 0.2833\n",
      "\n",
      "======================================================================\n",
      "EPOCH 24/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.989]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5880\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.84it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6569 | Val F1: 0.4865 | Gap: 0.1704 | EM: 0.2933\n",
      "\n",
      "======================================================================\n",
      "EPOCH 25/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.610]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5570\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.23it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 23.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7237 | Val F1: 0.5080 | Gap: 0.2158 | EM: 0.3033\n",
      "✓ SAVED! Best F1: 0.5080\n",
      "\n",
      "======================================================================\n",
      "EPOCH 26/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.463]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5158\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.92it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7315 | Val F1: 0.5411 | Gap: 0.1903 | EM: 0.3300\n",
      "✓ SAVED! Best F1: 0.5411\n",
      "\n",
      "======================================================================\n",
      "EPOCH 27/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.835]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4913\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.12it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 27.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7182 | Val F1: 0.5233 | Gap: 0.1949 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 28/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.416]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4605\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.06it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7393 | Val F1: 0.5153 | Gap: 0.2240 | EM: 0.3133\n",
      "\n",
      "======================================================================\n",
      "EPOCH 29/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.343]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4343\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.03it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7508 | Val F1: 0.5261 | Gap: 0.2248 | EM: 0.3233\n",
      "\n",
      "======================================================================\n",
      "EPOCH 30/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.332]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4073\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.31it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7256 | Val F1: 0.5087 | Gap: 0.2170 | EM: 0.3200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 31/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.525]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3795\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.47it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 27.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7338 | Val F1: 0.5234 | Gap: 0.2104 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 32/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32: 100%|█████████████████| 1875/1875 [05:07<00:00,  6.10it/s, loss=2.158]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3556\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 22.00it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 21.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7889 | Val F1: 0.5628 | Gap: 0.2262 | EM: 0.3600\n",
      "✓ SAVED! Best F1: 0.5628\n",
      "\n",
      "======================================================================\n",
      "EPOCH 33/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.130]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3352\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.59it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8077 | Val F1: 0.5458 | Gap: 0.2619 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 34/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34: 100%|█████████████████| 1875/1875 [05:07<00:00,  6.11it/s, loss=2.372]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3140\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.74it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7949 | Val F1: 0.5689 | Gap: 0.2261 | EM: 0.3533\n",
      "✓ SAVED! Best F1: 0.5689\n",
      "\n",
      "======================================================================\n",
      "EPOCH 35/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35: 100%|█████████████████| 1875/1875 [05:07<00:00,  6.10it/s, loss=2.335]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.36it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 21.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7888 | Val F1: 0.5314 | Gap: 0.2573 | EM: 0.3200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 36/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.378]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2723\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.88it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8371 | Val F1: 0.5970 | Gap: 0.2401 | EM: 0.3867\n",
      "✓ SAVED! Best F1: 0.5970\n",
      "\n",
      "======================================================================\n",
      "EPOCH 37/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.345]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2521\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.70it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8304 | Val F1: 0.5415 | Gap: 0.2889 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 38/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.218]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2339\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.54it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8139 | Val F1: 0.5611 | Gap: 0.2527 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 39/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.311]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2173\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.62it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8082 | Val F1: 0.5624 | Gap: 0.2458 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 40/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40: 100%|█████████████████| 1875/1875 [05:10<00:00,  6.04it/s, loss=2.316]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2012\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.03it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8311 | Val F1: 0.5633 | Gap: 0.2678 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 41/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41: 100%|█████████████████| 1875/1875 [05:09<00:00,  6.05it/s, loss=2.146]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1861\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.04it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8502 | Val F1: 0.5420 | Gap: 0.3081 | EM: 0.3400\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the fire\n",
      "  F1: 0.500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 42/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42: 100%|█████████████████| 1875/1875 [05:20<00:00,  5.85it/s, loss=2.193]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1703\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.32it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8547 | Val F1: 0.5549 | Gap: 0.2999 | EM: 0.3400\n",
      "\n",
      "======================================================================\n",
      "EPOCH 43/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43: 100%|█████████████████| 1875/1875 [05:52<00:00,  5.32it/s, loss=2.153]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1558\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.13it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8501 | Val F1: 0.5728 | Gap: 0.2773 | EM: 0.3600\n",
      "\n",
      "======================================================================\n",
      "EPOCH 44/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44: 100%|█████████████████| 1875/1875 [05:28<00:00,  5.71it/s, loss=2.039]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1366\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.32it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8576 | Val F1: 0.5657 | Gap: 0.2920 | EM: 0.3667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 45/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.188]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1226\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.94it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8617 | Val F1: 0.5529 | Gap: 0.3088 | EM: 0.3667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 46/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.088]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1096\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.36it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8763 | Val F1: 0.5487 | Gap: 0.3275 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 47/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.131]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.74it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8852 | Val F1: 0.5686 | Gap: 0.3166 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 48/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.140]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0855\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.39it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8821 | Val F1: 0.5512 | Gap: 0.3308 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 49/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.076]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0749\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.03it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8557 | Val F1: 0.5864 | Gap: 0.2692 | EM: 0.3700\n",
      "\n",
      "======================================================================\n",
      "EPOCH 50/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.208]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0642\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.92it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8784 | Val F1: 0.5767 | Gap: 0.3017 | EM: 0.3767\n",
      "\n",
      "======================================================================\n",
      "EPOCH 51/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.031]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0481\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.87it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8639 | Val F1: 0.5717 | Gap: 0.2922 | EM: 0.3633\n",
      "\n",
      "======================================================================\n",
      "EPOCH 52/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.109]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0400\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.74it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9299 | Val F1: 0.5814 | Gap: 0.3485 | EM: 0.3733\n",
      "\n",
      "======================================================================\n",
      "EPOCH 53/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 53: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.945]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0284\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.70it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8719 | Val F1: 0.5503 | Gap: 0.3216 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 54/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 54: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.904]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.72it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8972 | Val F1: 0.5542 | Gap: 0.3430 | EM: 0.3467\n",
      "\n",
      "======================================================================\n",
      "EPOCH 55/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 55: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.031]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0078\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.51it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9099 | Val F1: 0.5821 | Gap: 0.3278 | EM: 0.3667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 56/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 56: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=2.024]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9977\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.37it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 33.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9058 | Val F1: 0.5922 | Gap: 0.3136 | EM: 0.3933\n",
      "\n",
      "======================================================================\n",
      "EPOCH 57/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 57: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.080]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9885\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.27it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9267 | Val F1: 0.5792 | Gap: 0.3475 | EM: 0.3700\n",
      "\n",
      "======================================================================\n",
      "EPOCH 58/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 58: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.883]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9750\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.15it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9049 | Val F1: 0.5658 | Gap: 0.3391 | EM: 0.3633\n",
      "\n",
      "======================================================================\n",
      "EPOCH 59/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.996]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9696\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.62it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9253 | Val F1: 0.5831 | Gap: 0.3422 | EM: 0.3700\n",
      "\n",
      "======================================================================\n",
      "EPOCH 60/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 60: 100%|█████████████████| 1875/1875 [06:18<00:00,  4.95it/s, loss=2.047]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9600\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.28it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9264 | Val F1: 0.5798 | Gap: 0.3466 | EM: 0.3667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 61/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 61: 100%|█████████████████| 1875/1875 [05:30<00:00,  5.67it/s, loss=1.875]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9505\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.98it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9198 | Val F1: 0.5790 | Gap: 0.3407 | EM: 0.3533\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the fire\n",
      "  F1: 0.500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 62/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 62: 100%|█████████████████| 1875/1875 [05:09<00:00,  6.06it/s, loss=2.016]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9436\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.53it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9123 | Val F1: 0.5960 | Gap: 0.3163 | EM: 0.3700\n",
      "\n",
      "======================================================================\n",
      "EPOCH 63/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 63: 100%|█████████████████| 1875/1875 [06:18<00:00,  4.95it/s, loss=1.986]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9342\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.75it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9285 | Val F1: 0.5795 | Gap: 0.3490 | EM: 0.3633\n",
      "\n",
      "======================================================================\n",
      "EPOCH 64/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 64: 100%|█████████████████| 1875/1875 [06:19<00:00,  4.94it/s, loss=1.892]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9239\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.90it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9112 | Val F1: 0.5942 | Gap: 0.3171 | EM: 0.3733\n",
      "\n",
      "======================================================================\n",
      "EPOCH 65/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 65: 100%|█████████████████| 1875/1875 [05:13<00:00,  5.98it/s, loss=1.875]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9189\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.90it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9059 | Val F1: 0.5722 | Gap: 0.3337 | EM: 0.3667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 66/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 66: 100%|█████████████████| 1875/1875 [06:20<00:00,  4.93it/s, loss=1.954]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9098\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.28it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9237 | Val F1: 0.5654 | Gap: 0.3583 | EM: 0.3533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 67/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 67: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.895]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9018\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.10it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9400 | Val F1: 0.5762 | Gap: 0.3638 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 68/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 68: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.807]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.72it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9594 | Val F1: 0.5812 | Gap: 0.3782 | EM: 0.3867\n",
      "\n",
      "======================================================================\n",
      "EPOCH 69/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=1.881]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8883\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.32it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9506 | Val F1: 0.5817 | Gap: 0.3689 | EM: 0.3800\n",
      "\n",
      "======================================================================\n",
      "EPOCH 70/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 70: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.11it/s, loss=1.880]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8830\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.65it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9564 | Val F1: 0.5923 | Gap: 0.3641 | EM: 0.3900\n",
      "\n",
      "======================================================================\n",
      "FINAL RESULTS\n",
      "======================================================================\n",
      "Best Val F1: 59.7%\n",
      "Final Val F1: 59.2%\n",
      "Final EM: 39.0%\n",
      "Train-Val Gap: 0.3641\n",
      "Training for seed 1237\n",
      "Initializing token embeddings with GloVe...\n",
      "✓ Token embeddings initialized with GloVe\n",
      "Total parameters: 21.7M\n",
      "Trainable parameters: 21.7M\n",
      "\n",
      "======================================================================\n",
      "BASELINE (Standard LR)\n",
      "======================================================================\n",
      "\n",
      "Using differential LR: embeddings=0.1x, other=1.0x\n",
      "\n",
      "\n",
      "======================================================================\n",
      "EPOCH 1/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=6.964]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 8.6714\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:18<00:00, 10.56it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:30<00:00,  9.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.0213 | Val F1: 0.0230 | Gap: -0.0018 | EM: 0.0133\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the\n",
      "  F1: 0.000\n",
      "✓ SAVED! Best F1: 0.0230\n",
      "\n",
      "======================================================================\n",
      "EPOCH 2/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.14it/s, loss=5.802]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.5676\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:15<00:00, 13.30it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:27<00:00, 10.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1114 | Val F1: 0.0634 | Gap: 0.0480 | EM: 0.0233\n",
      "✓ SAVED! Best F1: 0.0634\n",
      "\n",
      "======================================================================\n",
      "EPOCH 3/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=5.531]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.1424\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:15<00:00, 12.93it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:24<00:00, 12.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1559 | Val F1: 0.1247 | Gap: 0.0312 | EM: 0.0533\n",
      "✓ SAVED! Best F1: 0.1247\n",
      "\n",
      "======================================================================\n",
      "EPOCH 4/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=6.084]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.8299\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:15<00:00, 12.88it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:28<00:00, 10.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2181 | Val F1: 0.1658 | Gap: 0.0523 | EM: 0.0633\n",
      "✓ SAVED! Best F1: 0.1658\n",
      "\n",
      "======================================================================\n",
      "EPOCH 5/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=5.736]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.5397\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:14<00:00, 14.26it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:26<00:00, 11.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2269 | Val F1: 0.1743 | Gap: 0.0526 | EM: 0.0700\n",
      "✓ SAVED! Best F1: 0.1743\n",
      "\n",
      "======================================================================\n",
      "EPOCH 6/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6: 100%|██████████████████| 1875/1875 [05:10<00:00,  6.05it/s, loss=5.207]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.2389\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:15<00:00, 13.21it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:18<00:00, 16.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2420 | Val F1: 0.2099 | Gap: 0.0321 | EM: 0.0733\n",
      "✓ SAVED! Best F1: 0.2099\n",
      "\n",
      "======================================================================\n",
      "EPOCH 7/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=4.361]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.8616\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:11<00:00, 17.82it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:19<00:00, 15.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.3254 | Val F1: 0.2561 | Gap: 0.0693 | EM: 0.1100\n",
      "✓ SAVED! Best F1: 0.2561\n",
      "\n",
      "======================================================================\n",
      "EPOCH 8/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8: 100%|██████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=3.901]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.3936\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.11it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 21.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4090 | Val F1: 0.3330 | Gap: 0.0759 | EM: 0.1600\n",
      "✓ SAVED! Best F1: 0.3330\n",
      "\n",
      "======================================================================\n",
      "EPOCH 9/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=3.188]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.9592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 20.54it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 21.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4370 | Val F1: 0.3607 | Gap: 0.0763 | EM: 0.1767\n",
      "✓ SAVED! Best F1: 0.3607\n",
      "\n",
      "======================================================================\n",
      "EPOCH 10/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.455]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.6263\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.65it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4797 | Val F1: 0.3950 | Gap: 0.0847 | EM: 0.2067\n",
      "✓ SAVED! Best F1: 0.3950\n",
      "\n",
      "======================================================================\n",
      "EPOCH 11/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11: 100%|█████████████████| 1875/1875 [05:12<00:00,  6.00it/s, loss=3.218]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.4158\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 19.24it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:16<00:00, 18.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5122 | Val F1: 0.4108 | Gap: 0.1014 | EM: 0.2300\n",
      "✓ SAVED! Best F1: 0.4108\n",
      "\n",
      "======================================================================\n",
      "EPOCH 12/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12: 100%|█████████████████| 1875/1875 [06:35<00:00,  4.74it/s, loss=3.450]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.2855\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.33it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5302 | Val F1: 0.4397 | Gap: 0.0904 | EM: 0.2667\n",
      "✓ SAVED! Best F1: 0.4397\n",
      "\n",
      "======================================================================\n",
      "EPOCH 13/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13: 100%|█████████████████| 1875/1875 [06:50<00:00,  4.57it/s, loss=3.342]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.1780\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 19.93it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:17<00:00, 16.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5096 | Val F1: 0.4283 | Gap: 0.0813 | EM: 0.2400\n",
      "\n",
      "======================================================================\n",
      "EPOCH 14/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14: 100%|█████████████████| 1875/1875 [06:47<00:00,  4.60it/s, loss=3.094]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.0959\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 19.34it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:16<00:00, 18.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5553 | Val F1: 0.4425 | Gap: 0.1127 | EM: 0.2533\n",
      "✓ SAVED! Best F1: 0.4425\n",
      "\n",
      "======================================================================\n",
      "EPOCH 15/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15: 100%|█████████████████| 1875/1875 [06:48<00:00,  4.59it/s, loss=2.942]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.0225\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.19it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5552 | Val F1: 0.4522 | Gap: 0.1030 | EM: 0.2667\n",
      "✓ SAVED! Best F1: 0.4522\n",
      "\n",
      "======================================================================\n",
      "EPOCH 16/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16: 100%|█████████████████| 1875/1875 [06:48<00:00,  4.59it/s, loss=2.928]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.9576\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.98it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5876 | Val F1: 0.4694 | Gap: 0.1181 | EM: 0.2833\n",
      "✓ SAVED! Best F1: 0.4694\n",
      "\n",
      "======================================================================\n",
      "EPOCH 17/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17: 100%|█████████████████| 1875/1875 [06:48<00:00,  4.59it/s, loss=3.010]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.8977\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:10<00:00, 18.29it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5757 | Val F1: 0.4803 | Gap: 0.0955 | EM: 0.2867\n",
      "✓ SAVED! Best F1: 0.4803\n",
      "\n",
      "======================================================================\n",
      "EPOCH 18/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18: 100%|█████████████████| 1875/1875 [05:42<00:00,  5.47it/s, loss=2.958]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.8424\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.69it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5734 | Val F1: 0.4606 | Gap: 0.1127 | EM: 0.2800\n",
      "\n",
      "======================================================================\n",
      "EPOCH 19/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|█████████████████| 1875/1875 [06:42<00:00,  4.65it/s, loss=3.011]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7909\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.88it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5904 | Val F1: 0.4624 | Gap: 0.1280 | EM: 0.2667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 20/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20: 100%|█████████████████| 1875/1875 [06:40<00:00,  4.68it/s, loss=2.977]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7448\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.59it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6371 | Val F1: 0.4540 | Gap: 0.1831 | EM: 0.2700\n",
      "\n",
      "======================================================================\n",
      "EPOCH 21/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21: 100%|█████████████████| 1875/1875 [06:43<00:00,  4.64it/s, loss=2.562]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6977\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.55it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 21.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6633 | Val F1: 0.4969 | Gap: 0.1665 | EM: 0.3100\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: manually suppress the fire\n",
      "  F1: 1.000\n",
      "✓ SAVED! Best F1: 0.4969\n",
      "\n",
      "======================================================================\n",
      "EPOCH 22/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22: 100%|█████████████████| 1875/1875 [06:35<00:00,  4.75it/s, loss=2.438]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6585\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.20it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6899 | Val F1: 0.4765 | Gap: 0.2134 | EM: 0.2833\n",
      "\n",
      "======================================================================\n",
      "EPOCH 23/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23: 100%|█████████████████| 1875/1875 [05:43<00:00,  5.46it/s, loss=2.547]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6275\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.40it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6707 | Val F1: 0.4940 | Gap: 0.1767 | EM: 0.3033\n",
      "\n",
      "======================================================================\n",
      "EPOCH 24/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24: 100%|█████████████████| 1875/1875 [06:46<00:00,  4.62it/s, loss=2.740]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 20.11it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6884 | Val F1: 0.4886 | Gap: 0.1998 | EM: 0.2833\n",
      "\n",
      "======================================================================\n",
      "EPOCH 25/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25: 100%|█████████████████| 1875/1875 [06:45<00:00,  4.62it/s, loss=2.703]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5553\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.88it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7044 | Val F1: 0.4832 | Gap: 0.2212 | EM: 0.2933\n",
      "\n",
      "======================================================================\n",
      "EPOCH 26/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26: 100%|█████████████████| 1875/1875 [06:45<00:00,  4.62it/s, loss=2.782]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5209\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.77it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7343 | Val F1: 0.5168 | Gap: 0.2174 | EM: 0.3300\n",
      "✓ SAVED! Best F1: 0.5168\n",
      "\n",
      "======================================================================\n",
      "EPOCH 27/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27: 100%|█████████████████| 1875/1875 [06:40<00:00,  4.68it/s, loss=2.232]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4927\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 20.30it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7055 | Val F1: 0.5199 | Gap: 0.1857 | EM: 0.3167\n",
      "✓ SAVED! Best F1: 0.5199\n",
      "\n",
      "======================================================================\n",
      "EPOCH 28/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28: 100%|█████████████████| 1875/1875 [06:52<00:00,  4.55it/s, loss=2.404]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4599\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.29it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7314 | Val F1: 0.5154 | Gap: 0.2160 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 29/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29: 100%|█████████████████| 1875/1875 [06:51<00:00,  4.56it/s, loss=2.526]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4339\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.66it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7337 | Val F1: 0.5072 | Gap: 0.2265 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 30/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30: 100%|█████████████████| 1875/1875 [06:47<00:00,  4.60it/s, loss=2.253]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4083\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.82it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7647 | Val F1: 0.5104 | Gap: 0.2543 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 31/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31: 100%|█████████████████| 1875/1875 [06:47<00:00,  4.60it/s, loss=2.461]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3835\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.79it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7413 | Val F1: 0.5208 | Gap: 0.2206 | EM: 0.3400\n",
      "✓ SAVED! Best F1: 0.5208\n",
      "\n",
      "======================================================================\n",
      "EPOCH 32/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32: 100%|█████████████████| 1875/1875 [06:51<00:00,  4.56it/s, loss=2.192]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3616\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.97it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7487 | Val F1: 0.5455 | Gap: 0.2032 | EM: 0.3333\n",
      "✓ SAVED! Best F1: 0.5455\n",
      "\n",
      "======================================================================\n",
      "EPOCH 33/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33: 100%|█████████████████| 1875/1875 [06:54<00:00,  4.53it/s, loss=2.500]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3396\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.52it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8058 | Val F1: 0.5345 | Gap: 0.2712 | EM: 0.3533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 34/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34: 100%|█████████████████| 1875/1875 [06:20<00:00,  4.93it/s, loss=2.577]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3166\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.92it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7606 | Val F1: 0.5291 | Gap: 0.2316 | EM: 0.3400\n",
      "\n",
      "======================================================================\n",
      "EPOCH 35/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35: 100%|█████████████████| 1875/1875 [06:50<00:00,  4.57it/s, loss=2.185]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2933\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.90it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7854 | Val F1: 0.5353 | Gap: 0.2501 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 36/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36: 100%|█████████████████| 1875/1875 [06:55<00:00,  4.52it/s, loss=2.356]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2774\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.07it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8212 | Val F1: 0.5467 | Gap: 0.2746 | EM: 0.3433\n",
      "✓ SAVED! Best F1: 0.5467\n",
      "\n",
      "======================================================================\n",
      "EPOCH 37/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37: 100%|█████████████████| 1875/1875 [06:55<00:00,  4.52it/s, loss=2.148]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2595\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.26it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8370 | Val F1: 0.5524 | Gap: 0.2846 | EM: 0.3667\n",
      "✓ SAVED! Best F1: 0.5524\n",
      "\n",
      "======================================================================\n",
      "EPOCH 38/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38: 100%|█████████████████| 1875/1875 [06:53<00:00,  4.53it/s, loss=2.187]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2408\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.09it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 27.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8317 | Val F1: 0.5540 | Gap: 0.2777 | EM: 0.3733\n",
      "✓ SAVED! Best F1: 0.5540\n",
      "\n",
      "======================================================================\n",
      "EPOCH 39/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39: 100%|█████████████████| 1875/1875 [06:55<00:00,  4.51it/s, loss=2.149]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.39it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8392 | Val F1: 0.5071 | Gap: 0.3322 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 40/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40: 100%|█████████████████| 1875/1875 [06:54<00:00,  4.52it/s, loss=2.219]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2018\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.71it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8460 | Val F1: 0.5355 | Gap: 0.3106 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 41/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41: 100%|█████████████████| 1875/1875 [06:53<00:00,  4.53it/s, loss=2.149]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1842\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.11it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8412 | Val F1: 0.5592 | Gap: 0.2820 | EM: 0.3600\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the fire\n",
      "  F1: 0.500\n",
      "✓ SAVED! Best F1: 0.5592\n",
      "\n",
      "======================================================================\n",
      "EPOCH 42/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42: 100%|█████████████████| 1875/1875 [06:58<00:00,  4.48it/s, loss=1.907]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1724\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.17it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8399 | Val F1: 0.5454 | Gap: 0.2945 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 43/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43: 100%|█████████████████| 1875/1875 [06:55<00:00,  4.51it/s, loss=2.190]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1580\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.47it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8339 | Val F1: 0.5341 | Gap: 0.2998 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 44/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44: 100%|█████████████████| 1875/1875 [06:22<00:00,  4.90it/s, loss=1.970]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1458\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.21it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8443 | Val F1: 0.5485 | Gap: 0.2958 | EM: 0.3533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 45/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45: 100%|█████████████████| 1875/1875 [06:38<00:00,  4.70it/s, loss=2.143]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1305\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.45it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8745 | Val F1: 0.5266 | Gap: 0.3479 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 46/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46: 100%|█████████████████| 1875/1875 [06:52<00:00,  4.54it/s, loss=2.115]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1166\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.09it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8675 | Val F1: 0.5474 | Gap: 0.3201 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 47/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47: 100%|█████████████████| 1875/1875 [06:51<00:00,  4.56it/s, loss=2.063]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.37it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8668 | Val F1: 0.5550 | Gap: 0.3117 | EM: 0.3533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 48/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48: 100%|█████████████████| 1875/1875 [06:51<00:00,  4.55it/s, loss=1.984]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.97it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8522 | Val F1: 0.5643 | Gap: 0.2879 | EM: 0.3633\n",
      "✓ SAVED! Best F1: 0.5643\n",
      "\n",
      "======================================================================\n",
      "EPOCH 49/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49: 100%|█████████████████| 1875/1875 [06:51<00:00,  4.56it/s, loss=2.120]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0781\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.75it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8653 | Val F1: 0.5586 | Gap: 0.3067 | EM: 0.3700\n",
      "\n",
      "======================================================================\n",
      "EPOCH 50/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50: 100%|█████████████████| 1875/1875 [06:51<00:00,  4.55it/s, loss=2.133]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0666\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.35it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8566 | Val F1: 0.5511 | Gap: 0.3055 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 51/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51: 100%|█████████████████| 1875/1875 [06:51<00:00,  4.56it/s, loss=2.067]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0542\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.09it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8915 | Val F1: 0.5589 | Gap: 0.3327 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 52/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52: 100%|█████████████████| 1875/1875 [06:29<00:00,  4.82it/s, loss=2.172]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0447\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.03it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9109 | Val F1: 0.5405 | Gap: 0.3704 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 53/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 53: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.025]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0310\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.90it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8761 | Val F1: 0.5541 | Gap: 0.3220 | EM: 0.3533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 54/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 54: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.102]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0213\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.79it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8935 | Val F1: 0.5497 | Gap: 0.3438 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 55/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 55: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.016]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0099\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.54it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 33.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8819 | Val F1: 0.5576 | Gap: 0.3243 | EM: 0.3533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 56/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 56: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=1.882]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0007\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.37it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9069 | Val F1: 0.5639 | Gap: 0.3430 | EM: 0.3600\n",
      "\n",
      "======================================================================\n",
      "EPOCH 57/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 57: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=1.991]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9907\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.27it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9001 | Val F1: 0.5977 | Gap: 0.3024 | EM: 0.3800\n",
      "✓ SAVED! Best F1: 0.5977\n",
      "\n",
      "======================================================================\n",
      "EPOCH 58/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 58: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=1.916]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9809\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.33it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8997 | Val F1: 0.5706 | Gap: 0.3291 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 59/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.037]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9734\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.35it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9207 | Val F1: 0.5650 | Gap: 0.3558 | EM: 0.3633\n",
      "\n",
      "======================================================================\n",
      "EPOCH 60/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 60: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=1.891]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9635\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.45it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9309 | Val F1: 0.5747 | Gap: 0.3562 | EM: 0.3733\n",
      "\n",
      "======================================================================\n",
      "EPOCH 61/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 61: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.817]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9551\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.78it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9203 | Val F1: 0.5801 | Gap: 0.3402 | EM: 0.3633\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: manually suppress the fire\n",
      "  F1: 1.000\n",
      "\n",
      "======================================================================\n",
      "EPOCH 62/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 62: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.884]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9458\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.46it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9273 | Val F1: 0.5674 | Gap: 0.3599 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 63/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 63: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.881]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9390\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.32it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9178 | Val F1: 0.5950 | Gap: 0.3228 | EM: 0.3900\n",
      "\n",
      "======================================================================\n",
      "EPOCH 64/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 64: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.793]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9292\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.66it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9329 | Val F1: 0.5801 | Gap: 0.3528 | EM: 0.3800\n",
      "\n",
      "======================================================================\n",
      "EPOCH 65/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 65: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.973]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9210\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.70it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9323 | Val F1: 0.5641 | Gap: 0.3682 | EM: 0.3633\n",
      "\n",
      "======================================================================\n",
      "EPOCH 66/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 66: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=2.112]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9124\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.54it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9332 | Val F1: 0.5671 | Gap: 0.3661 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 67/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 67: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.13it/s, loss=1.886]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9055\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.14it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 34.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9352 | Val F1: 0.5682 | Gap: 0.3670 | EM: 0.3767\n",
      "\n",
      "======================================================================\n",
      "EPOCH 68/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 68: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.961]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8981\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.25it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9452 | Val F1: 0.5580 | Gap: 0.3873 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 69/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=1.874]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.30it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9372 | Val F1: 0.5789 | Gap: 0.3583 | EM: 0.3667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 70/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 70: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.076]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8847\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.19it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9448 | Val F1: 0.5785 | Gap: 0.3663 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "FINAL RESULTS\n",
      "======================================================================\n",
      "Best Val F1: 59.8%\n",
      "Final Val F1: 57.9%\n",
      "Final EM: 35.7%\n",
      "Train-Val Gap: 0.3663\n",
      "Training for seed 1238\n",
      "Initializing token embeddings with GloVe...\n",
      "✓ Token embeddings initialized with GloVe\n",
      "Total parameters: 21.7M\n",
      "Trainable parameters: 21.7M\n",
      "\n",
      "======================================================================\n",
      "BASELINE (Standard LR)\n",
      "======================================================================\n",
      "\n",
      "Using differential LR: embeddings=0.1x, other=1.0x\n",
      "\n",
      "\n",
      "======================================================================\n",
      "EPOCH 1/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=6.806]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 8.5689\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:15<00:00, 13.01it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:30<00:00,  9.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.0286 | Val F1: 0.0340 | Gap: -0.0054 | EM: 0.0200\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the\n",
      "  F1: 0.000\n",
      "✓ SAVED! Best F1: 0.0340\n",
      "\n",
      "======================================================================\n",
      "EPOCH 2/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.14it/s, loss=6.051]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.5646\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:15<00:00, 12.62it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:24<00:00, 12.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.0749 | Val F1: 0.0794 | Gap: -0.0045 | EM: 0.0200\n",
      "✓ SAVED! Best F1: 0.0794\n",
      "\n",
      "======================================================================\n",
      "EPOCH 3/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.14it/s, loss=6.294]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 6.1519\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:13<00:00, 14.62it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:23<00:00, 12.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1665 | Val F1: 0.1275 | Gap: 0.0389 | EM: 0.0467\n",
      "✓ SAVED! Best F1: 0.1275\n",
      "\n",
      "======================================================================\n",
      "EPOCH 4/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.14it/s, loss=6.301]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.8517\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:14<00:00, 13.69it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:23<00:00, 12.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.1784 | Val F1: 0.1304 | Gap: 0.0481 | EM: 0.0533\n",
      "✓ SAVED! Best F1: 0.1304\n",
      "\n",
      "======================================================================\n",
      "EPOCH 5/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.14it/s, loss=5.634]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.5862\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:13<00:00, 15.22it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:20<00:00, 14.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2301 | Val F1: 0.1621 | Gap: 0.0680 | EM: 0.0500\n",
      "✓ SAVED! Best F1: 0.1621\n",
      "\n",
      "======================================================================\n",
      "EPOCH 6/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=5.701]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.3205\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:15<00:00, 13.09it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:24<00:00, 12.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.2211 | Val F1: 0.1757 | Gap: 0.0454 | EM: 0.0633\n",
      "✓ SAVED! Best F1: 0.1757\n",
      "\n",
      "======================================================================\n",
      "EPOCH 7/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=4.767]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 5.0229\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:11<00:00, 16.71it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:21<00:00, 14.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.3214 | Val F1: 0.2109 | Gap: 0.1105 | EM: 0.0900\n",
      "✓ SAVED! Best F1: 0.2109\n",
      "\n",
      "======================================================================\n",
      "EPOCH 8/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=4.626]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.6428\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.76it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.3990 | Val F1: 0.2841 | Gap: 0.1150 | EM: 0.1267\n",
      "✓ SAVED! Best F1: 0.2841\n",
      "\n",
      "======================================================================\n",
      "EPOCH 9/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9: 100%|██████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=4.428]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 4.1883\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:12<00:00, 16.51it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:21<00:00, 13.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.3820 | Val F1: 0.2816 | Gap: 0.1004 | EM: 0.1133\n",
      "\n",
      "======================================================================\n",
      "EPOCH 10/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=3.462]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.7834\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:12<00:00, 15.49it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:18<00:00, 16.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4629 | Val F1: 0.3516 | Gap: 0.1113 | EM: 0.1633\n",
      "✓ SAVED! Best F1: 0.3516\n",
      "\n",
      "======================================================================\n",
      "EPOCH 11/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.508]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.4977\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.45it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4371 | Val F1: 0.4220 | Gap: 0.0151 | EM: 0.2433\n",
      "✓ SAVED! Best F1: 0.4220\n",
      "\n",
      "======================================================================\n",
      "EPOCH 12/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.423]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.3229\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.83it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 19.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.4919 | Val F1: 0.3813 | Gap: 0.1106 | EM: 0.2167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 13/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.265]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.1980\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:11<00:00, 16.87it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:17<00:00, 17.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5218 | Val F1: 0.4277 | Gap: 0.0940 | EM: 0.2333\n",
      "✓ SAVED! Best F1: 0.4277\n",
      "\n",
      "======================================================================\n",
      "EPOCH 14/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.931]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.1045\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.86it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5532 | Val F1: 0.4129 | Gap: 0.1403 | EM: 0.2433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 15/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.287]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 3.0259\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.93it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5698 | Val F1: 0.4039 | Gap: 0.1659 | EM: 0.2300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 16/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.834]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.9575\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.13it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5887 | Val F1: 0.4849 | Gap: 0.1039 | EM: 0.2833\n",
      "✓ SAVED! Best F1: 0.4849\n",
      "\n",
      "======================================================================\n",
      "EPOCH 17/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.106]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.8935\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.80it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 21.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.5961 | Val F1: 0.4622 | Gap: 0.1339 | EM: 0.2767\n",
      "\n",
      "======================================================================\n",
      "EPOCH 18/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=3.218]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.8362\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.25it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6269 | Val F1: 0.4618 | Gap: 0.1651 | EM: 0.2533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 19/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.454]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7837\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:09<00:00, 21.85it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6135 | Val F1: 0.4751 | Gap: 0.1384 | EM: 0.2767\n",
      "\n",
      "======================================================================\n",
      "EPOCH 20/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.510]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.7395\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.23it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:14<00:00, 20.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6353 | Val F1: 0.4920 | Gap: 0.1432 | EM: 0.2967\n",
      "✓ SAVED! Best F1: 0.4920\n",
      "\n",
      "======================================================================\n",
      "EPOCH 21/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.787]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.62it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:16<00:00, 18.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6525 | Val F1: 0.4849 | Gap: 0.1676 | EM: 0.2767\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: the fire\n",
      "  F1: 0.500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 22/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.740]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6515\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:05<00:00, 33.92it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6670 | Val F1: 0.4971 | Gap: 0.1698 | EM: 0.3033\n",
      "✓ SAVED! Best F1: 0.4971\n",
      "\n",
      "======================================================================\n",
      "EPOCH 23/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.847]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.6143\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.66it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 21.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.6978 | Val F1: 0.4710 | Gap: 0.2268 | EM: 0.2800\n",
      "\n",
      "======================================================================\n",
      "EPOCH 24/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24: 100%|█████████████████| 1875/1875 [05:06<00:00,  6.12it/s, loss=2.699]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5778\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.51it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7145 | Val F1: 0.5108 | Gap: 0.2037 | EM: 0.3200\n",
      "✓ SAVED! Best F1: 0.5108\n",
      "\n",
      "======================================================================\n",
      "EPOCH 25/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25: 100%|█████████████████| 1875/1875 [05:05<00:00,  6.13it/s, loss=2.565]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5397\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.27it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 25.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7256 | Val F1: 0.5130 | Gap: 0.2126 | EM: 0.3133\n",
      "✓ SAVED! Best F1: 0.5130\n",
      "\n",
      "======================================================================\n",
      "EPOCH 26/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.604]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.5089\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.55it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7021 | Val F1: 0.5197 | Gap: 0.1825 | EM: 0.3167\n",
      "✓ SAVED! Best F1: 0.5197\n",
      "\n",
      "======================================================================\n",
      "EPOCH 27/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.300]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4792\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.03it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7140 | Val F1: 0.5004 | Gap: 0.2136 | EM: 0.2833\n",
      "\n",
      "======================================================================\n",
      "EPOCH 28/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.446]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4533\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.90it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:13<00:00, 22.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7305 | Val F1: 0.5029 | Gap: 0.2277 | EM: 0.2967\n",
      "\n",
      "======================================================================\n",
      "EPOCH 29/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.359]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.4260\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.60it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7560 | Val F1: 0.5451 | Gap: 0.2109 | EM: 0.3300\n",
      "✓ SAVED! Best F1: 0.5451\n",
      "\n",
      "======================================================================\n",
      "EPOCH 30/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=2.380]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3982\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.55it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:15<00:00, 18.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7542 | Val F1: 0.5353 | Gap: 0.2189 | EM: 0.3133\n",
      "\n",
      "======================================================================\n",
      "EPOCH 31/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.488]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3726\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 24.04it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7503 | Val F1: 0.5316 | Gap: 0.2186 | EM: 0.3200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 32/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=2.497]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.91it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7951 | Val F1: 0.5430 | Gap: 0.2520 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 33/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=2.307]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3241\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.12it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 27.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7891 | Val F1: 0.5394 | Gap: 0.2496 | EM: 0.3233\n",
      "\n",
      "======================================================================\n",
      "EPOCH 34/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.378]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.3040\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 23.50it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7853 | Val F1: 0.5357 | Gap: 0.2495 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 35/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.139]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2828\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.38it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.7864 | Val F1: 0.5632 | Gap: 0.2232 | EM: 0.3500\n",
      "✓ SAVED! Best F1: 0.5632\n",
      "\n",
      "======================================================================\n",
      "EPOCH 36/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=2.293]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2659\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.97it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8161 | Val F1: 0.5587 | Gap: 0.2574 | EM: 0.3467\n",
      "\n",
      "======================================================================\n",
      "EPOCH 37/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=2.304]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2433\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.08it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8249 | Val F1: 0.5331 | Gap: 0.2918 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 38/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=2.233]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2279\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:05<00:00, 36.13it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 36.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8290 | Val F1: 0.5425 | Gap: 0.2866 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 39/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=2.486]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.2096\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.53it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8070 | Val F1: 0.5556 | Gap: 0.2513 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 40/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.409]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.21it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8130 | Val F1: 0.5388 | Gap: 0.2742 | EM: 0.3200\n",
      "\n",
      "======================================================================\n",
      "EPOCH 41/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=2.213]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1787\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.97it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8574 | Val F1: 0.5388 | Gap: 0.3186 | EM: 0.3500\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: manually suppress the fire\n",
      "  F1: 1.000\n",
      "\n",
      "======================================================================\n",
      "EPOCH 42/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.280]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1591\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:08<00:00, 22.41it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 23.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8630 | Val F1: 0.5704 | Gap: 0.2926 | EM: 0.3600\n",
      "✓ SAVED! Best F1: 0.5704\n",
      "\n",
      "======================================================================\n",
      "EPOCH 43/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=2.144]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1448\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.68it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8300 | Val F1: 0.5376 | Gap: 0.2924 | EM: 0.3167\n",
      "\n",
      "======================================================================\n",
      "EPOCH 44/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=2.020]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1330\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 27.42it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8600 | Val F1: 0.5458 | Gap: 0.3143 | EM: 0.3300\n",
      "\n",
      "======================================================================\n",
      "EPOCH 45/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=2.215]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1172\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.85it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 34.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8934 | Val F1: 0.5491 | Gap: 0.3443 | EM: 0.3433\n",
      "\n",
      "======================================================================\n",
      "EPOCH 46/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=2.122]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.1021\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.13it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8807 | Val F1: 0.5829 | Gap: 0.2978 | EM: 0.3667\n",
      "✓ SAVED! Best F1: 0.5829\n",
      "\n",
      "======================================================================\n",
      "EPOCH 47/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.16it/s, loss=2.167]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0904\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 33.24it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 33.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9078 | Val F1: 0.5700 | Gap: 0.3378 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 48/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=2.062]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0766\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.49it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 28.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8624 | Val F1: 0.5637 | Gap: 0.2987 | EM: 0.3733\n",
      "\n",
      "======================================================================\n",
      "EPOCH 49/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=1.937]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.58it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 27.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8730 | Val F1: 0.5658 | Gap: 0.3072 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 50/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=2.159]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0542\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.59it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8921 | Val F1: 0.5462 | Gap: 0.3460 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 51/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=2.027]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0422\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.90it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8905 | Val F1: 0.5717 | Gap: 0.3188 | EM: 0.3733\n",
      "\n",
      "======================================================================\n",
      "EPOCH 52/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52: 100%|█████████████████| 1875/1875 [05:59<00:00,  5.22it/s, loss=2.040]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0303\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.12it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9202 | Val F1: 0.5769 | Gap: 0.3433 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 53/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 53: 100%|█████████████████| 1875/1875 [06:45<00:00,  4.62it/s, loss=1.917]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0213\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 32.41it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 33.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9014 | Val F1: 0.5593 | Gap: 0.3421 | EM: 0.3533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 54/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 54: 100%|█████████████████| 1875/1875 [06:45<00:00,  4.62it/s, loss=1.956]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0099\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 26.65it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 25.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9042 | Val F1: 0.5479 | Gap: 0.3564 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 55/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 55: 100%|█████████████████| 1875/1875 [06:44<00:00,  4.64it/s, loss=1.984]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 2.0016\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.33it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8907 | Val F1: 0.5633 | Gap: 0.3274 | EM: 0.3467\n",
      "\n",
      "======================================================================\n",
      "EPOCH 56/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 56: 100%|█████████████████| 1875/1875 [06:45<00:00,  4.62it/s, loss=2.099]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9885\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.72it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:12<00:00, 24.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9130 | Val F1: 0.5854 | Gap: 0.3276 | EM: 0.3600\n",
      "✓ SAVED! Best F1: 0.5854\n",
      "\n",
      "======================================================================\n",
      "EPOCH 57/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 57: 100%|█████████████████| 1875/1875 [06:47<00:00,  4.60it/s, loss=1.921]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9806\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.89it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.8916 | Val F1: 0.5727 | Gap: 0.3189 | EM: 0.3533\n",
      "\n",
      "======================================================================\n",
      "EPOCH 58/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 58: 100%|█████████████████| 1875/1875 [05:53<00:00,  5.30it/s, loss=2.060]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9723\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.52it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9094 | Val F1: 0.5706 | Gap: 0.3388 | EM: 0.3700\n",
      "\n",
      "======================================================================\n",
      "EPOCH 59/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59: 100%|█████████████████| 1875/1875 [06:51<00:00,  4.56it/s, loss=1.967]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9623\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.93it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9178 | Val F1: 0.5735 | Gap: 0.3443 | EM: 0.3667\n",
      "\n",
      "======================================================================\n",
      "EPOCH 60/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 60: 100%|█████████████████| 1875/1875 [06:50<00:00,  4.57it/s, loss=2.031]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9507\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.54it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9077 | Val F1: 0.5523 | Gap: 0.3553 | EM: 0.3333\n",
      "\n",
      "======================================================================\n",
      "EPOCH 61/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 61: 100%|█████████████████| 1875/1875 [06:48<00:00,  4.59it/s, loss=1.893]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9428\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.74it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 33.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9280 | Val F1: 0.5732 | Gap: 0.3547 | EM: 0.3467\n",
      "\n",
      "Sample:\n",
      "  Q: After the operators are warned by the escape of the steam, w...\n",
      "  True: manually suppress the fire\n",
      "  Pred: fire\n",
      "  F1: 0.500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 62/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 62: 100%|█████████████████| 1875/1875 [06:49<00:00,  4.58it/s, loss=2.012]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9339\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.94it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 31.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9455 | Val F1: 0.5648 | Gap: 0.3807 | EM: 0.3567\n",
      "\n",
      "======================================================================\n",
      "EPOCH 63/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 63: 100%|█████████████████| 1875/1875 [06:49<00:00,  4.58it/s, loss=1.946]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9274\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 29.82it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:10<00:00, 29.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9371 | Val F1: 0.5510 | Gap: 0.3861 | EM: 0.3500\n",
      "\n",
      "======================================================================\n",
      "EPOCH 64/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 64: 100%|█████████████████| 1875/1875 [06:49<00:00,  4.57it/s, loss=1.796]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9206\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 28.95it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9307 | Val F1: 0.5515 | Gap: 0.3792 | EM: 0.3267\n",
      "\n",
      "======================================================================\n",
      "EPOCH 65/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 65: 100%|█████████████████| 1875/1875 [06:48<00:00,  4.59it/s, loss=1.962]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9094\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 30.94it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 30.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9426 | Val F1: 0.5915 | Gap: 0.3510 | EM: 0.3833\n",
      "✓ SAVED! Best F1: 0.5915\n",
      "\n",
      "======================================================================\n",
      "EPOCH 66/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 66: 100%|█████████████████| 1875/1875 [06:04<00:00,  5.14it/s, loss=1.900]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.9045\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 28.48it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:09<00:00, 32.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9250 | Val F1: 0.5884 | Gap: 0.3365 | EM: 0.3767\n",
      "\n",
      "======================================================================\n",
      "EPOCH 67/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 67: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=1.874]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8949\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:07<00:00, 25.66it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:11<00:00, 26.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9302 | Val F1: 0.5547 | Gap: 0.3754 | EM: 0.3367\n",
      "\n",
      "======================================================================\n",
      "EPOCH 68/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 68: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=1.942]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8894\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:05<00:00, 33.74it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 35.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9464 | Val F1: 0.5606 | Gap: 0.3858 | EM: 0.3400\n",
      "\n",
      "======================================================================\n",
      "EPOCH 69/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69: 100%|█████████████████| 1875/1875 [05:04<00:00,  6.17it/s, loss=1.900]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8831\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:06<00:00, 31.15it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 34.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9417 | Val F1: 0.5767 | Gap: 0.3650 | EM: 0.3733\n",
      "\n",
      "======================================================================\n",
      "EPOCH 70/70\n",
      "======================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 70: 100%|█████████████████| 1875/1875 [05:03<00:00,  6.17it/s, loss=1.844]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loss: 1.8731\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████████| 200/200 [00:05<00:00, 33.63it/s]\n",
      "Eval: 100%|███████████████████████████████████| 300/300 [00:08<00:00, 33.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train F1: 0.9415 | Val F1: 0.5912 | Gap: 0.3502 | EM: 0.3600\n",
      "\n",
      "======================================================================\n",
      "FINAL RESULTS\n",
      "======================================================================\n",
      "Best Val F1: 59.2%\n",
      "Final Val F1: 59.1%\n",
      "Final EM: 36.0%\n",
      "Train-Val Gap: 0.3502\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "SQuAD Answer Generation with GloVe Embeddings + Q/K Hypothesis Testing\n",
    "\n",
    "EXPECTED PERFORMANCE:\n",
    "- With GloVe embeddings: 40-55% F1 ✓\n",
    "- Training time: ~40-50 minutes\n",
    "- Can reach 50%+ with Q/K hypothesis\n",
    "\"\"\"\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader, Subset\n",
    "from transformers import GPT2Tokenizer\n",
    "import json\n",
    "from collections import Counter\n",
    "import string\n",
    "import re\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import os\n",
    "import urllib.request\n",
    "import zipfile\n",
    "\n",
    "# Configuration\n",
    "TEST_QK_HYPOTHESIS = False  # Set True after baseline completes\n",
    "QK_LR_MULTIPLIER = 2.5  # Q/K learn 2.5x faster\n",
    "\n",
    "# Optimized for GloVe embeddings\n",
    "D_MODEL = 300  # Match GloVe dimension exactly\n",
    "N_HEADS = 6\n",
    "N_LAYERS = 6\n",
    "D_FF = 1200\n",
    "MAX_SEQ_LEN = 256\n",
    "MAX_ANSWER_LEN = 50\n",
    "DROPOUT = 0.2\n",
    "BATCH_SIZE = 32\n",
    "ACCUMULATION_STEPS = 2  # Effective batch: 48\n",
    "BASE_LR = 5e-4\n",
    "WARMUP_STEPS = 1000\n",
    "NUM_EPOCHS = 70\n",
    "GRAD_CLIP = 0.5\n",
    "WEIGHT_DECAY = 0.05\n",
    "LABEL_SMOOTHING = 0.1\n",
    "TRAIN_SUBSET_SIZE = 60000  # More data with GloVe\n",
    "VAL_SUBSET_SIZE = 10000\n",
    "\n",
    "\n",
    "def download_and_extract_glove():\n",
    "    \"\"\"Download and extract GloVe embeddings\"\"\"\n",
    "    glove_file = 'glove.6B.300d.txt'\n",
    "    \n",
    "    if os.path.exists(glove_file):\n",
    "        print(f\"✓ GloVe embeddings found: {glove_file}\")\n",
    "        return glove_file\n",
    "    \n",
    "    print(\"\\n\" + \"=\"*70)\n",
    "    print(\"DOWNLOADING GLOVE EMBEDDINGS\")\n",
    "    print(\"=\"*70)\n",
    "    \n",
    "    zip_file = 'glove.6B.zip'\n",
    "    \n",
    "    if not os.path.exists(zip_file):\n",
    "        print(\"Downloading GloVe 6B (822MB)... This may take a few minutes\")\n",
    "        url = 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.6B.zip'\n",
    "        \n",
    "        try:\n",
    "            # Download with progress bar\n",
    "            response = urllib.request.urlopen(url)\n",
    "            total_size = int(response.headers.get('content-length', 0))\n",
    "            \n",
    "            with open(zip_file, 'wb') as f, tqdm(\n",
    "                total=total_size, unit='B', unit_scale=True, desc='Downloading'\n",
    "            ) as pbar:\n",
    "                while True:\n",
    "                    chunk = response.read(8192)\n",
    "                    if not chunk:\n",
    "                        break\n",
    "                    f.write(chunk)\n",
    "                    pbar.update(len(chunk))\n",
    "            \n",
    "            print(\"✓ Download complete!\")\n",
    "        except Exception as e:\n",
    "            print(f\"Download failed: {e}\")\n",
    "            print(\"\\nAlternative: Download manually from:\")\n",
    "            print(\"  https://nlp.stanford.edu/projects/glove/\")\n",
    "            print(\"  or https://huggingface.co/stanfordnlp/glove\")\n",
    "            return None\n",
    "    \n",
    "    # Extract\n",
    "    if os.path.exists(zip_file):\n",
    "        print(\"Extracting GloVe embeddings...\")\n",
    "        with zipfile.ZipFile(zip_file, 'r') as zip_ref:\n",
    "            # Only extract the 300d file we need\n",
    "            zip_ref.extract('glove.6B.300d.txt')\n",
    "        print(\"✓ Extraction complete!\")\n",
    "        \n",
    "        # Optionally remove zip to save space\n",
    "        # os.remove(zip_file)\n",
    "    \n",
    "    if os.path.exists(glove_file):\n",
    "        return glove_file\n",
    "    else:\n",
    "        print(\"⚠ GloVe file not found after extraction\")\n",
    "        return None\n",
    "\n",
    "\n",
    "def load_glove_embeddings(glove_file, tokenizer, embedding_dim=300):\n",
    "    \"\"\"Load GloVe and create embedding matrix for GPT-2 tokenizer\"\"\"\n",
    "    print(\"\\n\" + \"=\"*70)\n",
    "    print(\"LOADING GLOVE EMBEDDINGS\")\n",
    "    print(\"=\"*70)\n",
    "    \n",
    "    # Load GloVe vectors\n",
    "    print(\"Reading GloVe file (this takes ~1 minute)...\")\n",
    "    glove_vectors = {}\n",
    "    \n",
    "    with open(glove_file, 'r', encoding='utf-8') as f:\n",
    "        for line in tqdm(f, total=400000, desc=\"Loading GloVe\"):\n",
    "            values = line.rstrip().split(' ')\n",
    "            word = values[0]\n",
    "            vector = np.asarray(values[1:], dtype='float32')\n",
    "            glove_vectors[word] = vector\n",
    "    \n",
    "    print(f\"✓ Loaded {len(glove_vectors):,} GloVe vectors\")\n",
    "    \n",
    "    # Create embedding matrix for tokenizer vocabulary\n",
    "    vocab_size = tokenizer.vocab_size\n",
    "    embedding_matrix = np.random.normal(0, 0.1, (vocab_size, embedding_dim)).astype('float32')\n",
    "    \n",
    "    # Match tokenizer vocab with GloVe\n",
    "    print(\"Matching tokenizer vocabulary with GloVee alpha finding is really significant - it suggests transformers might naturally want to do error correction but standard architectures prevent it!...\")\n",
    "    matched = 0\n",
    "    \n",
    "    for token, idx in tqdm(tokenizer.get_vocab().items(), desc=\"Matching\"):\n",
    "        # Try different matching strategies\n",
    "        token_clean = token.replace('Ġ', '').replace('Ċ', '').lower().strip()\n",
    "        \n",
    "        if token in glove_vectors:\n",
    "            embedding_matrix[idx] = glove_vectors[token]\n",
    "            matched += 1\n",
    "        elif token.lower() in glove_vectors:\n",
    "            embedding_matrix[idx] = glove_vectors[token.lower()]\n",
    "            matched += 1\n",
    "        elif token_clean in glove_vectors:\n",
    "            embedding_matrix[idx] = glove_vectors[token_clean]\n",
    "            matched += 1\n",
    "        # For subword tokens, try averaging character embeddings\n",
    "        elif len(token_clean) > 0 and all(c.isalpha() for c in token_clean):\n",
    "            # Use random but consistent embedding for unknown tokens\n",
    "            pass\n",
    "    \n",
    "    match_rate = 100 * matched / vocab_size\n",
    "    print(f\"✓ Matched {matched:,}/{vocab_size:,} tokens ({match_rate:.1f}%)\")\n",
    "    print(\"=\"*70 + \"\\n\")\n",
    "    \n",
    "    return torch.FloatTensor(embedding_matrix)\n",
    "\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, d_model, n_heads, dropout=0.1):\n",
    "        super().__init__()\n",
    "        assert d_model % n_heads == 0\n",
    "        self.d_k = d_model // n_heads\n",
    "        self.n_heads = n_heads\n",
    "        \n",
    "        self.q_linear = nn.Linear(d_model, d_model)\n",
    "        self.k_linear = nn.Linear(d_model, d_model)\n",
    "        self.v_linear = nn.Linear(d_model, d_model)\n",
    "        self.out = nn.Linear(d_model, d_model)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.last_attention_weights = None\n",
    "        \n",
    "    def forward(self, q, k, v, mask=None, save_attention=False):\n",
    "        bs = q.size(0)\n",
    "        \n",
    "        q = self.q_linear(q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)\n",
    "        k = self.k_linear(k).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)\n",
    "        v = self.v_linear(v).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)\n",
    "        \n",
    "        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)\n",
    "        \n",
    "        if mask is not None:\n",
    "            scores = scores.masked_fill(mask == 0, -1e9)\n",
    "        \n",
    "        attn = torch.softmax(scores, dim=-1)\n",
    "        if save_attention:\n",
    "            self.last_attention_weights = attn.detach()\n",
    "        \n",
    "        attn = self.dropout(attn)\n",
    "        context = torch.matmul(attn, v)\n",
    "        context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_k)\n",
    "        \n",
    "        return self.out(context)\n",
    "\n",
    "\n",
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)\n",
    "        self.ff = nn.Sequential(\n",
    "            nn.Linear(d_model, d_ff),\n",
    "            nn.GELU(),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(d_ff, d_model),\n",
    "            nn.Dropout(dropout)\n",
    "        )\n",
    "        self.norm1 = nn.LayerNorm(d_model)\n",
    "        self.norm2 = nn.LayerNorm(d_model)\n",
    "        \n",
    "    def forward(self, x, mask=None, save_attention=False):\n",
    "        # Pre-norm\n",
    "        attn_out = self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask, save_attention)\n",
    "        x = x + attn_out\n",
    "        x = x + self.ff(self.norm2(x))\n",
    "        return x\n",
    "\n",
    "\n",
    "class GPTAnswerGenerator(nn.Module):\n",
    "    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len, dropout=0.1, pretrained_embeddings=None):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.token_embedding = nn.Embedding(vocab_size, d_model)\n",
    "        \n",
    "        # Initialize with pretrained embeddings if provided\n",
    "        if pretrained_embeddings is not None:\n",
    "            print(\"Initializing token embeddings with GloVe...\")\n",
    "            self.token_embedding.weight.data.copy_(pretrained_embeddings)\n",
    "            print(\"✓ Token embeddings initialized with GloVe\")\n",
    "        \n",
    "        self.position_embedding = nn.Embedding(max_seq_len, d_model)\n",
    "        self.emb_dropout = nn.Dropout(dropout)\n",
    "        \n",
    "        self.layers = nn.ModuleList([\n",
    "            DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)\n",
    "        ])\n",
    "        \n",
    "        self.norm = nn.LayerNorm(d_model)\n",
    "        self.output = nn.Linear(d_model, vocab_size)\n",
    "        \n",
    "        # Weight tying\n",
    "        self.output.weight = self.token_embedding.weight\n",
    "        \n",
    "        # Initialize non-embedding weights\n",
    "        self._init_weights()\n",
    "        \n",
    "    def _init_weights(self):\n",
    "        # Don't reinitialize token_embedding if using GloVe\n",
    "        for name, p in self.named_parameters():\n",
    "            if 'token_embedding' not in name and p.dim() > 1:\n",
    "                nn.init.xavier_uniform_(p, gain=1/np.sqrt(2))\n",
    "        \n",
    "    def forward(self, x, mask=None, save_attention=False):\n",
    "        pos = torch.arange(x.size(1), device=x.device).unsqueeze(0)\n",
    "        x = self.token_embedding(x) + self.position_embedding(pos)\n",
    "        x = self.emb_dropout(x)\n",
    "        \n",
    "        for layer in self.layers:\n",
    "            x = layer(x, mask, save_attention)\n",
    "        \n",
    "        return self.output(self.norm(x))\n",
    "    \n",
    "    def get_attention_weights(self):\n",
    "        return [layer.self_attn.last_attention_weights for layer in self.layers]\n",
    "\n",
    "\n",
    "class SQuADDataset(Dataset):\n",
    "    def __init__(self, data_path, tokenizer, max_len, max_ans_len):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_len = max_len\n",
    "        self.max_ans_len = max_ans_len\n",
    "        self.data = []\n",
    "        \n",
    "        with open(data_path, 'r') as f:\n",
    "            squad = json.load(f)\n",
    "        \n",
    "        for article in squad['data']:\n",
    "            for para in article['paragraphs']:\n",
    "                ctx = para['context']\n",
    "                for qa in para['qas']:\n",
    "                    if not qa['is_impossible'] and qa['answers']:\n",
    "                        ans = qa['answers'][0]['text']\n",
    "                        ans_start = qa['answers'][0]['answer_start']\n",
    "                        \n",
    "                        # Extract relevant context window\n",
    "                        start = max(0, ans_start - 200)\n",
    "                        end = min(len(ctx), ans_start + len(ans) + 200)\n",
    "                        focused_ctx = ctx[start:end]\n",
    "                        \n",
    "                        self.data.append({\n",
    "                            'context': focused_ctx,\n",
    "                            'question': qa['question'],\n",
    "                            'answer': ans\n",
    "                        })\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        item = self.data[idx]\n",
    "        \n",
    "        # Format: Q: question C: context A: answer\n",
    "        prefix = f\"Q: {item['question']} C: {item['context']} A:\"\n",
    "        answer = f\" {item['answer']}\"\n",
    "        \n",
    "        prefix_ids = self.tokenizer.encode(prefix, max_length=self.max_len-self.max_ans_len-2, \n",
    "                                          truncation=True, add_special_tokens=False)\n",
    "        answer_ids = self.tokenizer.encode(answer, max_length=self.max_ans_len, \n",
    "                                          truncation=True, add_special_tokens=False)\n",
    "        answer_ids.append(self.tokenizer.eos_token_id)\n",
    "        \n",
    "        input_ids = prefix_ids + answer_ids\n",
    "        labels = [-100] * len(prefix_ids) + answer_ids\n",
    "        \n",
    "        # Pad\n",
    "        while len(input_ids) < self.max_len:\n",
    "            input_ids.append(self.tokenizer.pad_token_id)\n",
    "            labels.append(-100)\n",
    "        \n",
    "        return {\n",
    "            'input_ids': torch.tensor(input_ids[:self.max_len]),\n",
    "            'labels': torch.tensor(labels[:self.max_len])\n",
    "        }\n",
    "\n",
    "\n",
    "def create_mask(seq_len, device):\n",
    "    return (torch.triu(torch.ones(seq_len, seq_len, device=device), 1) == 0).unsqueeze(0).unsqueeze(0)\n",
    "\n",
    "\n",
    "def normalize_answer(s):\n",
    "    s = s.lower()\n",
    "    s = re.sub(r'\\b(a|an|the)\\b', ' ', s)\n",
    "    s = ''.join(c for c in s if c not in string.punctuation)\n",
    "    return ' '.join(s.split())\n",
    "\n",
    "\n",
    "def f1_score(pred, truth):\n",
    "    pred_tok = normalize_answer(pred).split()\n",
    "    truth_tok = normalize_answer(truth).split()\n",
    "    \n",
    "    if not pred_tok or not truth_tok:\n",
    "        return int(pred_tok == truth_tok)\n",
    "    \n",
    "    common = Counter(pred_tok) & Counter(truth_tok)\n",
    "    if not common:\n",
    "        return 0\n",
    "    \n",
    "    prec = sum(common.values()) / len(pred_tok)\n",
    "    rec = sum(common.values()) / len(truth_tok)\n",
    "    return 2 * prec * rec / (prec + rec)\n",
    "\n",
    "\n",
    "def exact_match(pred, truth):\n",
    "    return int(normalize_answer(pred) == normalize_answer(truth))\n",
    "\n",
    "\n",
    "def train_epoch(model, loader, opt, sched, device, epoch):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    opt.zero_grad()\n",
    "    \n",
    "    pbar = tqdm(loader, desc=f\"Epoch {epoch}\")\n",
    "    for i, batch in enumerate(pbar):\n",
    "        inp = batch['input_ids'].to(device)\n",
    "        lbl = batch['labels'].to(device)\n",
    "        \n",
    "        mask = create_mask(inp.size(1), device)\n",
    "        logits = model(inp, mask)\n",
    "        \n",
    "        # Shift for next-token prediction\n",
    "        loss = nn.functional.cross_entropy(\n",
    "            logits[:, :-1].reshape(-1, logits.size(-1)),\n",
    "            lbl[:, 1:].reshape(-1),\n",
    "            ignore_index=-100,\n",
    "            label_smoothing=LABEL_SMOOTHING\n",
    "        )\n",
    "        \n",
    "        loss = loss / ACCUMULATION_STEPS\n",
    "        loss.backward()\n",
    "        \n",
    "        if (i + 1) % ACCUMULATION_STEPS == 0:\n",
    "            nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n",
    "            opt.step()\n",
    "            sched.step()\n",
    "            opt.zero_grad()\n",
    "        \n",
    "        total_loss += loss.item() * ACCUMULATION_STEPS\n",
    "        pbar.set_postfix({'loss': f'{loss.item() * ACCUMULATION_STEPS:.3f}'})\n",
    "    \n",
    "    return total_loss / len(loader)\n",
    "\n",
    "\n",
    "def generate(model, tokenizer, context, question, device, max_len=50):\n",
    "    model.eval()\n",
    "    \n",
    "    prompt = f\"Q: {question} C: {context} A:\"\n",
    "    ids = tokenizer.encode(prompt, max_length=MAX_SEQ_LEN-max_len-5, \n",
    "                          truncation=True, add_special_tokens=False, return_tensors='pt').to(device)\n",
    "    \n",
    "    start_len = ids.size(1)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for _ in range(max_len):\n",
    "            if ids.size(1) >= MAX_SEQ_LEN:\n",
    "                break\n",
    "            \n",
    "            mask = create_mask(ids.size(1), device)\n",
    "            logits = model(ids, mask)\n",
    "            next_tok = logits[:, -1].argmax(-1, keepdim=True)\n",
    "            ids = torch.cat([ids, next_tok], 1)\n",
    "            \n",
    "            if next_tok.item() == tokenizer.eos_token_id:\n",
    "                break\n",
    "    \n",
    "    return tokenizer.decode(ids[0, start_len:], skip_special_tokens=True).strip()\n",
    "\n",
    "\n",
    "def evaluate(model, dataset, tokenizer, device, n_samples=300):\n",
    "    model.eval()\n",
    "    f1_sum = em_sum = 0\n",
    "    \n",
    "    if isinstance(dataset, Subset):\n",
    "        items = [dataset.dataset.data[dataset.indices[i]] for i in range(min(n_samples, len(dataset)))]\n",
    "    else:\n",
    "        items = dataset.data[:n_samples]\n",
    "    \n",
    "    for item in tqdm(items, desc=\"Eval\"):\n",
    "        pred = generate(model, tokenizer, item['context'], item['question'], device)\n",
    "        f1_sum += f1_score(pred, item['answer'])\n",
    "        em_sum += exact_match(pred, item['answer'])\n",
    "    \n",
    "    return {'f1': f1_sum / len(items), 'em': em_sum / len(items)}\n",
    "\n",
    "\n",
    "def analyze_attention(model, dataset, tokenizer, device, n=30):\n",
    "    model.eval()\n",
    "    scores = []\n",
    "    \n",
    "    if isinstance(dataset, Subset):\n",
    "        items = [dataset.dataset.data[dataset.indices[i]] for i in range(min(n, len(dataset)))]\n",
    "    else:\n",
    "        items = dataset.data[:n]\n",
    "    \n",
    "    for item in items:\n",
    "        prompt = f\"Q: {item['question']} C: {item['context']} A:\"\n",
    "        ids = tokenizer.encode(prompt, max_length=MAX_SEQ_LEN-MAX_ANSWER_LEN, \n",
    "                              truncation=True, add_special_tokens=False, return_tensors='pt').to(device)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            mask = create_mask(ids.size(1), device)\n",
    "            model(ids, mask, save_attention=True)\n",
    "            \n",
    "            weights = model.get_attention_weights()\n",
    "            if weights[0] is not None:\n",
    "                avg = torch.stack([w[0] for w in weights if w is not None]).mean(0)\n",
    "                scores.append(avg[0].mean().item())\n",
    "    \n",
    "    return np.mean(scores) if scores else 0\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    \n",
    "    print(\"=\"*70)\n",
    "    print(\"SQUAD ANSWER GENERATION WITH GLOVE EMBEDDINGS\")\n",
    "    print(\"=\"*70)\n",
    "    #print(f\"Expected F1: 40-55% (with GloVe)\")\n",
    "    print(f\"Model: {N_LAYERS}L, {D_MODEL}d, {N_HEADS}h\")\n",
    "    print(f\"Device: {device}\")\n",
    "    print(\"=\"*70 + \"\\n\")\n",
    "    \n",
    "    # Download and load GloVe\n",
    "    glove_file = download_and_extract_glove()\n",
    "    \n",
    "    if glove_file is None:\n",
    "        print(\"\\n WARNING: Could not load GloVe embeddings\")\n",
    "        print(\"Proceeding without pretrained embeddings (expect 15-25% F1)\")\n",
    "        pretrained_embeddings = None\n",
    "    \n",
    "    # Download SQuAD datasets\n",
    "    for name in ['train-v2.0.json', 'dev-v2.0.json']:\n",
    "        if not os.path.exists(name):\n",
    "            print(f\"Downloading {name}...\")\n",
    "            urllib.request.urlretrieve(\n",
    "                f'https://rajpurkar.github.io/SQuAD-explorer/dataset/{name}', name)\n",
    "    \n",
    "    # Setup tokenizer\n",
    "    print(\"Loading tokenizer...\")\n",
    "    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "    \n",
    "    # Load GloVe embeddings for tokenizer\n",
    "    if glove_file:\n",
    "        pretrained_embeddings = load_glove_embeddings(glove_file, tokenizer, D_MODEL)\n",
    "    else:\n",
    "        pretrained_embeddings = None\n",
    "    \n",
    "    # Load datasets\n",
    "    print(\"Loading datasets...\")\n",
    "    full_train = SQuADDataset('train-v2.0.json', tokenizer, MAX_SEQ_LEN, MAX_ANSWER_LEN)\n",
    "    full_val = SQuADDataset('dev-v2.0.json', tokenizer, MAX_SEQ_LEN, MAX_ANSWER_LEN)\n",
    "    \n",
    "    train_ds = Subset(full_train, torch.randperm(len(full_train))[:TRAIN_SUBSET_SIZE])\n",
    "    val_ds = Subset(full_val, torch.randperm(len(full_val))[:VAL_SUBSET_SIZE])\n",
    "    \n",
    "    print(f\"Train: {len(train_ds)}, Val: {len(val_ds)}\\n\")\n",
    "    \n",
    "    loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)\n",
    "    \n",
    "    # Model\n",
    "    print(\"Initializing model...\")\n",
    "    n_seed = [1234,1235,1236,1237,1238]\n",
    "    for seed_ in n_seed:\n",
    "        print(\"Training for seed\", seed_)\n",
    "        torch.manual_seed(seed_)\n",
    "        model = GPTAnswerGenerator(\n",
    "            vocab_size=tokenizer.vocab_size,\n",
    "            d_model=D_MODEL,\n",
    "            n_heads=N_HEADS,\n",
    "            n_layers=N_LAYERS,\n",
    "            d_ff=D_FF,\n",
    "            max_seq_len=MAX_SEQ_LEN,\n",
    "            dropout=DROPOUT,\n",
    "            pretrained_embeddings=pretrained_embeddings\n",
    "        ).to(device)\n",
    "\n",
    "        total_params = sum(p.numel() for p in model.parameters()) / 1e6\n",
    "        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6\n",
    "        print(f\"Total parameters: {total_params:.1f}M\")\n",
    "        print(f\"Trainable parameters: {trainable_params:.1f}M\\n\")\n",
    "\n",
    "        # Optimizer with differential learning rates for embeddings\n",
    "        if TEST_QK_HYPOTHESIS:\n",
    "            print(\"=\"*70)\n",
    "            print(f\"TESTING Q/K HYPOTHESIS - Q/K LR = {QK_LR_MULTIPLIER}x\")\n",
    "            print(\"=\"*70 + \"\\n\")\n",
    "\n",
    "            qk = [p for n, p in model.named_parameters() if 'q_linear' in n or 'k_linear' in n]\n",
    "            other = [p for n, p in model.named_parameters() if 'q_linear' not in n and 'k_linear' not in n]\n",
    "\n",
    "            print(f\"Q/K params: {sum(p.numel() for p in qk)/1e6:.1f}M\")\n",
    "            print(f\"Other params: {sum(p.numel() for p in other)/1e6:.1f}M\\n\")\n",
    "\n",
    "            opt = torch.optim.AdamW([\n",
    "                {'params': qk, 'lr': BASE_LR * QK_LR_MULTIPLIER, 'weight_decay': WEIGHT_DECAY},\n",
    "                {'params': other, 'lr': BASE_LR, 'weight_decay': WEIGHT_DECAY}\n",
    "            ])\n",
    "\n",
    "            sched = torch.optim.lr_scheduler.OneCycleLR(\n",
    "                opt, [BASE_LR * QK_LR_MULTIPLIER, BASE_LR],\n",
    "                total_steps=len(loader) * NUM_EPOCHS,\n",
    "                pct_start=WARMUP_STEPS / (len(loader) * NUM_EPOCHS)\n",
    "            )\n",
    "        else:\n",
    "            print(\"=\"*70)\n",
    "            print(\"BASELINE (Standard LR)\")\n",
    "            print(\"=\"*70 + \"\\n\")\n",
    "\n",
    "            # Use lower LR for pretrained embeddings if they exist\n",
    "            if pretrained_embeddings is not None:\n",
    "                embedding_params = [model.token_embedding.weight]\n",
    "                other_params = [p for n, p in model.named_parameters() if 'token_embedding' not in n]\n",
    "\n",
    "                opt = torch.optim.AdamW([\n",
    "                    {'params': embedding_params, 'lr': BASE_LR * 0.1, 'weight_decay': 0},  # Fine-tune slowly\n",
    "                    {'params': other_params, 'lr': BASE_LR, 'weight_decay': WEIGHT_DECAY}\n",
    "                ])\n",
    "\n",
    "                print(\"Using differential LR: embeddings=0.1x, other=1.0x\\n\")\n",
    "            else:\n",
    "                opt = torch.optim.AdamW(model.parameters(), BASE_LR, weight_decay=WEIGHT_DECAY)\n",
    "\n",
    "            sched = torch.optim.lr_scheduler.OneCycleLR(\n",
    "                opt,\n",
    "                max_lr=BASE_LR if pretrained_embeddings is None else [BASE_LR * 0.1, BASE_LR],\n",
    "                total_steps=len(loader) * NUM_EPOCHS,\n",
    "                pct_start=WARMUP_STEPS / (len(loader) * NUM_EPOCHS)\n",
    "            )\n",
    "\n",
    "        # Train\n",
    "        best_f1 = 0\n",
    "        results = {'loss': [], 'train_f1': [], 'val_f1': [], 'val_em': [], 'attn': []}\n",
    "\n",
    "        for e in range(NUM_EPOCHS):\n",
    "            print(f\"\\n{'='*70}\")\n",
    "            print(f\"EPOCH {e+1}/{NUM_EPOCHS}\")\n",
    "            print('='*70)\n",
    "\n",
    "            loss = train_epoch(model, loader, opt, sched, device, e+1)\n",
    "            results['loss'].append(loss)\n",
    "            print(f\"\\nLoss: {loss:.4f}\")\n",
    "\n",
    "            # Eval\n",
    "            train_m = evaluate(model, train_ds, tokenizer, device, 200)\n",
    "            val_m = evaluate(model, val_ds, tokenizer, device, 300)\n",
    "\n",
    "            results['train_f1'].append(train_m['f1'])\n",
    "            results['val_f1'].append(val_m['f1'])\n",
    "            results['val_em'].append(val_m['em'])\n",
    "\n",
    "            gap = train_m['f1'] - val_m['f1']\n",
    "            print(f\"Train F1: {train_m['f1']:.4f} | Val F1: {val_m['f1']:.4f} | Gap: {gap:.4f} | EM: {val_m['em']:.4f}\")\n",
    "\n",
    "            # Sample\n",
    "            if e % 20 == 0:\n",
    "                item = val_ds.dataset.data[val_ds.indices[0]]\n",
    "                pred = generate(model, tokenizer, item['context'], item['question'], device)\n",
    "                print(f\"\\nSample:\")\n",
    "                print(f\"  Q: {item['question'][:60]}...\")\n",
    "                print(f\"  True: {item['answer']}\")\n",
    "                print(f\"  Pred: {pred}\")\n",
    "                print(f\"  F1: {f1_score(pred, item['answer']):.3f}\")\n",
    "\n",
    "            # Attention\n",
    "            if e % 20 == 0 and TEST_QK_HYPOTHESIS:\n",
    "                attn = analyze_attention(model, val_ds, tokenizer, device)\n",
    "                results['attn'].append(attn)\n",
    "                print(f\"Attention: {attn:.4f}\")\n",
    "\n",
    "            # Save best\n",
    "            if val_m['f1'] > best_f1:\n",
    "                best_f1 = val_m['f1']\n",
    "                name = 'best_qk_'+str(seed_)+'.pt' if TEST_QK_HYPOTHESIS else 'best_baseline_'+str(seed_)+'.pt'\n",
    "                torch.save({'model': model.state_dict(), 'f1': best_f1, 'epoch': e+1}, name)\n",
    "                print(f\"✓ SAVED! Best F1: {best_f1:.4f}\")\n",
    "\n",
    "        # Final\n",
    "        print(f\"\\n{'='*70}\")\n",
    "        print(\"FINAL RESULTS\")\n",
    "        print('='*70)\n",
    "        print(f\"Best Val F1: {best_f1*100:.1f}%\")\n",
    "        print(f\"Final Val F1: {results['val_f1'][-1]*100:.1f}%\")\n",
    "        print(f\"Final EM: {results['val_em'][-1]*100:.1f}%\")\n",
    "        print(f\"Train-Val Gap: {results['train_f1'][-1] - results['val_f1'][-1]:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4197bcc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|███████████████████████████████| 40000/40000 [22:35<00:00, 29.52it/s]\n"
     ]
    }
   ],
   "source": [
    "# Eval\n",
    "train_m = evaluate(model, train_ds, tokenizer, device, 40000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14dcc4d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_m[\"f1\"],train_m[\"em\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01114815",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_m = evaluate(model, val_ds, tokenizer, device, 20000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c37c3907",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_m[\"f1\"],val_m[\"em\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae5121cc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c57fa4a8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50ed7f6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5633d680",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Simple Analysis Functions for Notebook\n",
    "\n",
    "IMPORTANT: Your dataset needs answer span information!\n",
    "\n",
    "First, reload your validation dataset with this code:\n",
    "    val_dataset_full = load_squad_with_spans('dev-v2.0.json', tokenizer)\n",
    "    \n",
    "Then run analysis:\n",
    "    results = analyze_model(model, val_dataset_full, tokenizer, device, n_samples=100)\n",
    "\"\"\"\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from collections import Counter\n",
    "import string\n",
    "import re\n",
    "import json\n",
    "\n",
    "\n",
    "def load_squad_with_spans(data_path, tokenizer):\n",
    "    \"\"\"\n",
    "    Load SQuAD dataset with answer span information\n",
    "    This is needed for sufficiency, faithfulness, and attention analysis\n",
    "    \"\"\"\n",
    "    class SQuADWithSpans:\n",
    "        def __init__(self, data_path):\n",
    "            self.data = []\n",
    "            \n",
    "            with open(data_path, 'r', encoding='utf-8') as f:\n",
    "                squad = json.load(f)\n",
    "            \n",
    "            for article in squad['data']:\n",
    "                for para in article['paragraphs']:\n",
    "                    ctx = para['context']\n",
    "                    for qa in para['qas']:\n",
    "                        if not qa['is_impossible'] and qa['answers']:\n",
    "                            ans = qa['answers'][0]['text']\n",
    "                            ans_start = qa['answers'][0]['answer_start']\n",
    "                            ans_end = ans_start + len(ans)\n",
    "                            \n",
    "                            # Extract focused context\n",
    "                            start = max(0, ans_start - 200)\n",
    "                            end = min(len(ctx), ans_start + len(ans) + 200)\n",
    "                            focused_ctx = ctx[start:end]\n",
    "                            \n",
    "                            # Adjust answer positions\n",
    "                            adjusted_ans_start = ans_start - start\n",
    "                            adjusted_ans_end = adjusted_ans_start + len(ans)\n",
    "                            \n",
    "                            self.data.append({\n",
    "                                'context': focused_ctx,\n",
    "                                'question': qa['question'],\n",
    "                                'answer': ans,\n",
    "                                'answer_start': adjusted_ans_start,\n",
    "                                'answer_end': adjusted_ans_end\n",
    "                            })\n",
    "        \n",
    "        def __len__(self):\n",
    "            return len(self.data)\n",
    "    \n",
    "    return SQuADWithSpans(data_path)\n",
    "\n",
    "\n",
    "def normalize_answer(s):\n",
    "    \"\"\"Normalize answer for comparison\"\"\"\n",
    "    s = s.lower()\n",
    "    s = re.sub(r'\\b(a|an|the)\\b', ' ', s)\n",
    "    s = ''.join(c for c in s if c not in string.punctuation)\n",
    "    return ' '.join(s.split())\n",
    "\n",
    "\n",
    "def f1_score(pred, truth):\n",
    "    \"\"\"Compute F1 score\"\"\"\n",
    "    pred_tok = normalize_answer(pred).split()\n",
    "    truth_tok = normalize_answer(truth).split()\n",
    "    \n",
    "    if not pred_tok or not truth_tok:\n",
    "        return int(pred_tok == truth_tok)\n",
    "    \n",
    "    common = Counter(pred_tok) & Counter(truth_tok)\n",
    "    if not common:\n",
    "        return 0\n",
    "    \n",
    "    prec = sum(common.values()) / len(pred_tok)\n",
    "    rec = sum(common.values()) / len(truth_tok)\n",
    "    return 2 * prec * rec / (prec + rec)\n",
    "\n",
    "\n",
    "def create_mask(seq_len, device):\n",
    "    \"\"\"Create causal mask\"\"\"\n",
    "    return (torch.triu(torch.ones(seq_len, seq_len, device=device), 1) == 0).unsqueeze(0).unsqueeze(0)\n",
    "\n",
    "\n",
    "def generate(model, tokenizer, context, question, device, max_len=50):\n",
    "    \"\"\"Generate answer\"\"\"\n",
    "    model.eval()\n",
    "    \n",
    "    prompt = f\"Q: {question} C: {context} A:\"\n",
    "    ids = tokenizer.encode(prompt, max_length=320-max_len-5, \n",
    "                          truncation=True, add_special_tokens=False, return_tensors='pt').to(device)\n",
    "    \n",
    "    start_len = ids.size(1)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for _ in range(max_len):\n",
    "            if ids.size(1) >= 320:\n",
    "                break\n",
    "            \n",
    "            mask = create_mask(ids.size(1), device)\n",
    "            logits = model(ids, mask)\n",
    "            next_tok = logits[:, -1].argmax(-1, keepdim=True)\n",
    "            ids = torch.cat([ids, next_tok], 1)\n",
    "            \n",
    "            if next_tok.item() == tokenizer.eos_token_id:\n",
    "                break\n",
    "    \n",
    "    return tokenizer.decode(ids[0, start_len:], skip_special_tokens=True).strip()\n",
    "\n",
    "\n",
    "def get_answer_token_positions(context, answer_start, answer_end, tokenizer):\n",
    "    \"\"\"\n",
    "    Get token indices that correspond to answer span in context\n",
    "    \"\"\"\n",
    "    tokens = tokenizer.tokenize(context)\n",
    "    token_to_char = []\n",
    "    \n",
    "    current_pos = 0\n",
    "    for token in tokens:\n",
    "        # Handle GPT-2 special characters\n",
    "        token_text = token.replace('Ġ', ' ').replace('Ċ', '\\n')\n",
    "        start_pos = context.find(token_text.strip(), current_pos)\n",
    "        if start_pos == -1:\n",
    "            start_pos = current_pos\n",
    "        end_pos = start_pos + len(token_text.strip())\n",
    "        token_to_char.append((start_pos, end_pos))\n",
    "        current_pos = end_pos\n",
    "    \n",
    "    # Find tokens that overlap with answer span\n",
    "    answer_token_indices = []\n",
    "    for idx, (tok_start, tok_end) in enumerate(token_to_char):\n",
    "        if not (tok_end <= answer_start or tok_start >= answer_end):\n",
    "            answer_token_indices.append(idx)\n",
    "    \n",
    "    return answer_token_indices\n",
    "\n",
    "\n",
    "def get_item_with_spans(dataset, idx):\n",
    "    \"\"\"\n",
    "    Get item with answer span information from dataset\n",
    "    Works with both Subset and regular Dataset\n",
    "    \"\"\"\n",
    "    from torch.utils.data import Subset\n",
    "    \n",
    "    if isinstance(dataset, Subset):\n",
    "        # Get original index from subset\n",
    "        original_idx = dataset.indices[idx]\n",
    "        item = dataset.dataset.data[original_idx]\n",
    "    else:\n",
    "        item = dataset.data[idx]\n",
    "    \n",
    "    # Check if item has span information\n",
    "    if 'answer_start' not in item:\n",
    "        # Need to reconstruct from original dataset\n",
    "        # This happens if you're using the old dataset format\n",
    "        return None\n",
    "    \n",
    "    return item\n",
    "\n",
    "\n",
    "def compute_sufficiency(model, tokenizer, item, device):\n",
    "    \"\"\"\n",
    "    Sufficiency: Can model answer correctly using ONLY the answer span?\n",
    "    \n",
    "    Returns: F1 score when using only answer span as context\n",
    "    \"\"\"\n",
    "    # Check if we have span information\n",
    "    if 'answer_start' not in item or 'answer_end' not in item:\n",
    "        return None\n",
    "    \n",
    "    # Extract only the answer span\n",
    "    answer_span = item['context'][item['answer_start']:item['answer_end']]\n",
    "    \n",
    "    # If span is empty, return None\n",
    "    if not answer_span.strip():\n",
    "        return None\n",
    "    \n",
    "    # Generate with only answer span as context\n",
    "    pred = generate(model, tokenizer, answer_span, item['question'], device)\n",
    "    \n",
    "    # Compare to ground truth\n",
    "    return f1_score(pred, item['answer'])\n",
    "\n",
    "\n",
    "def compute_faithfulness(model, tokenizer, item, device):\n",
    "    \"\"\"\n",
    "    Faithfulness: Does model's prediction change when answer span is removed?\n",
    "    \n",
    "    Returns: F1 drop when answer is removed (higher = more faithful)\n",
    "    \"\"\"\n",
    "    # Check if we have span information\n",
    "    if 'answer_start' not in item or 'answer_end' not in item:\n",
    "        return None\n",
    "    \n",
    "    # Generate with full context\n",
    "    pred_with = generate(model, tokenizer, item['context'], item['question'], device)\n",
    "    f1_with = f1_score(pred_with, item['answer'])\n",
    "    \n",
    "    # Generate with answer span removed\n",
    "    context_without = (item['context'][:item['answer_start']] + \n",
    "                      \" [REMOVED] \" + \n",
    "                      item['context'][item['answer_end']:])\n",
    "    pred_without = generate(model, tokenizer, context_without, item['question'], device)\n",
    "    f1_without = f1_score(pred_without, item['answer'])\n",
    "    \n",
    "    # Faithfulness = performance drop\n",
    "    faithfulness = f1_with - f1_without\n",
    "    \n",
    "    return faithfulness\n",
    "\n",
    "\n",
    "def compute_attention_on_answer_tokens(model, tokenizer, item, device):\n",
    "    \"\"\"\n",
    "    Compute average attention score on answer token positions during generation\n",
    "    \n",
    "    Returns: Mean attention score on answer tokens\n",
    "    \"\"\"\n",
    "    # Check if we have span information\n",
    "    if 'answer_start' not in item or 'answer_end' not in item:\n",
    "        return None\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    # Build prompt\n",
    "    prompt = f\"Q: {item['question']} C: {item['context']} A:\"\n",
    "    prompt_ids = tokenizer.encode(prompt, max_length=270, \n",
    "                                  truncation=True, add_special_tokens=False, \n",
    "                                  return_tensors='pt').to(device)\n",
    "    \n",
    "    # Get answer token positions in the context\n",
    "    answer_tokens = get_answer_token_positions(\n",
    "        item['context'], \n",
    "        item['answer_start'], \n",
    "        item['answer_end'], \n",
    "        tokenizer\n",
    "    )\n",
    "    \n",
    "    if len(answer_tokens) == 0:\n",
    "        return None  # No answer tokens found\n",
    "    \n",
    "    # Adjust positions for prompt structure \"Q: ... C: ...\"\n",
    "    q_prefix = f\"Q: {item['question']} C: \"\n",
    "    q_tokens_len = len(tokenizer.encode(q_prefix, add_special_tokens=False))\n",
    "    answer_positions = [q_tokens_len + idx for idx in answer_tokens]\n",
    "    \n",
    "    # Generate and track attention\n",
    "    generated = prompt_ids\n",
    "    attention_scores = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for _ in range(50):  # Max answer length\n",
    "            if generated.size(1) >= 320:\n",
    "                break\n",
    "            \n",
    "            mask = create_mask(generated.size(1), device)\n",
    "            logits = model(generated, mask, save_attention=True)\n",
    "            \n",
    "            # Get attention weights from all layers\n",
    "            attn_weights = model.get_attention_weights()\n",
    "            \n",
    "            if attn_weights[0] is not None:\n",
    "                # Average across all layers and heads\n",
    "                avg_attn = torch.stack([w[0] for w in attn_weights if w is not None]).mean(0)\n",
    "                \n",
    "                # Get attention from last generated token to all input tokens\n",
    "                last_token_attn = avg_attn[0, -1, :]  # Shape: [seq_len]\n",
    "                \n",
    "                # Extract attention scores on answer token positions\n",
    "                answer_attn = []\n",
    "                for pos in answer_positions:\n",
    "                    if pos < last_token_attn.size(0):\n",
    "                        answer_attn.append(last_token_attn[pos].item())\n",
    "                \n",
    "                if answer_attn:\n",
    "                    attention_scores.append(np.mean(answer_attn))\n",
    "            \n",
    "            # Generate next token\n",
    "            next_tok = logits[:, -1].argmax(-1, keepdim=True)\n",
    "            generated = torch.cat([generated, next_tok], 1)\n",
    "            \n",
    "            if next_tok.item() == tokenizer.eos_token_id:\n",
    "                break\n",
    "    \n",
    "    # Return mean attention score across generation steps\n",
    "    return np.mean(attention_scores) if attention_scores else None\n",
    "\n",
    "\n",
    "def analyze_model(model, dataset, tokenizer, device, n_samples=100):\n",
    "    \"\"\"\n",
    "    Run comprehensive analysis on model\n",
    "    \n",
    "    Args:\n",
    "        model: Trained GPT model\n",
    "        dataset: SQuAD dataset (or Subset) - must have answer_start/answer_end\n",
    "        tokenizer: GPT2Tokenizer\n",
    "        device: torch device\n",
    "        n_samples: Number of samples to analyze\n",
    "    \n",
    "    Returns:\n",
    "        dict with results and aggregate metrics\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    \n",
    "    # Handle Subset\n",
    "    from torch.utils.data import Subset\n",
    "    if isinstance(dataset, Subset):\n",
    "        items = [dataset.dataset.data[dataset.indices[i]] \n",
    "                for i in range(min(n_samples, len(dataset)))]\n",
    "    else:\n",
    "        items = dataset.data[:n_samples]\n",
    "    \n",
    "    # Check if first item has span information\n",
    "    if len(items) > 0 and 'answer_start' not in items[0]:\n",
    "        print(\"\\n\" + \"=\"*70)\n",
    "        print(\"ERROR: Dataset does not have answer span information!\")\n",
    "        print(\"=\"*70)\n",
    "        print(\"\\nYour dataset items need 'answer_start' and 'answer_end' keys.\")\n",
    "        print(\"Please reload your dataset with the updated SQuADDataset class that\")\n",
    "        print(\"includes span information.\")\n",
    "        print(\"\\nTo fix: Reload val_dataset using the code from the training script\")\n",
    "        print(\"that includes answer_start and answer_end in the data.\")\n",
    "        print(\"=\"*70)\n",
    "        return None\n",
    "    \n",
    "    results = {\n",
    "        'sufficiency': [],\n",
    "        'faithfulness': [],\n",
    "        'attention': [],\n",
    "        'f1': []\n",
    "    }\n",
    "    \n",
    "    print(f\"\\nAnalyzing {len(items)} samples...\")\n",
    "    print(\"=\"*70)\n",
    "    \n",
    "    valid_samples = 0\n",
    "    for item in tqdm(items, desc=\"Computing metrics\"):\n",
    "        try:\n",
    "            # Basic F1\n",
    "            pred = generate(model, tokenizer, item['context'], item['question'], device)\n",
    "            f1 = f1_score(pred, item['answer'])\n",
    "            \n",
    "            # Sufficiency\n",
    "            suff = compute_sufficiency(model, tokenizer, item, device)\n",
    "            \n",
    "            # Faithfulness\n",
    "            faith = compute_faithfulness(model, tokenizer, item, device)\n",
    "            \n",
    "            # Attention on answer tokens\n",
    "            attn = compute_attention_on_answer_tokens(model, tokenizer, item, device)\n",
    "            \n",
    "            # Only include if all metrics were computed\n",
    "            if suff is not None and faith is not None and attn is not None:\n",
    "                results['sufficiency'].append(suff)\n",
    "                results['faithfulness'].append(faith)\n",
    "                results['attention'].append(attn)\n",
    "                results['f1'].append(f1)\n",
    "                valid_samples += 1\n",
    "            \n",
    "        except Exception as e:\n",
    "            print(f\"\\nError on sample: {e}\")\n",
    "            continue\n",
    "    \n",
    "    if valid_samples == 0:\n",
    "        print(\"\\n\" + \"=\"*70)\n",
    "        print(\"ERROR: No valid samples were analyzed!\")\n",
    "        print(\"This usually means the dataset doesn't have answer span info.\")\n",
    "        print(\"=\"*70)\n",
    "        return None\n",
    "    \n",
    "    # Compute aggregate statistics\n",
    "    results['mean_sufficiency'] = np.mean(results['sufficiency'])\n",
    "    results['mean_faithfulness'] = np.mean(results['faithfulness'])\n",
    "    results['mean_attention'] = np.mean(results['attention'])\n",
    "    results['mean_f1'] = np.mean(results['f1'])\n",
    "    \n",
    "    results['std_sufficiency'] = np.std(results['sufficiency'])\n",
    "    results['std_faithfulness'] = np.std(results['faithfulness'])\n",
    "    results['std_attention'] = np.std(results['attention'])\n",
    "    \n",
    "    # Correlations\n",
    "    if len(results['f1']) > 1:\n",
    "        results['corr_f1_sufficiency'] = np.corrcoef(results['f1'], results['sufficiency'])[0, 1]\n",
    "        results['corr_f1_faithfulness'] = np.corrcoef(results['f1'], results['faithfulness'])[0, 1]\n",
    "        results['corr_f1_attention'] = np.corrcoef(results['f1'], results['attention'])[0, 1]\n",
    "    \n",
    "    # Print summary\n",
    "    print(\"\\n\" + \"=\"*70)\n",
    "    print(\"RESULTS\")\n",
    "    print(\"=\"*70)\n",
    "    print(f\"Samples analyzed: {valid_samples}\")\n",
    "    print(f\"\\nMean F1: {results['mean_f1']:.4f}\")\n",
    "    print(f\"\\nSufficiency: {results['mean_sufficiency']:.4f} ± {results['std_sufficiency']:.4f}\")\n",
    "    print(f\"  (Can model answer with only answer span?)\")\n",
    "    print(f\"\\nFaithfulness: {results['mean_faithfulness']:.4f} ± {results['std_faithfulness']:.4f}\")\n",
    "    print(f\"  (F1 drop when answer removed - higher = more faithful)\")\n",
    "    print(f\"\\nAttention on Answer Tokens: {results['mean_attention']:.4f} ± {results['std_attention']:.4f}\")\n",
    "    print(f\"  (Average attention on answer token positions)\")\n",
    "    \n",
    "    if 'corr_f1_attention' in results:\n",
    "        print(f\"\\nCorrelations with F1:\")\n",
    "        print(f\"  F1 ↔ Sufficiency:   {results['corr_f1_sufficiency']:>6.3f}\")\n",
    "        print(f\"  F1 ↔ Faithfulness:  {results['corr_f1_faithfulness']:>6.3f}\")\n",
    "        print(f\"  F1 ↔ Attention:     {results['corr_f1_attention']:>6.3f}\")\n",
    "    \n",
    "    print(\"=\"*70)\n",
    "    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa6391ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this first to load dataset with answer span information\n",
    "val_dataset_full = load_squad_with_spans('dev-v2.0.json', tokenizer)\n",
    "print(f\"Loaded {len(val_dataset_full)} validation samples with span info\")\n",
    "\n",
    "# 1. Load your baseline model\n",
    "checkpoint = torch.load('best_baseline.pt')\n",
    "model.load_state_dict(checkpoint['model'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75492ec5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now analyze your model\n",
    "baseline_results = analyze_model(\n",
    "    model=model,\n",
    "    dataset=val_dataset_full,  # Use the new dataset!\n",
    "    tokenizer=tokenizer,\n",
    "    device=device,\n",
    "    n_samples=5000\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc33cece",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this first to load dataset with answer span information\n",
    "train_dataset_full = load_squad_with_spans('train-v2.0.json', tokenizer)\n",
    "print(f\"Loaded {len(train_dataset_full)} train samples with span info\")\n",
    "\n",
    "# 1. Load your baseline model\n",
    "checkpoint = torch.load('best_baseline.pt')\n",
    "model.load_state_dict(checkpoint['model'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59099a49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now analyze your model\n",
    "train_baseline_results = analyze_model(\n",
    "    model=model,\n",
    "    dataset=train_dataset_full,  # Use the new dataset!\n",
    "    tokenizer=tokenizer,\n",
    "    device=device,\n",
    "    n_samples=60000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ced7544c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b83ff579",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ffbc04c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7418a280",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "SQuAD Answer Generation with GloVe Embeddings + Q/K Hypothesis Testing\n",
    "\n",
    "EXPECTED PERFORMANCE:\n",
    "- With GloVe embeddings: 40-55% F1 ✓\n",
    "- Training time: ~40-50 minutes\n",
    "- Can reach 50%+ with Q/K hypothesis\n",
    "\"\"\"\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader, Subset\n",
    "from transformers import GPT2Tokenizer\n",
    "import json\n",
    "from collections import Counter\n",
    "import string\n",
    "import re\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import os\n",
    "import urllib.request\n",
    "import zipfile\n",
    "\n",
    "# Configuration\n",
    "TEST_QK_HYPOTHESIS = True  # Set True after baseline completes\n",
    "QK_LR_MULTIPLIER = 20   # Q/K learn 2.5x faster\n",
    "\n",
    "# Optimized for GloVe embeddings\n",
    "D_MODEL = 300  # Match GloVe dimension exactly\n",
    "N_HEADS = 6\n",
    "N_LAYERS = 6\n",
    "D_FF = 1200\n",
    "MAX_SEQ_LEN = 320\n",
    "MAX_ANSWER_LEN = 50\n",
    "DROPOUT = 0.2\n",
    "BATCH_SIZE = 24\n",
    "ACCUMULATION_STEPS = 2  # Effective batch: 48\n",
    "BASE_LR = 5e-5\n",
    "WARMUP_STEPS = 800\n",
    "NUM_EPOCHS = 100\n",
    "GRAD_CLIP = 0.5\n",
    "WEIGHT_DECAY = 0.05\n",
    "LABEL_SMOOTHING = 0.1\n",
    "TRAIN_SUBSET_SIZE = 60000  # More data with GloVe\n",
    "VAL_SUBSET_SIZE = 10000\n",
    "\n",
    "\n",
    "def download_and_extract_glove():\n",
    "    \"\"\"Download and extract GloVe embeddings\"\"\"\n",
    "    glove_file = 'glove.6B.300d.txt'\n",
    "    \n",
    "    if os.path.exists(glove_file):\n",
    "        print(f\"✓ GloVe embeddings found: {glove_file}\")\n",
    "        return glove_file\n",
    "    \n",
    "    print(\"\\n\" + \"=\"*70)\n",
    "    print(\"DOWNLOADING GLOVE EMBEDDINGS\")\n",
    "    print(\"=\"*70)\n",
    "    \n",
    "    zip_file = 'glove.6B.zip'\n",
    "    \n",
    "    if not os.path.exists(zip_file):\n",
    "        print(\"Downloading GloVe 6B (822MB)... This may take a few minutes\")\n",
    "        url = 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.6B.zip'\n",
    "        \n",
    "        try:\n",
    "            # Download with progress bar\n",
    "            response = urllib.request.urlopen(url)\n",
    "            total_size = int(response.headers.get('content-length', 0))\n",
    "            \n",
    "            with open(zip_file, 'wb') as f, tqdm(\n",
    "                total=total_size, unit='B', unit_scale=True, desc='Downloading'\n",
    "            ) as pbar:\n",
    "                while True:\n",
    "                    chunk = response.read(8192)\n",
    "                    if not chunk:\n",
    "                        break\n",
    "                    f.write(chunk)\n",
    "                    pbar.update(len(chunk))\n",
    "            \n",
    "            print(\"✓ Download complete!\")\n",
    "        except Exception as e:\n",
    "            print(f\"Download failed: {e}\")\n",
    "            print(\"\\nAlternative: Download manually from:\")\n",
    "            print(\"  https://nlp.stanford.edu/projects/glove/\")\n",
    "            print(\"  or https://huggingface.co/stanfordnlp/glove\")\n",
    "            return None\n",
    "    \n",
    "    # Extract\n",
    "    if os.path.exists(zip_file):\n",
    "        print(\"Extracting GloVe embeddings...\")\n",
    "        with zipfile.ZipFile(zip_file, 'r') as zip_ref:\n",
    "            # Only extract the 300d file we need\n",
    "            zip_ref.extract('glove.6B.300d.txt')\n",
    "        print(\"✓ Extraction complete!\")\n",
    "        \n",
    "        # Optionally remove zip to save space\n",
    "        # os.remove(zip_file)\n",
    "    \n",
    "    if os.path.exists(glove_file):\n",
    "        return glove_file\n",
    "    else:\n",
    "        print(\"GloVe file not found after extraction\")\n",
    "        return None\n",
    "\n",
    "\n",
    "def load_glove_embeddings(glove_file, tokenizer, embedding_dim=300):\n",
    "    \"\"\"Load GloVe and create embedding matrix for GPT-2 tokenizer\"\"\"\n",
    "    print(\"\\n\" + \"=\"*70)\n",
    "    print(\"LOADING GLOVE EMBEDDINGS\")\n",
    "    print(\"=\"*70)\n",
    "    \n",
    "    # Load GloVe vectors\n",
    "    print(\"Reading GloVe file (this takes ~1 minute)...\")\n",
    "    glove_vectors = {}\n",
    "    \n",
    "    with open(glove_file, 'r', encoding='utf-8') as f:\n",
    "        for line in tqdm(f, total=400000, desc=\"Loading GloVe\"):\n",
    "            values = line.rstrip().split(' ')\n",
    "            word = values[0]\n",
    "            vector = np.asarray(values[1:], dtype='float32')\n",
    "            glove_vectors[word] = vector\n",
    "    \n",
    "    print(f\"✓ Loaded {len(glove_vectors):,} GloVe vectors\")\n",
    "    \n",
    "    # Create embedding matrix for tokenizer vocabulary\n",
    "    vocab_size = tokenizer.vocab_size\n",
    "    embedding_matrix = np.random.normal(0, 0.1, (vocab_size, embedding_dim)).astype('float32')\n",
    "    \n",
    "    # Match tokenizer vocab with GloVe\n",
    "    print(\"Matching tokenizer vocabulary with GloVe...\")\n",
    "    matched = 0\n",
    "    \n",
    "    for token, idx in tqdm(tokenizer.get_vocab().items(), desc=\"Matching\"):\n",
    "        # Try different matching strategies\n",
    "        token_clean = token.replace('Ġ', '').replace('Ċ', '').lower().strip()\n",
    "        \n",
    "        if token in glove_vectors:\n",
    "            embedding_matrix[idx] = glove_vectors[token]\n",
    "            matched += 1\n",
    "        elif token.lower() in glove_vectors:\n",
    "            embedding_matrix[idx] = glove_vectors[token.lower()]\n",
    "            matched += 1\n",
    "        elif token_clean in glove_vectors:\n",
    "            embedding_matrix[idx] = glove_vectors[token_clean]\n",
    "            matched += 1\n",
    "        # For subword tokens, try averaging character embeddings\n",
    "        elif len(token_clean) > 0 and all(c.isalpha() for c in token_clean):\n",
    "            # Use random but consistent embedding for unknown tokens\n",
    "            pass\n",
    "    \n",
    "    match_rate = 100 * matched / vocab_size\n",
    "    print(f\"✓ Matched {matched:,}/{vocab_size:,} tokens ({match_rate:.1f}%)\")\n",
    "    print(\"=\"*70 + \"\\n\")\n",
    "    \n",
    "    return torch.FloatTensor(embedding_matrix)\n",
    "\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, d_model, n_heads, dropout=0.1):\n",
    "        super().__init__()\n",
    "        assert d_model % n_heads == 0\n",
    "        self.d_k = d_model // n_heads\n",
    "        self.n_heads = n_heads\n",
    "        \n",
    "        self.q_linear = nn.Linear(d_model, d_model)\n",
    "        self.k_linear = nn.Linear(d_model, d_model)\n",
    "        self.v_linear = nn.Linear(d_model, d_model)\n",
    "        self.out = nn.Linear(d_model, d_model)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.last_attention_weights = None\n",
    "        \n",
    "    def forward(self, q, k, v, mask=None, save_attention=False):\n",
    "        bs = q.size(0)\n",
    "        \n",
    "        q = self.q_linear(q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)\n",
    "        k = self.k_linear(k).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)\n",
    "        v = self.v_linear(v).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)\n",
    "        \n",
    "        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)\n",
    "        \n",
    "        if mask is not None:\n",
    "            scores = scores.masked_fill(mask == 0, -1e9)\n",
    "        \n",
    "        attn = torch.softmax(scores, dim=-1)\n",
    "        if save_attention:\n",
    "            self.last_attention_weights = attn.detach()\n",
    "        \n",
    "        attn = self.dropout(attn)\n",
    "        context = torch.matmul(attn, v)\n",
    "        context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_k)\n",
    "        \n",
    "        return self.out(context)\n",
    "\n",
    "\n",
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)\n",
    "        self.ff = nn.Sequential(\n",
    "            nn.Linear(d_model, d_ff),\n",
    "            nn.GELU(),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(d_ff, d_model),\n",
    "            nn.Dropout(dropout)\n",
    "        )\n",
    "        self.norm1 = nn.LayerNorm(d_model)\n",
    "        self.norm2 = nn.LayerNorm(d_model)\n",
    "        \n",
    "    def forward(self, x, mask=None, save_attention=False):\n",
    "        # Pre-norm\n",
    "        attn_out = self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask, save_attention)\n",
    "        x = x + attn_out\n",
    "        x = x + self.ff(self.norm2(x))\n",
    "        return x\n",
    "\n",
    "\n",
    "class GPTAnswerGenerator(nn.Module):\n",
    "    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len, dropout=0.1, pretrained_embeddings=None):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.token_embedding = nn.Embedding(vocab_size, d_model)\n",
    "        \n",
    "        # Initialize with pretrained embeddings if provided\n",
    "        if pretrained_embeddings is not None:\n",
    "            print(\"Initializing token embeddings with GloVe...\")\n",
    "            self.token_embedding.weight.data.copy_(pretrained_embeddings)\n",
    "            print(\"✓ Token embeddings initialized with GloVe\")\n",
    "        \n",
    "        self.position_embedding = nn.Embedding(max_seq_len, d_model)\n",
    "        self.emb_dropout = nn.Dropout(dropout)\n",
    "        \n",
    "        self.layers = nn.ModuleList([\n",
    "            DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)\n",
    "        ])\n",
    "        \n",
    "        self.norm = nn.LayerNorm(d_model)\n",
    "        self.output = nn.Linear(d_model, vocab_size)\n",
    "        \n",
    "        # Weight tying\n",
    "        self.output.weight = self.token_embedding.weight\n",
    "        \n",
    "        # Initialize non-embedding weights\n",
    "        self._init_weights()\n",
    "        \n",
    "    def _init_weights(self):\n",
    "        # Don't reinitialize token_embedding if using GloVe\n",
    "        for name, p in self.named_parameters():\n",
    "            if 'token_embedding' not in name and p.dim() > 1:\n",
    "                nn.init.xavier_uniform_(p, gain=1/np.sqrt(2))\n",
    "        \n",
    "    def forward(self, x, mask=None, save_attention=False):\n",
    "        pos = torch.arange(x.size(1), device=x.device).unsqueeze(0)\n",
    "        x = self.token_embedding(x) + self.position_embedding(pos)\n",
    "        x = self.emb_dropout(x)\n",
    "        \n",
    "        for layer in self.layers:\n",
    "            x = layer(x, mask, save_attention)\n",
    "        \n",
    "        return self.output(self.norm(x))\n",
    "    \n",
    "    def get_attention_weights(self):\n",
    "        return [layer.self_attn.last_attention_weights for layer in self.layers]\n",
    "\n",
    "\n",
    "class SQuADDataset(Dataset):\n",
    "    def __init__(self, data_path, tokenizer, max_len, max_ans_len):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_len = max_len\n",
    "        self.max_ans_len = max_ans_len\n",
    "        self.data = []\n",
    "        \n",
    "        with open(data_path, 'r') as f:\n",
    "            squad = json.load(f)\n",
    "        \n",
    "        for article in squad['data']:\n",
    "            for para in article['paragraphs']:\n",
    "                ctx = para['context']\n",
    "                for qa in para['qas']:\n",
    "                    if not qa['is_impossible'] and qa['answers']:\n",
    "                        ans = qa['answers'][0]['text']\n",
    "                        ans_start = qa['answers'][0]['answer_start']\n",
    "                        \n",
    "                        # Extract relevant context window\n",
    "                        start = max(0, ans_start - 200)\n",
    "                        end = min(len(ctx), ans_start + len(ans) + 200)\n",
    "                        focused_ctx = ctx[start:end]\n",
    "                        \n",
    "                        self.data.append({\n",
    "                            'context': focused_ctx,\n",
    "                            'question': qa['question'],\n",
    "                            'answer': ans\n",
    "                        })\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        item = self.data[idx]\n",
    "        \n",
    "        # Format: Q: question C: context A: answer\n",
    "        prefix = f\"Q: {item['question']} C: {item['context']} A:\"\n",
    "        answer = f\" {item['answer']}\"\n",
    "        \n",
    "        prefix_ids = self.tokenizer.encode(prefix, max_length=self.max_len-self.max_ans_len-2, \n",
    "                                          truncation=True, add_special_tokens=False)\n",
    "        answer_ids = self.tokenizer.encode(answer, max_length=self.max_ans_len, \n",
    "                                          truncation=True, add_special_tokens=False)\n",
    "        answer_ids.append(self.tokenizer.eos_token_id)\n",
    "        \n",
    "        input_ids = prefix_ids + answer_ids\n",
    "        labels = [-100] * len(prefix_ids) + answer_ids\n",
    "        \n",
    "        # Pad\n",
    "        while len(input_ids) < self.max_len:\n",
    "            input_ids.append(self.tokenizer.pad_token_id)\n",
    "            labels.append(-100)\n",
    "        \n",
    "        return {\n",
    "            'input_ids': torch.tensor(input_ids[:self.max_len]),\n",
    "            'labels': torch.tensor(labels[:self.max_len])\n",
    "        }\n",
    "\n",
    "\n",
    "def create_mask(seq_len, device):\n",
    "    return (torch.triu(torch.ones(seq_len, seq_len, device=device), 1) == 0).unsqueeze(0).unsqueeze(0)\n",
    "\n",
    "\n",
    "def normalize_answer(s):\n",
    "    s = s.lower()\n",
    "    s = re.sub(r'\\b(a|an|the)\\b', ' ', s)\n",
    "    s = ''.join(c for c in s if c not in string.punctuation)\n",
    "    return ' '.join(s.split())\n",
    "\n",
    "\n",
    "def f1_score(pred, truth):\n",
    "    pred_tok = normalize_answer(pred).split()\n",
    "    truth_tok = normalize_answer(truth).split()\n",
    "    \n",
    "    if not pred_tok or not truth_tok:\n",
    "        return int(pred_tok == truth_tok)\n",
    "    \n",
    "    common = Counter(pred_tok) & Counter(truth_tok)\n",
    "    if not common:\n",
    "        return 0\n",
    "    \n",
    "    prec = sum(common.values()) / len(pred_tok)\n",
    "    rec = sum(common.values()) / len(truth_tok)\n",
    "    return 2 * prec * rec / (prec + rec)\n",
    "\n",
    "\n",
    "def exact_match(pred, truth):\n",
    "    return int(normalize_answer(pred) == normalize_answer(truth))\n",
    "\n",
    "\n",
    "def train_epoch(model, loader, opt, sched, device, epoch):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    opt.zero_grad()\n",
    "    \n",
    "    pbar = tqdm(loader, desc=f\"Epoch {epoch}\")\n",
    "    for i, batch in enumerate(pbar):\n",
    "        inp = batch['input_ids'].to(device)\n",
    "        lbl = batch['labels'].to(device)\n",
    "        \n",
    "        mask = create_mask(inp.size(1), device)\n",
    "        logits = model(inp, mask)\n",
    "        \n",
    "        # Shift for next-token prediction\n",
    "        loss = nn.functional.cross_entropy(\n",
    "            logits[:, :-1].reshape(-1, logits.size(-1)),\n",
    "            lbl[:, 1:].reshape(-1),\n",
    "            ignore_index=-100,\n",
    "            label_smoothing=LABEL_SMOOTHING\n",
    "        )\n",
    "        \n",
    "        loss = loss / ACCUMULATION_STEPS\n",
    "        loss.backward()\n",
    "        \n",
    "        if (i + 1) % ACCUMULATION_STEPS == 0:\n",
    "            nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n",
    "            opt.step()\n",
    "            sched.step()\n",
    "            opt.zero_grad()\n",
    "        \n",
    "        total_loss += loss.item() * ACCUMULATION_STEPS\n",
    "        pbar.set_postfix({'loss': f'{loss.item() * ACCUMULATION_STEPS:.3f}'})\n",
    "    \n",
    "    return total_loss / len(loader)\n",
    "\n",
    "\n",
    "def generate(model, tokenizer, context, question, device, max_len=50):\n",
    "    model.eval()\n",
    "    \n",
    "    prompt = f\"Q: {question} C: {context} A:\"\n",
    "    ids = tokenizer.encode(prompt, max_length=MAX_SEQ_LEN-max_len-5, \n",
    "                          truncation=True, add_special_tokens=False, return_tensors='pt').to(device)\n",
    "    \n",
    "    start_len = ids.size(1)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for _ in range(max_len):\n",
    "            if ids.size(1) >= MAX_SEQ_LEN:\n",
    "                break\n",
    "            \n",
    "            mask = create_mask(ids.size(1), device)\n",
    "            logits = model(ids, mask)\n",
    "            next_tok = logits[:, -1].argmax(-1, keepdim=True)\n",
    "            ids = torch.cat([ids, next_tok], 1)\n",
    "            \n",
    "            if next_tok.item() == tokenizer.eos_token_id:\n",
    "                break\n",
    "    \n",
    "    return tokenizer.decode(ids[0, start_len:], skip_special_tokens=True).strip()\n",
    "\n",
    "\n",
    "def evaluate(model, dataset, tokenizer, device, n_samples=300):\n",
    "    model.eval()\n",
    "    f1_sum = em_sum = 0\n",
    "    \n",
    "    if isinstance(dataset, Subset):\n",
    "        items = [dataset.dataset.data[dataset.indices[i]] for i in range(min(n_samples, len(dataset)))]\n",
    "    else:\n",
    "        items = dataset.data[:n_samples]\n",
    "    \n",
    "    for item in tqdm(items, desc=\"Eval\"):\n",
    "        pred = generate(model, tokenizer, item['context'], item['question'], device)\n",
    "        f1_sum += f1_score(pred, item['answer'])\n",
    "        em_sum += exact_match(pred, item['answer'])\n",
    "    \n",
    "    return {'f1': f1_sum / len(items), 'em': em_sum / len(items)}\n",
    "\n",
    "\n",
    "def analyze_attention(model, dataset, tokenizer, device, n=30):\n",
    "    model.eval()\n",
    "    scores = []\n",
    "    \n",
    "    if isinstance(dataset, Subset):\n",
    "        items = [dataset.dataset.data[dataset.indices[i]] for i in range(min(n, len(dataset)))]\n",
    "    else:\n",
    "        items = dataset.data[:n]\n",
    "    \n",
    "    for item in items:\n",
    "        prompt = f\"Q: {item['question']} C: {item['context']} A:\"\n",
    "        ids = tokenizer.encode(prompt, max_length=MAX_SEQ_LEN-MAX_ANSWER_LEN, \n",
    "                              truncation=True, add_special_tokens=False, return_tensors='pt').to(device)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            mask = create_mask(ids.size(1), device)\n",
    "            model(ids, mask, save_attention=True)\n",
    "            \n",
    "            weights = model.get_attention_weights()\n",
    "            if weights[0] is not None:\n",
    "                avg = torch.stack([w[0] for w in weights if w is not None]).mean(0)\n",
    "                scores.append(avg[0].mean().item())\n",
    "    \n",
    "    return np.mean(scores) if scores else 0\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    \n",
    "    print(\"=\"*70)\n",
    "    print(\"SQUAD ANSWER GENERATION WITH GLOVE EMBEDDINGS\")\n",
    "    print(\"=\"*70)\n",
    "    #print(f\"Expected F1: 40-55% (with GloVe)\")\n",
    "    print(f\"Model: {N_LAYERS}L, {D_MODEL}d, {N_HEADS}h\")\n",
    "    print(f\"Device: {device}\")\n",
    "    print(\"=\"*70 + \"\\n\")\n",
    "    \n",
    "    # Download and load GloVe\n",
    "    glove_file = download_and_extract_glove()\n",
    "    \n",
    "    if glove_file is None:\n",
    "        print(\"\\n WARNING: Could not load GloVe embeddings\")\n",
    "        print(\"Proceeding without pretrained embeddings (expect 15-25% F1)\")\n",
    "        pretrained_embeddings = None\n",
    "    \n",
    "    # Download SQuAD datasets\n",
    "    for name in ['train-v2.0.json', 'dev-v2.0.json']:\n",
    "        if not os.path.exists(name):\n",
    "            print(f\"Downloading {name}...\")\n",
    "            urllib.request.urlretrieve(\n",
    "                f'https://rajpurkar.github.io/SQuAD-explorer/dataset/{name}', name)\n",
    "    \n",
    "    # Setup tokenizer\n",
    "    print(\"Loading tokenizer...\")\n",
    "    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')# Assign a custom LR multiplier per layer\n",
    "# Example: scale grows with layer index (edit as you like)\n",
    "def qk_lr_scale(layer_idx):\n",
    "    return 1.0 + 0.2 * layer_idx     # layer0=1.0x, layer1=1.2x, layer2=1.4x, ...\n",
    "\n",
    "# Collect params\n",
    "qk_param_groups = []\n",
    "other = []\n",
    "\n",
    "for n, p in model.named_parameters():\n",
    "\n",
    "    if 'q_linear' in n or 'k_linear' in n:\n",
    "\n",
    "        # extract layer number from name (modify if your naming differs)\n",
    "        # expected: something like \"transformer.layers.3.attn.q_linear.weight\"\n",
    "        layer_id = int([x for x in n.split('.') if x.isdigit()][0])\n",
    "\n",
    "        qk_param_groups.append({\n",
    "            'params': [p],\n",
    "            'lr': BASE_LR * qk_lr_scale(layer_id),\n",
    "            'weight_decay': WEIGHT_DECAY\n",
    "        })\n",
    "    else:\n",
    "        other.append(p)\n",
    "\n",
    "print(f\"Q/K params: {sum(pg['params'][0].numel() for pg in qk_param_groups)/1e6:.1f}M\")\n",
    "print(f\"Other params: {sum(p.numel() for p in other)/1e6:.1f}M\\n\")\n",
    "\n",
    "# Build optimizer: all Q/K groups + other params group\n",
    "opt = torch.optim.AdamW(\n",
    "    qk_param_groups +\n",
    "    [{'params': other, 'lr': BASE_LR, 'weight_decay': WEIGHT_DECAY}]\n",
    ")\n",
    "\n",
    "# Build matching scheduler max_lrs\n",
    "max_lrs = [pg['lr'] for pg in qk_param_groups] + [BASE_LR]\n",
    "\n",
    "sched = torch.optim.lr_scheduler.OneCycleLR(\n",
    "    opt,\n",
    "    max_lr=max_lrs,\n",
    "    total_steps=len(loader) * NUM_EPOCHS,\n",
    "    pct_start=WARMUP_STEPS / (len(loader) * NUM_EPOCHS)\n",
    ")\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "    \n",
    "    # Load GloVe embeddings for tokenizer\n",
    "    if glove_file:\n",
    "        pretrained_embeddings = load_glove_embeddings(glove_file, tokenizer, D_MODEL)\n",
    "    else:\n",
    "        pretrained_embeddings = None\n",
    "    \n",
    "    # Load datasets\n",
    "    print(\"Loading datasets...\")\n",
    "    full_train = SQuADDataset('train-v2.0.json', tokenizer, MAX_SEQ_LEN, MAX_ANSWER_LEN)\n",
    "    full_val = SQuADDataset('dev-v2.0.json', tokenizer, MAX_SEQ_LEN, MAX_ANSWER_LEN)\n",
    "    \n",
    "    train_ds = Subset(full_train, torch.randperm(len(full_train))[:TRAIN_SUBSET_SIZE])\n",
    "    val_ds = Subset(full_val, torch.randperm(len(full_val))[:VAL_SUBSET_SIZE])\n",
    "    \n",
    "    print(f\"Train: {len(train_ds)}, Val: {len(val_ds)}\\n\")\n",
    "    \n",
    "    loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)\n",
    "    \n",
    "    # Model\n",
    "    print(\"Initializing model...\")\n",
    "    model = GPTAnswerGenerator(\n",
    "        vocab_size=tokenizer.vocab_size,\n",
    "        d_model=D_MODEL,\n",
    "        n_heads=N_HEADS,\n",
    "        n_layers=N_LAYERS,\n",
    "        d_ff=D_FF,\n",
    "        max_seq_len=MAX_SEQ_LEN,\n",
    "        dropout=DROPOUT,\n",
    "        pretrained_embeddings=pretrained_embeddings\n",
    "    ).to(device)\n",
    "    \n",
    "    total_params = sum(p.numel() for p in model.parameters()) / 1e6\n",
    "    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6\n",
    "    print(f\"Total parameters: {total_params:.1f}M\")\n",
    "    print(f\"Trainable parameters: {trainable_params:.1f}M\\n\")\n",
    "    \n",
    "    # Optimizer with differential learning rates for embeddings\n",
    "    if TEST_QK_HYPOTHESIS:\n",
    "        print(\"=\"*70)\n",
    "        print(f\"TESTING Q/K HYPOTHESIS - Q/K LR = {QK_LR_MULTIPLIER}x\")\n",
    "        print(\"=\"*70 + \"\\n\")\n",
    "        \n",
    "        qk = [p for n, p in model.named_parameters() if 'q_linear' in n or 'k_linear' in n]\n",
    "        other = [p for n, p in model.named_parameters() if 'q_linear' not in n and 'k_linear' not in n]\n",
    "        \n",
    "        print(f\"Q/K params: {sum(p.numel() for p in qk)/1e6:.1f}M\")\n",
    "        print(f\"Other params: {sum(p.numel() for p in other)/1e6:.1f}M\\n\")\n",
    "        \n",
    "        opt = torch.optim.AdamW([\n",
    "            {'params': qk, 'lr': BASE_LR * QK_LR_MULTIPLIER, 'weight_decay': WEIGHT_DECAY},\n",
    "            {'params': other, 'lr': BASE_LR, 'weight_decay': WEIGHT_DECAY}\n",
    "        ])\n",
    "        \n",
    "        sched = torch.optim.lr_scheduler.OneCycleLR(\n",
    "            opt, [BASE_LR * QK_LR_MULTIPLIER, BASE_LR],\n",
    "            total_steps=len(loader) * NUM_EPOCHS,\n",
    "            pct_start=WARMUP_STEPS / (len(loader) * NUM_EPOCHS)\n",
    "        )\n",
    "    else:\n",
    "        print(\"=\"*70)\n",
    "        print(\"BASELINE (Standard LR)\")\n",
    "        print(\"=\"*70 + \"\\n\")\n",
    "        \n",
    "        # Use lower LR for pretrained embeddings if they exist\n",
    "        if pretrained_embeddings is not None:\n",
    "            embedding_params = [model.token_embedding.weight]\n",
    "            other_params = [p for n, p in model.named_parameters() if 'token_embedding' not in n]\n",
    "            \n",
    "            opt = torch.optim.AdamW([\n",
    "                {'params': embedding_params, 'lr': BASE_LR * 0.1, 'weight_decay': 0},  # Fine-tune slowly\n",
    "                {'params': other_params, 'lr': BASE_LR, 'weight_decay': WEIGHT_DECAY}\n",
    "            ])\n",
    "            \n",
    "            print(\"Using differential LR: embeddings=0.1x, other=1.0x\\n\")\n",
    "        else:\n",
    "            opt = torch.optim.AdamW(model.parameters(), BASE_LR, weight_decay=WEIGHT_DECAY)\n",
    "        \n",
    "        sched = torch.optim.lr_scheduler.OneCycleLR(\n",
    "            opt,\n",
    "            max_lr=BASE_LR if pretrained_embeddings is None else [BASE_LR * 0.1, BASE_LR],\n",
    "            total_steps=len(loader) * NUM_EPOCHS,\n",
    "            pct_start=WARMUP_STEPS / (len(loader) * NUM_EPOCHS)\n",
    "        )\n",
    "    \n",
    "    # Train\n",
    "    best_f1 = 0\n",
    "    results = {'loss': [], 'train_f1': [], 'val_f1': [], 'val_em': [], 'attn': []}\n",
    "    \n",
    "    for e in range(NUM_EPOCHS):\n",
    "        print(f\"\\n{'='*70}\")\n",
    "        print(f\"EPOCH {e+1}/{NUM_EPOCHS}\")\n",
    "        print('='*70)\n",
    "        \n",
    "        loss = train_epoch(model, loader, opt, sched, device, e+1)\n",
    "        results['loss'].append(loss)\n",
    "        print(f\"\\nLoss: {loss:.4f}\")\n",
    "        \n",
    "        # Eval\n",
    "        train_m = evaluate(model, train_ds, tokenizer, device, 200)\n",
    "        val_m = evaluate(model, val_ds, tokenizer, device, 300)\n",
    "        \n",
    "        results['train_f1'].append(train_m['f1'])\n",
    "        results['val_f1'].append(val_m['f1'])\n",
    "        results['val_em'].append(val_m['em'])\n",
    "        \n",
    "        gap = train_m['f1'] - val_m['f1']\n",
    "        print(f\"Train F1: {train_m['f1']:.4f} | Val F1: {val_m['f1']:.4f} | Gap: {gap:.4f} | EM: {val_m['em']:.4f}\")\n",
    "        \n",
    "        # Sample\n",
    "        if e % 2 == 0:\n",
    "            item = val_ds.dataset.data[val_ds.indices[0]]\n",
    "            pred = generate(model, tokenizer, item['context'], item['question'], device)\n",
    "            print(f\"\\nSample:\")\n",
    "            print(f\"  Q: {item['question'][:60]}...\")\n",
    "            print(f\"  True: {item['answer']}\")\n",
    "            print(f\"  Pred: {pred}\")\n",
    "            print(f\"  F1: {f1_score(pred, item['answer']):.3f}\")\n",
    "        \n",
    "        # Attention\n",
    "        if e % 4 == 0 and TEST_QK_HYPOTHESIS:\n",
    "            attn = analyze_attention(model, val_ds, tokenizer, device)\n",
    "            results['attn'].append(attn)\n",
    "            print(f\"Attention: {attn:.4f}\")\n",
    "        \n",
    "        # Save best\n",
    "        if val_m['f1'] > best_f1:\n",
    "            best_f1 = val_m['f1']\n",
    "            name = 'best_qk_20x.pt' if TEST_QK_HYPOTHESIS else 'best_baseline.pt'\n",
    "            torch.save({'model': model.state_dict(), 'f1': best_f1, 'epoch': e+1}, name)\n",
    "            print(f\"✓ SAVED! Best F1: {best_f1:.4f}\")\n",
    "    \n",
    "    # Final\n",
    "    print(f\"\\n{'='*70}\")\n",
    "    print(\"FINAL RESULTS\")\n",
    "    print('='*70)\n",
    "    print(f\"Best Val F1: {best_f1*100:.1f}%\")\n",
    "    print(f\"Final Val F1: {results['val_f1'][-1]*100:.1f}%\")\n",
    "    print(f\"Final EM: {results['val_em'][-1]*100:.1f}%\")\n",
    "    print(f\"Train-Val Gap: {results['train_f1'][-1] - results['val_f1'][-1]:.4f}\")\n",
    "    \n",
    "    if pretrained_embeddings is not None:\n",
    "        if best_f1 >= 0.40:\n",
    "            print(f\"\\n✓ EXCELLENT! Hit target with GloVe embeddings!\")\n",
    "        elif best_f1 >= 0.30:\n",
    "            print(f\"\\n✓ GOOD! GloVe embeddings helping significantly\")\n",
    "        else:\n",
    "            print(f\"\\n⚠ Below expected (40%+) - may need more training\")\n",
    "    \n",
    "    print('='*70)\n",
    "    \n",
    "    name = 'results_qk_glove.pt' if TEST_QK_HYPOTHESIS else 'results_baseline_glove.pt'\n",
    "    torch.save(results, name)\n",
    "    print(f\"\\n✓ Saved to {name}\")\n",
    "    \n",
    "    if not TEST_QK_HYPOTHESIS and best_f1 > 0.30:\n",
    "        print(f\"\\n{'='*70}\")\n",
    "        print(\"BASELINE COMPLETE!\")\n",
    "        print(\"Now set TEST_QK_HYPOTHESIS=True to test your hypothesis\")\n",
    "        print(f\"Target: Beat {best_f1*100:.1f}% F1 with Q/K boosted learning\")\n",
    "        print('='*70)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9e228a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this first to load dataset with answer span information\n",
    "val_dataset_full = load_squad_with_spans('dev-v2.0.json', tokenizer)\n",
    "print(f\"Loaded {len(val_dataset_full)} validation samples with span info\")\n",
    "\n",
    "# 1. Load your baseline model\n",
    "checkpoint = torch.load('best_qk_20x.pt')\n",
    "model.load_state_dict(checkpoint['model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "546496f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now analyze your model\n",
    "baseline_results = analyze_model(\n",
    "    model=model,\n",
    "    dataset=val_dataset_full,  # Use the new dataset!\n",
    "    tokenizer=tokenizer,\n",
    "    device=device,\n",
    "    n_samples=5000\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b31f6ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this first to load dataset with answer span information\n",
    "train_dataset_full = load_squad_with_spans('train-v2.0.json', tokenizer)\n",
    "print(f\"Loaded {len(train_dataset_full)} train samples with span info\")\n",
    "\n",
    "# 1. Load your baseline model\n",
    "checkpoint = torch.load('best_qk_20x.pt')\n",
    "model.load_state_dict(checkpoint['model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0b5dd6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now analyze your model\n",
    "baseline_results_train = analyze_model(\n",
    "    model=model,\n",
    "    dataset=train_dataset_full,  # Use the new dataset!\n",
    "    tokenizer=tokenizer,\n",
    "    device=device,\n",
    "    n_samples=60000\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4917194b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae0f4119",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
