{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "9742f0d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import dataset\n",
    "from models.clam import CLAM\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import os\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "1a3624ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47945\n",
      "65450\n",
      "65450\n",
      "17504\n",
      "17504\n",
      "17504\n"
     ]
    }
   ],
   "source": [
    "#Real Dataset\n",
    "\n",
    "\n",
    "\n",
    "real_df_1 = \"real_songs_2.csv\"\n",
    "real_df_1 = pd.read_csv(real_df_1)\n",
    "real_df_2 = \"real_yt_covers.csv\"\n",
    "real_df_2 = pd.read_csv(real_df_2)\n",
    "\n",
    "\n",
    "real_mert_1 = \"real_songs_mert\"\n",
    "real_mert_2 = \"yt_covers_mert\"\n",
    "\n",
    "real_wav2vec2_1 = \"real_songs_wav2vec2\"\n",
    "real_wav2vec2_2 = \"yt_covers_wav2vec2\"\n",
    "\n",
    "print(len(real_df_1))\n",
    "print(len(os.listdir(real_mert_1)))\n",
    "print(len(os.listdir(real_wav2vec2_1)))\n",
    "\n",
    "print(len(real_df_2))\n",
    "print(len(os.listdir(real_mert_2)))\n",
    "print(len(os.listdir(real_wav2vec2_2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "7baef53b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "fake_df = \"ai_generated_music_metadata.csv\"\n",
    "fake_df = pd.read_csv(fake_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "836f41c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>model_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ai_covers_0.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ai_covers_1.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ai_covers_2.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>ai_covers_3.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ai_covers_4.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          filename model_name\n",
       "0  ai_covers_0.mp3  AI_COVERS\n",
       "1  ai_covers_1.mp3  AI_COVERS\n",
       "2  ai_covers_2.mp3  AI_COVERS\n",
       "3  ai_covers_3.mp3  AI_COVERS\n",
       "4  ai_covers_4.mp3  AI_COVERS"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fake_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "e23ec37b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64972\n",
      "64958\n",
      "64958\n"
     ]
    }
   ],
   "source": [
    "#Fake Dataset\n",
    "\n",
    "\n",
    "fake_df = \"ai_generated_music_metadata.csv\"\n",
    "fake_df = pd.read_csv(fake_df)\n",
    "fake_mert = \"ai_generated_music_mert\"\n",
    "fake_wav2vec2 = \"ai_generated_music_wav2vec2\"\n",
    "\n",
    "\n",
    "print(len(fake_df))\n",
    "print(len(os.listdir(fake_mert)))\n",
    "print(len(os.listdir(fake_wav2vec2)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "8c40f0b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 47945/47945 [00:12<00:00, 3700.12it/s]\n",
      "100%|██████████| 17504/17504 [00:01<00:00, 10312.79it/s]\n",
      "100%|██████████| 64972/64972 [00:17<00:00, 3676.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47945\n",
      "17504\n",
      "64958\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def remove_filename(df, folder):\n",
    "    filenames = os.listdir(folder)\n",
    "    #make copy of df\n",
    "    df_copy = df.copy()\n",
    "    for index, row in tqdm(df_copy.iterrows(), total=df_copy.shape[0]):\n",
    "        filename = row['filename']\n",
    "        if filename.endswith('.mp3'):\n",
    "            filename = filename[:-4]\n",
    "        if filename.endswith('.wav'):\n",
    "            filename = filename[:-4]\n",
    "\n",
    "        if filename + '.pt' not in filenames:\n",
    "            df_copy.drop(index, inplace=True)\n",
    "    return df_copy\n",
    "\n",
    "\n",
    "real_df_1 = remove_filename(real_df_1, real_mert_1)\n",
    "real_df_2 = remove_filename(real_df_2, real_mert_2)\n",
    "fake_df = remove_filename(fake_df, fake_mert)\n",
    "print(len(real_df_1))\n",
    "print(len(real_df_2))\n",
    "print(len(fake_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "6e654958",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47945\n",
      "17504\n"
     ]
    }
   ],
   "source": [
    "print(len(real_df_1))\n",
    "print(len(real_df_2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "b9facf9d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65449\n"
     ]
    }
   ],
   "source": [
    "real_df = pd.concat([real_df_1, real_df_2])\n",
    "print(len(real_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "2b17865a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "model_name\n",
       "suno_3.5     23695\n",
       "Udio_1.5     19500\n",
       "riffusion     7043\n",
       "Yue           5278\n",
       "Diffrythm     4606\n",
       "suno_3        3512\n",
       "AI_COVERS     1166\n",
       "suno_2         110\n",
       "suno_4          48\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fake_df[\"model_name\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "901d4ef6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17047\n",
      "47911\n",
      "0.7375688906678162\n",
      "48273\n",
      "17176\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "fake_df_test = fake_df[fake_df[\"model_name\"].isin([\"riffusion\", \"suno_3\", \"AI_COVERS\"  , 'suno_4' , 'Yue'])]\n",
    "print(len(fake_df_test))\n",
    "fake_df_train = fake_df[~fake_df[\"model_name\"].isin([\"riffusion\", \"suno_3\", \"AI_COVERS\" , \"suno_4\" , \"Yue\"] )]\n",
    "print(len(fake_df_train))\n",
    "\n",
    "ratio = len(fake_df_train) / ( len(fake_df_test) + len(fake_df_train)  )\n",
    "print(ratio)\n",
    "\n",
    "real_df = pd.concat([real_df_1, real_df_2])\n",
    "\n",
    "real_df_train, real_df_test = train_test_split(real_df, test_size= 1 - ratio, random_state=42)\n",
    "print(len(real_df_train))\n",
    "print(len(real_df_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "af6715f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preloading embeddings for train split...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  24%|██▍       | 11644/48273 [00:18<00:59, 617.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_8382.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  33%|███▎      | 15907/48273 [00:25<00:56, 573.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_5219.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  58%|█████▊    | 28051/48273 [00:45<00:33, 611.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_11880.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  60%|█████▉    | 28942/48273 [00:46<00:30, 629.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_9557.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  67%|██████▋   | 32432/48273 [00:52<00:26, 596.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PytorchStreamReader failed locating file data/0: file not found\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_11542.pt: PytorchStreamReader failed locating file data/0: file not found. Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU: 100%|██████████| 48273/48273 [01:18<00:00, 615.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preloading embeddings for train split...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:   1%|▏         | 642/47911 [00:01<01:21, 576.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for diffrythm_566.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  10%|█         | 4918/47911 [00:08<01:12, 594.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_122.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  11%|█         | 5098/47911 [00:08<01:11, 594.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_278.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  14%|█▍        | 6919/47911 [00:11<01:09, 593.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_2113.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  35%|███▍      | 16573/47911 [00:28<01:07, 461.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_11812.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  49%|████▉     | 23435/47911 [00:40<00:39, 619.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_18613.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  59%|█████▉    | 28483/47911 [00:48<00:33, 586.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_23683.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  61%|██████    | 29094/47911 [00:49<00:31, 594.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for udio_1_5_607.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  75%|███████▍  | 35726/47911 [01:00<00:19, 627.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for udio_1_5_7188.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU:  97%|█████████▋| 46695/47911 [01:24<00:02, 483.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for udio_1_5_18205.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train data to CPU: 100%|██████████| 47911/47911 [01:26<00:00, 550.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preloading embeddings for test split...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU:   5%|▌         | 944/17176 [00:01<00:27, 595.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_13185.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU:   7%|▋         | 1251/17176 [00:02<00:26, 610.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_9011.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU:  44%|████▎     | 7510/17176 [00:12<00:15, 620.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_10928.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU:  61%|██████▏   | 10551/17176 [00:17<00:10, 609.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_7006.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU:  77%|███████▋  | 13280/17176 [00:21<00:06, 611.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_5289.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU:  80%|████████  | 13768/17176 [00:22<00:05, 593.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PytorchStreamReader failed locating file data/0: file not found\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_7797.pt: PytorchStreamReader failed locating file data/0: file not found. Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU:  98%|█████████▊| 16863/17176 [00:27<00:00, 617.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_15133.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU: 100%|██████████| 17176/17176 [00:28<00:00, 607.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preloading embeddings for test split...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU:  47%|████▋     | 7961/17047 [00:13<00:15, 586.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_1802.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU:  48%|████▊     | 8261/17047 [00:14<00:15, 584.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_2149.pt: . Skipping sample.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test data to CPU: 100%|██████████| 17047/17047 [00:29<00:00, 584.36it/s]\n"
     ]
    }
   ],
   "source": [
    "from utils import dataset\n",
    "\n",
    "train_real_dataset = dataset.SongsDataset(real_df_train, real_mert_1, real_wav2vec2_1, label = 0)\n",
    "train_fake_dataset = dataset.SongsDataset(fake_df_train, fake_mert, fake_wav2vec2, label = 1)\n",
    "\n",
    "test_real_dataset = dataset.SongsDataset(real_df_test, real_mert_1, real_wav2vec2_1, label = 0 , split = \"test\")\n",
    "test_fake_dataset = dataset.SongsDataset(fake_df_test, fake_mert, fake_wav2vec2, label = 1 , split = \"test\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "296bb56e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "48268\n",
      "47901\n",
      "17169\n",
      "17045\n"
     ]
    }
   ],
   "source": [
    "print(len(train_real_dataset))\n",
    "print(len(train_fake_dataset))\n",
    "print(len(test_real_dataset))\n",
    "print(len(test_fake_dataset))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "738fc507",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "96169\n",
      "34214\n"
     ]
    }
   ],
   "source": [
    "train_dataset = torch.utils.data.ConcatDataset([train_real_dataset, train_fake_dataset])\n",
    "test_dataset = torch.utils.data.ConcatDataset([test_real_dataset, test_fake_dataset])\n",
    "print(len(train_dataset))\n",
    "print(len(test_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "359ca8a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Training...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1/50 Training: 100%|██████████| 4809/4809 [00:34<00:00, 139.99batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/50\n",
      "  Train Total Loss: 0.2594 | CLS Loss: 0.2510 | Align Loss (raw): 0.0170\n",
      "  Train Acc: 0.8870 | F1: 0.8842 | Recall: 0.8665 | Precision: 0.9025\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1/50 Validation: 100%|██████████| 1203/1203 [00:02<00:00, 440.96batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1407 | CLS Loss: 0.1329 | Align Loss (raw): 0.0156\n",
      "  Val Acc: 0.9574 | F1: 0.9586 | Recall: 0.9860 | Precision: 0.9328\n",
      "  New Best F1: 0.9586\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2/50 Training: 100%|██████████| 4809/4809 [00:27<00:00, 175.38batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/50\n",
      "  Train Total Loss: 0.1040 | CLS Loss: 0.0981 | Align Loss (raw): 0.0119\n",
      "  Train Acc: 0.9667 | F1: 0.9664 | Recall: 0.9640 | Precision: 0.9688\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 750.26batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0744 | CLS Loss: 0.0703 | Align Loss (raw): 0.0082\n",
      "  Val Acc: 0.9769 | F1: 0.9767 | Recall: 0.9696 | Precision: 0.9840\n",
      "  New Best F1: 0.9767\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 164.27batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/50\n",
      "  Train Total Loss: 0.0786 | CLS Loss: 0.0748 | Align Loss (raw): 0.0075\n",
      "  Train Acc: 0.9745 | F1: 0.9743 | Recall: 0.9731 | Precision: 0.9755\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 722.97batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0956 | CLS Loss: 0.0912 | Align Loss (raw): 0.0089\n",
      "  Val Acc: 0.9701 | F1: 0.9707 | Recall: 0.9925 | Precision: 0.9499\n",
      "  Model not saved. Best F1 so far: 0.9767\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4/50 Training: 100%|██████████| 4809/4809 [00:31<00:00, 154.45batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/50\n",
      "  Train Total Loss: 0.0647 | CLS Loss: 0.0618 | Align Loss (raw): 0.0059\n",
      "  Train Acc: 0.9793 | F1: 0.9792 | Recall: 0.9783 | Precision: 0.9801\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 685.72batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0724 | CLS Loss: 0.0703 | Align Loss (raw): 0.0043\n",
      "  Val Acc: 0.9741 | F1: 0.9735 | Recall: 0.9535 | Precision: 0.9945\n",
      "  Model not saved. Best F1 so far: 0.9767\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5/50 Training: 100%|██████████| 4809/4809 [00:27<00:00, 172.10batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/50\n",
      "  Train Total Loss: 0.0566 | CLS Loss: 0.0541 | Align Loss (raw): 0.0049\n",
      "  Train Acc: 0.9812 | F1: 0.9811 | Recall: 0.9797 | Precision: 0.9825\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 732.36batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0463 | CLS Loss: 0.0438 | Align Loss (raw): 0.0051\n",
      "  Val Acc: 0.9851 | F1: 0.9852 | Recall: 0.9912 | Precision: 0.9793\n",
      "  New Best F1: 0.9852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 171.51batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/50\n",
      "  Train Total Loss: 0.0511 | CLS Loss: 0.0488 | Align Loss (raw): 0.0045\n",
      "  Train Acc: 0.9838 | F1: 0.9837 | Recall: 0.9829 | Precision: 0.9846\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 746.57batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0468 | CLS Loss: 0.0453 | Align Loss (raw): 0.0031\n",
      "  Val Acc: 0.9846 | F1: 0.9845 | Recall: 0.9749 | Precision: 0.9943\n",
      "  Model not saved. Best F1 so far: 0.9852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7/50 Training: 100%|██████████| 4809/4809 [00:27<00:00, 177.00batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/50\n",
      "  Train Total Loss: 0.0479 | CLS Loss: 0.0458 | Align Loss (raw): 0.0041\n",
      "  Train Acc: 0.9840 | F1: 0.9839 | Recall: 0.9835 | Precision: 0.9843\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 742.96batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0524 | CLS Loss: 0.0505 | Align Loss (raw): 0.0039\n",
      "  Val Acc: 0.9841 | F1: 0.9843 | Recall: 0.9956 | Precision: 0.9733\n",
      "  Model not saved. Best F1 so far: 0.9852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 170.08batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/50\n",
      "  Train Total Loss: 0.0443 | CLS Loss: 0.0424 | Align Loss (raw): 0.0039\n",
      "  Train Acc: 0.9853 | F1: 0.9852 | Recall: 0.9847 | Precision: 0.9857\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 713.64batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0723 | CLS Loss: 0.0699 | Align Loss (raw): 0.0049\n",
      "  Val Acc: 0.9763 | F1: 0.9758 | Recall: 0.9554 | Precision: 0.9971\n",
      "  Model not saved. Best F1 so far: 0.9852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9/50 Training: 100%|██████████| 4809/4809 [00:27<00:00, 172.36batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/50\n",
      "  Train Total Loss: 0.0407 | CLS Loss: 0.0388 | Align Loss (raw): 0.0037\n",
      "  Train Acc: 0.9866 | F1: 0.9865 | Recall: 0.9858 | Precision: 0.9873\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 735.45batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0702 | CLS Loss: 0.0687 | Align Loss (raw): 0.0030\n",
      "  Val Acc: 0.9750 | F1: 0.9745 | Recall: 0.9541 | Precision: 0.9959\n",
      "  Model not saved. Best F1 so far: 0.9852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 168.19batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10/50\n",
      "  Train Total Loss: 0.0391 | CLS Loss: 0.0374 | Align Loss (raw): 0.0033\n",
      "  Train Acc: 0.9871 | F1: 0.9870 | Recall: 0.9863 | Precision: 0.9878\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 736.04batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0766 | CLS Loss: 0.0747 | Align Loss (raw): 0.0036\n",
      "  Val Acc: 0.9732 | F1: 0.9725 | Recall: 0.9494 | Precision: 0.9968\n",
      "  Model not saved. Best F1 so far: 0.9852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11/50 Training: 100%|██████████| 4809/4809 [00:27<00:00, 173.89batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11/50\n",
      "  Train Total Loss: 0.0382 | CLS Loss: 0.0365 | Align Loss (raw): 0.0033\n",
      "  Train Acc: 0.9874 | F1: 0.9873 | Recall: 0.9868 | Precision: 0.9879\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 670.62batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0418 | CLS Loss: 0.0407 | Align Loss (raw): 0.0022\n",
      "  Val Acc: 0.9864 | F1: 0.9863 | Recall: 0.9776 | Precision: 0.9951\n",
      "  New Best F1: 0.9863\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 163.87batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12/50\n",
      "  Train Total Loss: 0.0363 | CLS Loss: 0.0347 | Align Loss (raw): 0.0032\n",
      "  Train Acc: 0.9880 | F1: 0.9880 | Recall: 0.9873 | Precision: 0.9886\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 704.31batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0580 | CLS Loss: 0.0554 | Align Loss (raw): 0.0050\n",
      "  Val Acc: 0.9798 | F1: 0.9801 | Recall: 0.9930 | Precision: 0.9674\n",
      "  Model not saved. Best F1 so far: 0.9863\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13/50 Training: 100%|██████████| 4809/4809 [00:27<00:00, 171.99batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13/50\n",
      "  Train Total Loss: 0.0329 | CLS Loss: 0.0313 | Align Loss (raw): 0.0031\n",
      "  Train Acc: 0.9899 | F1: 0.9899 | Recall: 0.9891 | Precision: 0.9906\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 735.10batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0420 | CLS Loss: 0.0404 | Align Loss (raw): 0.0031\n",
      "  Val Acc: 0.9870 | F1: 0.9868 | Recall: 0.9779 | Precision: 0.9960\n",
      "  New Best F1: 0.9868\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 162.88batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14/50\n",
      "  Train Total Loss: 0.0324 | CLS Loss: 0.0309 | Align Loss (raw): 0.0031\n",
      "  Train Acc: 0.9891 | F1: 0.9891 | Recall: 0.9887 | Precision: 0.9894\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 672.23batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0277 | CLS Loss: 0.0265 | Align Loss (raw): 0.0024\n",
      "  Val Acc: 0.9912 | F1: 0.9912 | Recall: 0.9904 | Precision: 0.9919\n",
      "  New Best F1: 0.9912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 163.16batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15/50\n",
      "  Train Total Loss: 0.0319 | CLS Loss: 0.0304 | Align Loss (raw): 0.0029\n",
      "  Train Acc: 0.9893 | F1: 0.9892 | Recall: 0.9887 | Precision: 0.9897\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 722.09batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0388 | CLS Loss: 0.0375 | Align Loss (raw): 0.0026\n",
      "  Val Acc: 0.9872 | F1: 0.9872 | Recall: 0.9874 | Precision: 0.9870\n",
      "  Model not saved. Best F1 so far: 0.9912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 167.07batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16/50\n",
      "  Train Total Loss: 0.0299 | CLS Loss: 0.0284 | Align Loss (raw): 0.0030\n",
      "  Train Acc: 0.9901 | F1: 0.9900 | Recall: 0.9894 | Precision: 0.9906\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 692.95batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0325 | CLS Loss: 0.0315 | Align Loss (raw): 0.0021\n",
      "  Val Acc: 0.9894 | F1: 0.9893 | Recall: 0.9819 | Precision: 0.9968\n",
      "  Model not saved. Best F1 so far: 0.9912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 169.16batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17/50\n",
      "  Train Total Loss: 0.0311 | CLS Loss: 0.0296 | Align Loss (raw): 0.0029\n",
      "  Train Acc: 0.9898 | F1: 0.9898 | Recall: 0.9891 | Precision: 0.9905\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 694.30batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0411 | CLS Loss: 0.0402 | Align Loss (raw): 0.0019\n",
      "  Val Acc: 0.9850 | F1: 0.9849 | Recall: 0.9740 | Precision: 0.9960\n",
      "  Model not saved. Best F1 so far: 0.9912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 171.29batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18/50\n",
      "  Train Total Loss: 0.0287 | CLS Loss: 0.0273 | Align Loss (raw): 0.0028\n",
      "  Train Acc: 0.9906 | F1: 0.9905 | Recall: 0.9900 | Precision: 0.9911\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 660.80batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0347 | CLS Loss: 0.0331 | Align Loss (raw): 0.0033\n",
      "  Val Acc: 0.9895 | F1: 0.9896 | Recall: 0.9952 | Precision: 0.9840\n",
      "  Model not saved. Best F1 so far: 0.9912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 162.64batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19/50\n",
      "  Train Total Loss: 0.0278 | CLS Loss: 0.0265 | Align Loss (raw): 0.0028\n",
      "  Train Acc: 0.9906 | F1: 0.9906 | Recall: 0.9902 | Precision: 0.9910\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 735.65batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0375 | CLS Loss: 0.0363 | Align Loss (raw): 0.0025\n",
      "  Val Acc: 0.9883 | F1: 0.9882 | Recall: 0.9795 | Precision: 0.9970\n",
      "  Model not saved. Best F1 so far: 0.9912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 168.66batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20/50\n",
      "  Train Total Loss: 0.0275 | CLS Loss: 0.0261 | Align Loss (raw): 0.0027\n",
      "  Train Acc: 0.9904 | F1: 0.9904 | Recall: 0.9899 | Precision: 0.9909\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 670.37batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0349 | CLS Loss: 0.0340 | Align Loss (raw): 0.0017\n",
      "  Val Acc: 0.9880 | F1: 0.9879 | Recall: 0.9792 | Precision: 0.9967\n",
      "  Model not saved. Best F1 so far: 0.9912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 162.66batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 21/50\n",
      "  Train Total Loss: 0.0255 | CLS Loss: 0.0241 | Align Loss (raw): 0.0028\n",
      "  Train Acc: 0.9919 | F1: 0.9918 | Recall: 0.9917 | Precision: 0.9920\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 692.52batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0253 | CLS Loss: 0.0242 | Align Loss (raw): 0.0023\n",
      "  Val Acc: 0.9916 | F1: 0.9916 | Recall: 0.9915 | Precision: 0.9917\n",
      "  New Best F1: 0.9916\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22/50 Training: 100%|██████████| 4809/4809 [00:27<00:00, 172.11batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 22/50\n",
      "  Train Total Loss: 0.0256 | CLS Loss: 0.0243 | Align Loss (raw): 0.0027\n",
      "  Train Acc: 0.9914 | F1: 0.9914 | Recall: 0.9909 | Precision: 0.9918\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 725.02batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0234 | CLS Loss: 0.0221 | Align Loss (raw): 0.0027\n",
      "  Val Acc: 0.9920 | F1: 0.9920 | Recall: 0.9899 | Precision: 0.9941\n",
      "  New Best F1: 0.9920\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 167.74batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 23/50\n",
      "  Train Total Loss: 0.0248 | CLS Loss: 0.0235 | Align Loss (raw): 0.0026\n",
      "  Train Acc: 0.9917 | F1: 0.9916 | Recall: 0.9913 | Precision: 0.9920\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 729.83batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0403 | CLS Loss: 0.0386 | Align Loss (raw): 0.0033\n",
      "  Val Acc: 0.9874 | F1: 0.9875 | Recall: 0.9965 | Precision: 0.9787\n",
      "  Model not saved. Best F1 so far: 0.9920\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24/50 Training: 100%|██████████| 4809/4809 [00:30<00:00, 160.10batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 24/50\n",
      "  Train Total Loss: 0.0245 | CLS Loss: 0.0232 | Align Loss (raw): 0.0028\n",
      "  Train Acc: 0.9921 | F1: 0.9920 | Recall: 0.9915 | Precision: 0.9925\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 655.17batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0290 | CLS Loss: 0.0277 | Align Loss (raw): 0.0025\n",
      "  Val Acc: 0.9908 | F1: 0.9909 | Recall: 0.9953 | Precision: 0.9865\n",
      "  Model not saved. Best F1 so far: 0.9920\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 163.78batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 25/50\n",
      "  Train Total Loss: 0.0238 | CLS Loss: 0.0224 | Align Loss (raw): 0.0026\n",
      "  Train Acc: 0.9922 | F1: 0.9922 | Recall: 0.9915 | Precision: 0.9929\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25/50 Validation: 100%|██████████| 1203/1203 [00:02<00:00, 541.34batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0266 | CLS Loss: 0.0247 | Align Loss (raw): 0.0039\n",
      "  Val Acc: 0.9915 | F1: 0.9914 | Recall: 0.9866 | Precision: 0.9963\n",
      "  Model not saved. Best F1 so far: 0.9920\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 162.64batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 26/50\n",
      "  Train Total Loss: 0.0229 | CLS Loss: 0.0216 | Align Loss (raw): 0.0026\n",
      "  Train Acc: 0.9924 | F1: 0.9924 | Recall: 0.9918 | Precision: 0.9929\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 733.49batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0237 | CLS Loss: 0.0226 | Align Loss (raw): 0.0023\n",
      "  Val Acc: 0.9921 | F1: 0.9921 | Recall: 0.9887 | Precision: 0.9955\n",
      "  New Best F1: 0.9921\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 165.51batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 27/50\n",
      "  Train Total Loss: 0.0230 | CLS Loss: 0.0218 | Align Loss (raw): 0.0025\n",
      "  Train Acc: 0.9926 | F1: 0.9925 | Recall: 0.9920 | Precision: 0.9931\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 721.75batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0249 | CLS Loss: 0.0239 | Align Loss (raw): 0.0020\n",
      "  Val Acc: 0.9915 | F1: 0.9915 | Recall: 0.9882 | Precision: 0.9949\n",
      "  Model not saved. Best F1 so far: 0.9921\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 167.61batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 28/50\n",
      "  Train Total Loss: 0.0211 | CLS Loss: 0.0198 | Align Loss (raw): 0.0026\n",
      "  Train Acc: 0.9930 | F1: 0.9930 | Recall: 0.9926 | Precision: 0.9933\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 675.86batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0242 | CLS Loss: 0.0231 | Align Loss (raw): 0.0022\n",
      "  Val Acc: 0.9918 | F1: 0.9919 | Recall: 0.9928 | Precision: 0.9909\n",
      "  Model not saved. Best F1 so far: 0.9921\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 169.85batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29/50\n",
      "  Train Total Loss: 0.0221 | CLS Loss: 0.0209 | Align Loss (raw): 0.0026\n",
      "  Train Acc: 0.9928 | F1: 0.9928 | Recall: 0.9925 | Precision: 0.9931\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 703.97batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0239 | CLS Loss: 0.0227 | Align Loss (raw): 0.0025\n",
      "  Val Acc: 0.9914 | F1: 0.9914 | Recall: 0.9892 | Precision: 0.9936\n",
      "  Model not saved. Best F1 so far: 0.9921\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 164.47batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 30/50\n",
      "  Train Total Loss: 0.0208 | CLS Loss: 0.0195 | Align Loss (raw): 0.0026\n",
      "  Train Acc: 0.9933 | F1: 0.9932 | Recall: 0.9930 | Precision: 0.9935\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 701.74batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0230 | CLS Loss: 0.0218 | Align Loss (raw): 0.0025\n",
      "  Val Acc: 0.9928 | F1: 0.9928 | Recall: 0.9905 | Precision: 0.9950\n",
      "  New Best F1: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 163.71batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 31/50\n",
      "  Train Total Loss: 0.0205 | CLS Loss: 0.0192 | Align Loss (raw): 0.0026\n",
      "  Train Acc: 0.9932 | F1: 0.9931 | Recall: 0.9929 | Precision: 0.9934\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 657.09batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0225 | CLS Loss: 0.0214 | Align Loss (raw): 0.0023\n",
      "  Val Acc: 0.9925 | F1: 0.9925 | Recall: 0.9944 | Precision: 0.9906\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 165.42batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 32/50\n",
      "  Train Total Loss: 0.0204 | CLS Loss: 0.0191 | Align Loss (raw): 0.0025\n",
      "  Train Acc: 0.9934 | F1: 0.9933 | Recall: 0.9930 | Precision: 0.9937\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 653.18batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0304 | CLS Loss: 0.0293 | Align Loss (raw): 0.0022\n",
      "  Val Acc: 0.9903 | F1: 0.9903 | Recall: 0.9844 | Precision: 0.9962\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 163.01batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 33/50\n",
      "  Train Total Loss: 0.0189 | CLS Loss: 0.0176 | Align Loss (raw): 0.0025\n",
      "  Train Acc: 0.9942 | F1: 0.9942 | Recall: 0.9942 | Precision: 0.9942\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 621.40batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0240 | CLS Loss: 0.0223 | Align Loss (raw): 0.0034\n",
      "  Val Acc: 0.9917 | F1: 0.9917 | Recall: 0.9925 | Precision: 0.9910\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 168.41batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 34/50\n",
      "  Train Total Loss: 0.0188 | CLS Loss: 0.0176 | Align Loss (raw): 0.0024\n",
      "  Train Acc: 0.9940 | F1: 0.9940 | Recall: 0.9937 | Precision: 0.9943\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 743.13batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0235 | CLS Loss: 0.0223 | Align Loss (raw): 0.0024\n",
      "  Val Acc: 0.9924 | F1: 0.9924 | Recall: 0.9899 | Precision: 0.9949\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 161.79batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 35/50\n",
      "  Train Total Loss: 0.0180 | CLS Loss: 0.0167 | Align Loss (raw): 0.0025\n",
      "  Train Acc: 0.9943 | F1: 0.9943 | Recall: 0.9939 | Precision: 0.9946\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 628.79batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0245 | CLS Loss: 0.0230 | Align Loss (raw): 0.0028\n",
      "  Val Acc: 0.9917 | F1: 0.9917 | Recall: 0.9927 | Precision: 0.9908\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36/50 Training: 100%|██████████| 4809/4809 [00:30<00:00, 159.26batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 36/50\n",
      "  Train Total Loss: 0.0181 | CLS Loss: 0.0169 | Align Loss (raw): 0.0024\n",
      "  Train Acc: 0.9942 | F1: 0.9941 | Recall: 0.9941 | Precision: 0.9942\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 643.90batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0438 | CLS Loss: 0.0423 | Align Loss (raw): 0.0031\n",
      "  Val Acc: 0.9861 | F1: 0.9863 | Recall: 0.9982 | Precision: 0.9746\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 167.30batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 37/50\n",
      "  Train Total Loss: 0.0179 | CLS Loss: 0.0167 | Align Loss (raw): 0.0023\n",
      "  Train Acc: 0.9946 | F1: 0.9946 | Recall: 0.9943 | Precision: 0.9948\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 683.80batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0249 | CLS Loss: 0.0236 | Align Loss (raw): 0.0027\n",
      "  Val Acc: 0.9920 | F1: 0.9921 | Recall: 0.9935 | Precision: 0.9907\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 163.39batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 38/50\n",
      "  Train Total Loss: 0.0171 | CLS Loss: 0.0160 | Align Loss (raw): 0.0023\n",
      "  Train Acc: 0.9945 | F1: 0.9945 | Recall: 0.9941 | Precision: 0.9949\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 667.11batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0287 | CLS Loss: 0.0279 | Align Loss (raw): 0.0017\n",
      "  Val Acc: 0.9906 | F1: 0.9906 | Recall: 0.9843 | Precision: 0.9969\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 169.11batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 39/50\n",
      "  Train Total Loss: 0.0179 | CLS Loss: 0.0167 | Align Loss (raw): 0.0024\n",
      "  Train Acc: 0.9941 | F1: 0.9940 | Recall: 0.9935 | Precision: 0.9946\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 664.85batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0411 | CLS Loss: 0.0395 | Align Loss (raw): 0.0032\n",
      "  Val Acc: 0.9879 | F1: 0.9881 | Recall: 0.9982 | Precision: 0.9781\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 165.28batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 40/50\n",
      "  Train Total Loss: 0.0164 | CLS Loss: 0.0151 | Align Loss (raw): 0.0025\n",
      "  Train Acc: 0.9947 | F1: 0.9947 | Recall: 0.9942 | Precision: 0.9952\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 678.02batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0303 | CLS Loss: 0.0291 | Align Loss (raw): 0.0024\n",
      "  Val Acc: 0.9918 | F1: 0.9917 | Recall: 0.9863 | Precision: 0.9973\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 166.47batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 41/50\n",
      "  Train Total Loss: 0.0173 | CLS Loss: 0.0161 | Align Loss (raw): 0.0024\n",
      "  Train Acc: 0.9943 | F1: 0.9943 | Recall: 0.9940 | Precision: 0.9946\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 715.50batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0313 | CLS Loss: 0.0302 | Align Loss (raw): 0.0021\n",
      "  Val Acc: 0.9900 | F1: 0.9899 | Recall: 0.9824 | Precision: 0.9975\n",
      "  Model not saved. Best F1 so far: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 166.30batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 42/50\n",
      "  Train Total Loss: 0.0171 | CLS Loss: 0.0159 | Align Loss (raw): 0.0024\n",
      "  Train Acc: 0.9945 | F1: 0.9945 | Recall: 0.9939 | Precision: 0.9950\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 664.51batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0196 | CLS Loss: 0.0186 | Align Loss (raw): 0.0019\n",
      "  Val Acc: 0.9939 | F1: 0.9939 | Recall: 0.9923 | Precision: 0.9954\n",
      "  New Best F1: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 164.90batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 43/50\n",
      "  Train Total Loss: 0.0157 | CLS Loss: 0.0146 | Align Loss (raw): 0.0022\n",
      "  Train Acc: 0.9948 | F1: 0.9948 | Recall: 0.9945 | Precision: 0.9951\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 692.08batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0214 | CLS Loss: 0.0203 | Align Loss (raw): 0.0022\n",
      "  Val Acc: 0.9925 | F1: 0.9925 | Recall: 0.9941 | Precision: 0.9909\n",
      "  Model not saved. Best F1 so far: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 170.39batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 44/50\n",
      "  Train Total Loss: 0.0165 | CLS Loss: 0.0153 | Align Loss (raw): 0.0024\n",
      "  Train Acc: 0.9946 | F1: 0.9945 | Recall: 0.9942 | Precision: 0.9948\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 678.23batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0266 | CLS Loss: 0.0255 | Align Loss (raw): 0.0021\n",
      "  Val Acc: 0.9923 | F1: 0.9923 | Recall: 0.9953 | Precision: 0.9893\n",
      "  Model not saved. Best F1 so far: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 163.78batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 45/50\n",
      "  Train Total Loss: 0.0154 | CLS Loss: 0.0142 | Align Loss (raw): 0.0024\n",
      "  Train Acc: 0.9950 | F1: 0.9949 | Recall: 0.9950 | Precision: 0.9949\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 672.57batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0227 | CLS Loss: 0.0216 | Align Loss (raw): 0.0021\n",
      "  Val Acc: 0.9928 | F1: 0.9928 | Recall: 0.9901 | Precision: 0.9954\n",
      "  Model not saved. Best F1 so far: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 164.94batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 46/50\n",
      "  Train Total Loss: 0.0149 | CLS Loss: 0.0138 | Align Loss (raw): 0.0022\n",
      "  Train Acc: 0.9952 | F1: 0.9952 | Recall: 0.9949 | Precision: 0.9956\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 631.79batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0415 | CLS Loss: 0.0405 | Align Loss (raw): 0.0018\n",
      "  Val Acc: 0.9871 | F1: 0.9869 | Recall: 0.9757 | Precision: 0.9984\n",
      "  Model not saved. Best F1 so far: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 161.08batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 47/50\n",
      "  Train Total Loss: 0.0163 | CLS Loss: 0.0151 | Align Loss (raw): 0.0023\n",
      "  Train Acc: 0.9944 | F1: 0.9944 | Recall: 0.9938 | Precision: 0.9950\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 630.97batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0333 | CLS Loss: 0.0318 | Align Loss (raw): 0.0030\n",
      "  Val Acc: 0.9903 | F1: 0.9904 | Recall: 0.9977 | Precision: 0.9831\n",
      "  Model not saved. Best F1 so far: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 162.27batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 48/50\n",
      "  Train Total Loss: 0.0151 | CLS Loss: 0.0139 | Align Loss (raw): 0.0023\n",
      "  Train Acc: 0.9953 | F1: 0.9953 | Recall: 0.9953 | Precision: 0.9953\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48/50 Validation: 100%|██████████| 1203/1203 [00:02<00:00, 573.80batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0217 | CLS Loss: 0.0208 | Align Loss (raw): 0.0019\n",
      "  Val Acc: 0.9929 | F1: 0.9929 | Recall: 0.9900 | Precision: 0.9957\n",
      "  Model not saved. Best F1 so far: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49/50 Training: 100%|██████████| 4809/4809 [00:28<00:00, 168.77batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 49/50\n",
      "  Train Total Loss: 0.0150 | CLS Loss: 0.0138 | Align Loss (raw): 0.0023\n",
      "  Train Acc: 0.9954 | F1: 0.9954 | Recall: 0.9952 | Precision: 0.9956\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 655.62batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0249 | CLS Loss: 0.0237 | Align Loss (raw): 0.0024\n",
      "  Val Acc: 0.9915 | F1: 0.9915 | Recall: 0.9876 | Precision: 0.9954\n",
      "  Model not saved. Best F1 so far: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50/50 Training: 100%|██████████| 4809/4809 [00:29<00:00, 164.68batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50/50\n",
      "  Train Total Loss: 0.0152 | CLS Loss: 0.0141 | Align Loss (raw): 0.0023\n",
      "  Train Acc: 0.9951 | F1: 0.9951 | Recall: 0.9949 | Precision: 0.9953\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50/50 Validation: 100%|██████████| 1203/1203 [00:01<00:00, 643.18batch/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0251 | CLS Loss: 0.0242 | Align Loss (raw): 0.0017\n",
      "  Val Acc: 0.9923 | F1: 0.9923 | Recall: 0.9875 | Precision: 0.9971\n",
      "  Model not saved. Best F1 so far: 0.9939\n",
      "Training Finished.\n",
      "Best Validation Accuracy achieved: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n",
    "from torch.optim import Adam , AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "from sklearn.metrics import roc_curve\n",
    "import numpy as np\n",
    "\n",
    "def compute_eer(y_true, y_scores):\n",
    "    fpr, tpr, thresholds = roc_curve(y_true, y_scores)\n",
    "    fnr = 1 - tpr\n",
    "    # Find the point where FPR = FNR\n",
    "    eer_threshold_index = np.nanargmin(np.absolute((fnr - fpr)))\n",
    "    eer = (fpr[eer_threshold_index] + fnr[eer_threshold_index]) / 2\n",
    "    return eer\n",
    "\n",
    "\n",
    "# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "val_split = 0.2\n",
    "train_size = int((1 - val_split) * len(train_dataset))\n",
    "val_size = len(train_dataset) - train_size\n",
    "train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size] , generator=torch.Generator().manual_seed(42))\n",
    "train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "in_channel1 = 13\n",
    "in_channel2 = 13\n",
    "embed_dim1 = 768\n",
    "embed_dim2 = 768\n",
    "best_acc = 0\n",
    "best_f1 = 0\n",
    "\n",
    "model = CLAM(in_channel1 , in_channel2 , embed_dim1 , embed_dim2).to(device)\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=1e-4)\n",
    "\n",
    "classification_criterion = nn.BCEWithLogitsLoss() # For fake/real classification\n",
    "alignment_criterion = nn.MSELoss()             # For aligning embeddings (real vs fake)\n",
    "save_name = \"best_model_mse_alignment.pt\"\n",
    "save_folder = \"model_wts\"\n",
    "# --- Hyperparameters ---\n",
    "epochs = 50\n",
    "alignment_loss_weight = 0.5 # Weight factor for the alignment loss \n",
    "print(\"Starting Training...\")\n",
    "\n",
    "\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    model.train()\n",
    "    train_loss_total = 0\n",
    "    train_loss_cls = 0\n",
    "    train_loss_align = 0\n",
    "    train_preds_all = []\n",
    "    train_labels_all = []\n",
    "\n",
    "    for batch in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{epochs} Training\", unit=\"batch\"):\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        embed1, embed2, labels = batch\n",
    "        embed1 = embed1.to(device)\n",
    "        embed2 = embed2.to(device)\n",
    "        labels = labels.float().to(device).view(-1)\n",
    "\n",
    "        # --- Forward pass ---\n",
    "        outputs , real_emb, fake_emb  = model.forward_training(embed1, embed2)\n",
    "        outputs = outputs.view(-1) \n",
    "\n",
    "        # --- Calculate Classification Loss (for all samples) ---\n",
    "        classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    " \n",
    "        real_indices = (labels == 0).nonzero(as_tuple=True)[0]\n",
    "\n",
    "        alignment_loss = torch.tensor(0.0).to(device) # Initialize alignment loss for this batch\n",
    "        if real_indices.nelement() > 0: # Check if there are any real samples in the batch\n",
    "            # 2. Select the embeddings corresponding to real samples\n",
    "            real_emb_filtered = real_emb[real_indices]\n",
    "            fake_emb_filtered = fake_emb[real_indices]\n",
    "\n",
    "            # 3. Calculate alignment loss (e.g., MSE between real and fake embeddings for real samples)\n",
    "            alignment_loss = alignment_criterion(real_emb_filtered, fake_emb_filtered)\n",
    "\n",
    "        # --- Combine Losses ---\n",
    "        total_loss = classification_loss + alignment_loss_weight * alignment_loss\n",
    "\n",
    "        # --- Backward pass and optimization ---\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # --- Accumulate metrics and losses for reporting ---\n",
    "        train_loss_total += total_loss.item()\n",
    "        train_loss_cls += classification_loss.item()\n",
    "        train_loss_align += alignment_loss.item() # Note: this is the raw alignment loss before weighting\n",
    "\n",
    "        preds = torch.sigmoid(outputs).detach().round() # Get predictions (0 or 1)\n",
    "        train_preds_all.extend(preds.cpu().numpy())\n",
    "        train_labels_all.extend(labels.cpu().numpy())\n",
    "\n",
    "    # --- Calculate Epoch Metrics (Training) ---\n",
    "    num_batches = len(train_loader)\n",
    "    avg_train_loss_total = train_loss_total / num_batches\n",
    "    avg_train_loss_cls = train_loss_cls / num_batches\n",
    "    avg_train_loss_align = train_loss_align / num_batches # Average raw alignment loss across batches where it was calculated\n",
    "\n",
    "    train_acc = accuracy_score(train_labels_all, train_preds_all)\n",
    "    train_f1 = f1_score(train_labels_all, train_preds_all, average='binary', zero_division=0)\n",
    "    train_recall = recall_score(train_labels_all, train_preds_all, average='binary', zero_division=0)\n",
    "    train_precision = precision_score(train_labels_all, train_preds_all, average='binary', zero_division=0)\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{epochs}\")\n",
    "    print(f\"  Train Total Loss: {avg_train_loss_total:.4f} | CLS Loss: {avg_train_loss_cls:.4f} | Align Loss (raw): {avg_train_loss_align:.4f}\")\n",
    "    print(f\"  Train Acc: {train_acc:.4f} | F1: {train_f1:.4f} | Recall: {train_recall:.4f} | Precision: {train_precision:.4f}\")\n",
    "\n",
    "\n",
    "    # --- Validation ---\n",
    "    model.eval()\n",
    "    val_loss_total = 0\n",
    "    val_loss_cls = 0\n",
    "    val_loss_align = 0 \n",
    "    val_preds_all = []\n",
    "    val_labels_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(val_loader, desc=f\"Epoch {epoch+1}/{epochs} Validation\", unit=\"batch\"):\n",
    "            embed1, embed2, labels = batch\n",
    "            embed1 = embed1.to(device)\n",
    "            embed2 = embed2.to(device)\n",
    "            labels = labels.float().to(device).view(-1)\n",
    "\n",
    "            # --- Forward pass (validation) ---\n",
    "            outputs , real_emb, fake_emb  = model.forward_training(embed1, embed2)\n",
    "            outputs = outputs.view(-1)\n",
    "\n",
    "            # --- Calculate Classification Loss (Validation) ---\n",
    "            classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    "            real_indices = (labels == 0).nonzero(as_tuple=True)[0]\n",
    "            alignment_loss = torch.tensor(0.0).to(device)\n",
    "            if real_indices.nelement() > 0:\n",
    "                real_emb_filtered = real_emb[real_indices]\n",
    "                fake_emb_filtered = fake_emb[real_indices]\n",
    "                alignment_loss = alignment_criterion(real_emb_filtered, fake_emb_filtered)\n",
    "\n",
    "            total_loss = classification_loss + alignment_loss_weight * alignment_loss\n",
    "\n",
    "            val_loss_total += total_loss.item()\n",
    "            val_loss_cls += classification_loss.item()\n",
    "            val_loss_align += alignment_loss.item()\n",
    "\n",
    "            preds = torch.sigmoid(outputs).detach().round()\n",
    "            val_preds_all.extend(preds.cpu().numpy())\n",
    "            val_labels_all.extend(labels.cpu().numpy())\n",
    "\n",
    "    num_batches_val = len(val_loader)\n",
    "    avg_val_loss_total = val_loss_total / num_batches_val\n",
    "    avg_val_loss_cls = val_loss_cls / num_batches_val\n",
    "    avg_val_loss_align = val_loss_align / num_batches_val\n",
    "\n",
    "    val_acc = accuracy_score(val_labels_all, val_preds_all)\n",
    "    val_f1 = f1_score(val_labels_all, val_preds_all, average='binary', zero_division=0)\n",
    "    val_recall = recall_score(val_labels_all, val_preds_all, average='binary', zero_division=0)\n",
    "    val_precision = precision_score(val_labels_all, val_preds_all, average='binary', zero_division=0)\n",
    "\n",
    "    print(f\"  Val Total Loss: {avg_val_loss_total:.4f} | CLS Loss: {avg_val_loss_cls:.4f} | Align Loss (raw): {avg_val_loss_align:.4f}\")\n",
    "    print(f\"  Val Acc: {val_acc:.4f} | F1: {val_f1:.4f} | Recall: {val_recall:.4f} | Precision: {val_precision:.4f}\")\n",
    "\n",
    "    current_val_metric = val_acc \n",
    "\n",
    "    #Best F1\n",
    "    if val_f1 > best_f1:\n",
    "        best_f1 = val_f1\n",
    "        print(f\"  New Best F1: {best_f1:.4f}\")\n",
    "        torch.save(model.state_dict(), os.path.join(save_folder, save_name))\n",
    "    else:\n",
    "        print(f\"  Model not saved. Best F1 so far: {best_f1:.4f}\")\n",
    "\n",
    "\n",
    "\n",
    "print(\"Training Finished.\")\n",
    "print(f\"Best Validation Accuracy achieved: {best_acc:.4f}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "2426b781",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Testing: 100%|██████████| 2139/2139 [00:05<00:00, 362.92batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Total Loss: 0.4141 | CLS Loss: 0.3729 | Align Loss (raw): 0.0824\n",
      "Test Acc: 0.9163 | F1: 0.9088 | Recall: 0.8374 | Precision: 0.9937\n",
      "EER: 0.0649\n"
     ]
    }
   ],
   "source": [
    "#On test set\n",
    "model.load_state_dict(torch.load(os.path.join(save_folder, f\"{save_name}\")))\n",
    "model.eval()\n",
    "test_loss_total = 0\n",
    "test_loss_cls = 0\n",
    "test_loss_align = 0\n",
    "test_preds_all = []\n",
    "test_labels_all = []\n",
    "test_scores_all = []    \n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(test_loader, desc=\"Testing\", unit=\"batch\"):\n",
    "        embed1, embed2, labels = batch\n",
    "        embed1 = embed1.to(device)\n",
    "        embed2 = embed2.to(device)\n",
    "        labels = labels.float().to(device).view(-1)\n",
    "\n",
    "        # --- Forward pass (testing) ---\n",
    "        outputs , real_emb, fake_emb  = model.forward_training(embed1, embed2)\n",
    "        outputs = outputs.view(-1)\n",
    "\n",
    "        # --- Calculate Classification Loss (Testing) ---\n",
    "        classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    "        real_indices = (labels == 1).nonzero(as_tuple=True)[0]\n",
    "        alignment_loss = torch.tensor(0.0).to(device)\n",
    "        if real_indices.nelement() > 0:\n",
    "            real_emb_filtered = real_emb[real_indices]\n",
    "            fake_emb_filtered = fake_emb[real_indices]\n",
    "            alignment_loss = alignment_criterion(real_emb_filtered, fake_emb_filtered)\n",
    "\n",
    "        total_loss = classification_loss + alignment_loss_weight * alignment_loss\n",
    "\n",
    "        test_loss_total += total_loss.item()\n",
    "        test_loss_cls += classification_loss.item()\n",
    "        test_loss_align += alignment_loss.item()\n",
    "\n",
    "        preds = torch.sigmoid(outputs).detach().round()\n",
    "        test_preds_all.extend(preds.cpu().numpy())\n",
    "        test_labels_all.extend(labels.cpu().numpy())\n",
    "        test_scores_all.extend(torch.sigmoid(outputs).cpu().numpy())\n",
    "\n",
    "    num_batches_test = len(test_loader)\n",
    "    avg_test_loss_total = test_loss_total / num_batches_test\n",
    "    avg_test_loss_cls = test_loss_cls / num_batches_test    \n",
    "    avg_test_loss_align = test_loss_align / num_batches_test\n",
    "    test_acc = accuracy_score(test_labels_all, test_preds_all)  \n",
    "    test_f1 = f1_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    test_recall = recall_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    test_precision = precision_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    eer = compute_eer(test_labels_all, test_scores_all)  # Calculate EER using the raw scores\n",
    "    print(f\"Test Total Loss: {avg_test_loss_total:.4f} | CLS Loss: {avg_test_loss_cls:.4f} | Align Loss (raw): {avg_test_loss_align:.4f}\")\n",
    "    print(f\"Test Acc: {test_acc:.4f} | F1: {test_f1:.4f} | Recall: {test_recall:.4f} | Precision: {test_precision:.4f}\")\n",
    "    print(f\"EER: {eer:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
