{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9742f0d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import dataset\n",
    "from models.clam import CLAM\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import os\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1a3624ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47945\n",
      "65450\n",
      "65450\n",
      "17504\n",
      "17504\n",
      "17504\n"
     ]
    }
   ],
   "source": [
    "#Real Dataset\n",
    "\n",
    "\n",
    "\n",
    "real_df_1 = \"real_songs_2.csv\"\n",
    "real_df_1 = pd.read_csv(real_df_1)\n",
    "real_df_2 = \"real_yt_covers.csv\"\n",
    "real_df_2 = pd.read_csv(real_df_2)\n",
    "\n",
    "\n",
    "real_mert_1 = \"real_songs_mert\"\n",
    "real_mert_2 = \"yt_covers_mert\"\n",
    "\n",
    "real_wav2vec2_1 = \"real_songs_wav2vec2\"\n",
    "real_wav2vec2_2 = \"yt_covers_wav2vec2\"\n",
    "\n",
    "print(len(real_df_1))\n",
    "print(len(os.listdir(real_mert_1)))\n",
    "print(len(os.listdir(real_wav2vec2_1)))\n",
    "\n",
    "print(len(real_df_2))\n",
    "print(len(os.listdir(real_mert_2)))\n",
    "print(len(os.listdir(real_wav2vec2_2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7baef53b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "fake_df = \"ai_generated_music_metadata.csv\"\n",
    "fake_df = pd.read_csv(fake_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "836f41c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>model_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ai_covers_0.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ai_covers_1.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ai_covers_2.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>ai_covers_3.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ai_covers_4.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          filename model_name\n",
       "0  ai_covers_0.mp3  AI_COVERS\n",
       "1  ai_covers_1.mp3  AI_COVERS\n",
       "2  ai_covers_2.mp3  AI_COVERS\n",
       "3  ai_covers_3.mp3  AI_COVERS\n",
       "4  ai_covers_4.mp3  AI_COVERS"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fake_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e23ec37b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64972\n",
      "64958\n",
      "64958\n"
     ]
    }
   ],
   "source": [
    "#Fake Dataset\n",
    "\n",
    "\n",
    "fake_df = \"ai_generated_music_metadata.csv\"\n",
    "fake_df = pd.read_csv(fake_df)\n",
    "fake_mert = \"ai_generated_music_mert\"\n",
    "fake_wav2vec2 = \"ai_generated_music_wav2vec2\"\n",
    "\n",
    "\n",
    "print(len(fake_df))\n",
    "print(len(os.listdir(fake_mert)))\n",
    "print(len(os.listdir(fake_wav2vec2)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8c40f0b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 47945/47945 [00:16<00:00, 2964.25it/s]\n",
      "100%|██████████| 17504/17504 [00:02<00:00, 8556.84it/s]\n",
      "100%|██████████| 64972/64972 [00:21<00:00, 2978.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47945\n",
      "17504\n",
      "64958\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def remove_filename(df, folder):\n",
    "    filenames = os.listdir(folder)\n",
    "    #make copy of df\n",
    "    df_copy = df.copy()\n",
    "    for index, row in tqdm(df_copy.iterrows(), total=df_copy.shape[0]):\n",
    "        filename = row['filename']\n",
    "        if filename.endswith('.mp3'):\n",
    "            filename = filename[:-4]\n",
    "        if filename.endswith('.wav'):\n",
    "            filename = filename[:-4]\n",
    "\n",
    "        if filename + '.pt' not in filenames:\n",
    "            df_copy.drop(index, inplace=True)\n",
    "    return df_copy\n",
    "\n",
    "\n",
    "real_df_1 = remove_filename(real_df_1, real_mert_1)\n",
    "real_df_2 = remove_filename(real_df_2, real_mert_2)\n",
    "fake_df = remove_filename(fake_df, fake_mert)\n",
    "print(len(real_df_1))\n",
    "print(len(real_df_2))\n",
    "print(len(fake_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6e654958",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17504\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(len(real_df_2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b9facf9d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65449\n"
     ]
    }
   ],
   "source": [
    "real_df = pd.concat([real_df_1, real_df_2])\n",
    "print(len(real_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2b17865a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "model_name\n",
       "suno_3.5     23695\n",
       "Udio_1.5     19500\n",
       "riffusion     7043\n",
       "Yue           5278\n",
       "Diffrythm     4606\n",
       "suno_3        3512\n",
       "AI_COVERS     1166\n",
       "suno_2         110\n",
       "suno_4          48\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fake_df[\"model_name\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "901d4ef6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17047\n",
      "47911\n",
      "0.7375688906678162\n",
      "48273\n",
      "17176\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "fake_df_test = fake_df[fake_df[\"model_name\"].isin([\"riffusion\", \"suno_3\", \"AI_COVERS\"  , 'suno_4' , 'Yue'])]\n",
    "print(len(fake_df_test))\n",
    "fake_df_train = fake_df[~fake_df[\"model_name\"].isin([\"riffusion\", \"suno_3\", \"AI_COVERS\" , \"suno_4\" , \"Yue\"] )]\n",
    "print(len(fake_df_train))\n",
    "\n",
    "ratio = len(fake_df_train) / ( len(fake_df_test) + len(fake_df_train)  )\n",
    "print(ratio)\n",
    "\n",
    "real_df = pd.concat([real_df_1, real_df_2])\n",
    "\n",
    "real_df_train, real_df_test = train_test_split(real_df, test_size= 1 - ratio, random_state=42)\n",
    "print(len(real_df_train))\n",
    "print(len(real_df_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "af6715f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preloading embeddings for train split...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27673df5feaf4e9bbff738cd838f60d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading train data to CPU:   0%|          | 0/48273 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/run/media/arnesh/MySSD/dataset/utils/dataset.py:32: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  mert = torch.load(mert_path, map_location=torch.device('cpu'))\n",
      "/run/media/arnesh/MySSD/dataset/utils/dataset.py:36: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  wav2vec2 = torch.load(wav2vec2_path, map_location=torch.device('cpu'))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_8382.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_5219.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_11880.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_9557.pt: Ran out of input. Skipping sample.\n",
      "PytorchStreamReader failed locating file data/0: file not found\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_11542.pt: PytorchStreamReader failed locating file data/0: file not found. Skipping sample.\n",
      "Preloading embeddings for train split...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4925c768f7aa45ccafab2b0daa109d81",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading train data to CPU:   0%|          | 0/47911 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for diffrythm_566.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_122.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_278.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_2113.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_11812.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_18613.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_23683.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for udio_1_5_607.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for udio_1_5_7188.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for udio_1_5_18205.pt: Ran out of input. Skipping sample.\n",
      "Preloading embeddings for test split...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5170ba31ea0944e2830280766d611091",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading test data to CPU:   0%|          | 0/17176 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_13185.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_9011.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_10928.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_7006.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_5289.pt: Ran out of input. Skipping sample.\n",
      "PytorchStreamReader failed locating file data/0: file not found\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_7797.pt: PytorchStreamReader failed locating file data/0: file not found. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_15133.pt: Ran out of input. Skipping sample.\n",
      "Preloading embeddings for test split...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c7a5d64ae1ac4343975041bc322715c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading test data to CPU:   0%|          | 0/17047 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_1802.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_2149.pt: Ran out of input. Skipping sample.\n"
     ]
    }
   ],
   "source": [
    "from utils import dataset\n",
    "\n",
    "train_real_dataset = dataset.SongsDataset(real_df_train, real_mert_1, real_wav2vec2_1, label = 0)\n",
    "train_fake_dataset = dataset.SongsDataset(fake_df_train, fake_mert, fake_wav2vec2, label = 1)\n",
    "\n",
    "test_real_dataset = dataset.SongsDataset(real_df_test, real_mert_1, real_wav2vec2_1, label = 0 , split = \"test\")\n",
    "test_fake_dataset = dataset.SongsDataset(fake_df_test, fake_mert, fake_wav2vec2, label = 1 , split = \"test\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "296bb56e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "48268\n",
      "47901\n",
      "17169\n",
      "17045\n"
     ]
    }
   ],
   "source": [
    "print(len(train_real_dataset))\n",
    "print(len(train_fake_dataset))\n",
    "print(len(test_real_dataset))\n",
    "print(len(test_fake_dataset))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "738fc507",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "96169\n",
      "34214\n"
     ]
    }
   ],
   "source": [
    "train_dataset = torch.utils.data.ConcatDataset([train_real_dataset, train_fake_dataset])\n",
    "test_dataset = torch.utils.data.ConcatDataset([test_real_dataset, test_fake_dataset])\n",
    "print(len(train_dataset))\n",
    "print(len(test_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "359ca8a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Training...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 154.15batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/50\n",
      "  Train Total Loss: 0.4242 | CLS Loss: 0.4176 | Align Loss (raw): 0.0132\n",
      "  Train Acc: 0.7721 | F1: 0.7606 | Recall: 0.7326 | Precision: 0.7907\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 570.69batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.3361 | CLS Loss: 0.3312 | Align Loss (raw): 0.0098\n",
      "  Val Acc: 0.8404 | F1: 0.8590 | Recall: 0.9785 | Precision: 0.7654\n",
      "  New Best F1: 0.8590. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 158.68batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/50\n",
      "  Train Total Loss: 0.2112 | CLS Loss: 0.2088 | Align Loss (raw): 0.0049\n",
      "  Train Acc: 0.9095 | F1: 0.9074 | Recall: 0.8981 | Precision: 0.9170\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 563.82batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1824 | CLS Loss: 0.1800 | Align Loss (raw): 0.0048\n",
      "  Val Acc: 0.9292 | F1: 0.9313 | Recall: 0.9665 | Precision: 0.8985\n",
      "  New Best F1: 0.9313. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 156.41batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/50\n",
      "  Train Total Loss: 0.1481 | CLS Loss: 0.1461 | Align Loss (raw): 0.0040\n",
      "  Train Acc: 0.9452 | F1: 0.9443 | Recall: 0.9403 | Precision: 0.9484\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 606.43batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1325 | CLS Loss: 0.1302 | Align Loss (raw): 0.0045\n",
      "  Val Acc: 0.9537 | F1: 0.9543 | Recall: 0.9744 | Precision: 0.9350\n",
      "  New Best F1: 0.9543. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.34batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/50\n",
      "  Train Total Loss: 0.1137 | CLS Loss: 0.1118 | Align Loss (raw): 0.0037\n",
      "  Train Acc: 0.9590 | F1: 0.9584 | Recall: 0.9554 | Precision: 0.9614\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 602.23batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1036 | CLS Loss: 0.1029 | Align Loss (raw): 0.0016\n",
      "  Val Acc: 0.9636 | F1: 0.9627 | Recall: 0.9466 | Precision: 0.9794\n",
      "  New Best F1: 0.9627. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.21batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/50\n",
      "  Train Total Loss: 0.0921 | CLS Loss: 0.0906 | Align Loss (raw): 0.0030\n",
      "  Train Acc: 0.9673 | F1: 0.9669 | Recall: 0.9665 | Precision: 0.9673\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 645.90batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0941 | CLS Loss: 0.0935 | Align Loss (raw): 0.0013\n",
      "  Val Acc: 0.9680 | F1: 0.9672 | Recall: 0.9489 | Precision: 0.9862\n",
      "  New Best F1: 0.9672. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 157.64batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/50\n",
      "  Train Total Loss: 0.0787 | CLS Loss: 0.0776 | Align Loss (raw): 0.0022\n",
      "  Train Acc: 0.9722 | F1: 0.9719 | Recall: 0.9715 | Precision: 0.9723\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 641.71batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0790 | CLS Loss: 0.0781 | Align Loss (raw): 0.0018\n",
      "  Val Acc: 0.9737 | F1: 0.9734 | Recall: 0.9668 | Precision: 0.9800\n",
      "  New Best F1: 0.9734. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 162.11batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/50\n",
      "  Train Total Loss: 0.0733 | CLS Loss: 0.0719 | Align Loss (raw): 0.0027\n",
      "  Train Acc: 0.9745 | F1: 0.9742 | Recall: 0.9735 | Precision: 0.9749\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 629.78batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0758 | CLS Loss: 0.0744 | Align Loss (raw): 0.0029\n",
      "  Val Acc: 0.9764 | F1: 0.9762 | Recall: 0.9734 | Precision: 0.9789\n",
      "  New Best F1: 0.9762. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 158.56batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/50\n",
      "  Train Total Loss: 0.0648 | CLS Loss: 0.0639 | Align Loss (raw): 0.0017\n",
      "  Train Acc: 0.9772 | F1: 0.9769 | Recall: 0.9768 | Precision: 0.9770\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 615.72batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0783 | CLS Loss: 0.0759 | Align Loss (raw): 0.0048\n",
      "  Val Acc: 0.9764 | F1: 0.9762 | Recall: 0.9770 | Precision: 0.9755\n",
      "  New Best F1: 0.9762. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 158.34batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/50\n",
      "  Train Total Loss: 0.0603 | CLS Loss: 0.0593 | Align Loss (raw): 0.0019\n",
      "  Train Acc: 0.9792 | F1: 0.9789 | Recall: 0.9787 | Precision: 0.9791\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 637.93batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1062 | CLS Loss: 0.1053 | Align Loss (raw): 0.0018\n",
      "  Val Acc: 0.9615 | F1: 0.9600 | Recall: 0.9284 | Precision: 0.9937\n",
      "  Model not saved. Best F1 so far: 0.9762\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.67batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10/50\n",
      "  Train Total Loss: 0.0581 | CLS Loss: 0.0571 | Align Loss (raw): 0.0019\n",
      "  Train Acc: 0.9801 | F1: 0.9799 | Recall: 0.9803 | Precision: 0.9795\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 586.70batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0632 | CLS Loss: 0.0604 | Align Loss (raw): 0.0056\n",
      "  Val Acc: 0.9808 | F1: 0.9808 | Recall: 0.9854 | Precision: 0.9762\n",
      "  New Best F1: 0.9808. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 159.44batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11/50\n",
      "  Train Total Loss: 0.0513 | CLS Loss: 0.0503 | Align Loss (raw): 0.0019\n",
      "  Train Acc: 0.9820 | F1: 0.9818 | Recall: 0.9821 | Precision: 0.9815\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 636.66batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1342 | CLS Loss: 0.1340 | Align Loss (raw): 0.0006\n",
      "  Val Acc: 0.9590 | F1: 0.9571 | Recall: 0.9205 | Precision: 0.9967\n",
      "  Model not saved. Best F1 so far: 0.9808\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 162.99batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12/50\n",
      "  Train Total Loss: 0.0471 | CLS Loss: 0.0464 | Align Loss (raw): 0.0015\n",
      "  Train Acc: 0.9826 | F1: 0.9824 | Recall: 0.9827 | Precision: 0.9822\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 618.88batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0662 | CLS Loss: 0.0632 | Align Loss (raw): 0.0061\n",
      "  Val Acc: 0.9827 | F1: 0.9826 | Recall: 0.9788 | Precision: 0.9863\n",
      "  New Best F1: 0.9826. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.51batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13/50\n",
      "  Train Total Loss: 0.0456 | CLS Loss: 0.0448 | Align Loss (raw): 0.0018\n",
      "  Train Acc: 0.9845 | F1: 0.9843 | Recall: 0.9842 | Precision: 0.9843\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 649.14batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0792 | CLS Loss: 0.0755 | Align Loss (raw): 0.0074\n",
      "  Val Acc: 0.9772 | F1: 0.9773 | Recall: 0.9916 | Precision: 0.9635\n",
      "  Model not saved. Best F1 so far: 0.9826\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 158.19batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14/50\n",
      "  Train Total Loss: 0.0448 | CLS Loss: 0.0442 | Align Loss (raw): 0.0012\n",
      "  Train Acc: 0.9843 | F1: 0.9841 | Recall: 0.9841 | Precision: 0.9841\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 630.02batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0569 | CLS Loss: 0.0562 | Align Loss (raw): 0.0014\n",
      "  Val Acc: 0.9808 | F1: 0.9808 | Recall: 0.9880 | Precision: 0.9738\n",
      "  Model not saved. Best F1 so far: 0.9826\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.17batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15/50\n",
      "  Train Total Loss: 0.0420 | CLS Loss: 0.0414 | Align Loss (raw): 0.0011\n",
      "  Train Acc: 0.9854 | F1: 0.9853 | Recall: 0.9851 | Precision: 0.9854\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 625.01batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0616 | CLS Loss: 0.0611 | Align Loss (raw): 0.0010\n",
      "  Val Acc: 0.9808 | F1: 0.9805 | Recall: 0.9706 | Precision: 0.9906\n",
      "  Model not saved. Best F1 so far: 0.9826\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 155.55batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16/50\n",
      "  Train Total Loss: 0.0380 | CLS Loss: 0.0374 | Align Loss (raw): 0.0012\n",
      "  Train Acc: 0.9858 | F1: 0.9856 | Recall: 0.9857 | Precision: 0.9855\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 614.66batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0995 | CLS Loss: 0.0940 | Align Loss (raw): 0.0109\n",
      "  Val Acc: 0.9712 | F1: 0.9717 | Recall: 0.9954 | Precision: 0.9491\n",
      "  Model not saved. Best F1 so far: 0.9826\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 158.61batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17/50\n",
      "  Train Total Loss: 0.0366 | CLS Loss: 0.0360 | Align Loss (raw): 0.0011\n",
      "  Train Acc: 0.9873 | F1: 0.9872 | Recall: 0.9874 | Precision: 0.9869\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 631.95batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0542 | CLS Loss: 0.0534 | Align Loss (raw): 0.0015\n",
      "  Val Acc: 0.9836 | F1: 0.9834 | Recall: 0.9785 | Precision: 0.9884\n",
      "  New Best F1: 0.9834. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.80batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18/50\n",
      "  Train Total Loss: 0.0340 | CLS Loss: 0.0335 | Align Loss (raw): 0.0010\n",
      "  Train Acc: 0.9878 | F1: 0.9877 | Recall: 0.9879 | Precision: 0.9875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 652.88batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0630 | CLS Loss: 0.0625 | Align Loss (raw): 0.0010\n",
      "  Val Acc: 0.9811 | F1: 0.9808 | Recall: 0.9724 | Precision: 0.9893\n",
      "  Model not saved. Best F1 so far: 0.9834\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.25batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19/50\n",
      "  Train Total Loss: 0.0346 | CLS Loss: 0.0340 | Align Loss (raw): 0.0012\n",
      "  Train Acc: 0.9871 | F1: 0.9870 | Recall: 0.9870 | Precision: 0.9870\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 600.33batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0507 | CLS Loss: 0.0502 | Align Loss (raw): 0.0009\n",
      "  Val Acc: 0.9841 | F1: 0.9840 | Recall: 0.9824 | Precision: 0.9856\n",
      "  New Best F1: 0.9840. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.02batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20/50\n",
      "  Train Total Loss: 0.0313 | CLS Loss: 0.0308 | Align Loss (raw): 0.0010\n",
      "  Train Acc: 0.9886 | F1: 0.9885 | Recall: 0.9888 | Precision: 0.9883\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 646.18batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0442 | CLS Loss: 0.0435 | Align Loss (raw): 0.0014\n",
      "  Val Acc: 0.9867 | F1: 0.9865 | Recall: 0.9831 | Precision: 0.9900\n",
      "  New Best F1: 0.9865. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.77batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 21/50\n",
      "  Train Total Loss: 0.0321 | CLS Loss: 0.0317 | Align Loss (raw): 0.0009\n",
      "  Train Acc: 0.9882 | F1: 0.9881 | Recall: 0.9883 | Precision: 0.9879\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 613.26batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0662 | CLS Loss: 0.0630 | Align Loss (raw): 0.0065\n",
      "  Val Acc: 0.9817 | F1: 0.9818 | Recall: 0.9931 | Precision: 0.9708\n",
      "  Model not saved. Best F1 so far: 0.9865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 157.93batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 22/50\n",
      "  Train Total Loss: 0.0277 | CLS Loss: 0.0274 | Align Loss (raw): 0.0006\n",
      "  Train Acc: 0.9900 | F1: 0.9899 | Recall: 0.9903 | Precision: 0.9895\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 620.89batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0504 | CLS Loss: 0.0496 | Align Loss (raw): 0.0017\n",
      "  Val Acc: 0.9860 | F1: 0.9859 | Recall: 0.9842 | Precision: 0.9877\n",
      "  Model not saved. Best F1 so far: 0.9865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 163.24batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 23/50\n",
      "  Train Total Loss: 0.0289 | CLS Loss: 0.0285 | Align Loss (raw): 0.0007\n",
      "  Train Acc: 0.9893 | F1: 0.9892 | Recall: 0.9897 | Precision: 0.9887\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 603.24batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0654 | CLS Loss: 0.0647 | Align Loss (raw): 0.0015\n",
      "  Val Acc: 0.9788 | F1: 0.9786 | Recall: 0.9765 | Precision: 0.9807\n",
      "  Model not saved. Best F1 so far: 0.9865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.86batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 24/50\n",
      "  Train Total Loss: 0.0280 | CLS Loss: 0.0275 | Align Loss (raw): 0.0011\n",
      "  Train Acc: 0.9902 | F1: 0.9901 | Recall: 0.9899 | Precision: 0.9903\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 649.96batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0641 | CLS Loss: 0.0623 | Align Loss (raw): 0.0036\n",
      "  Val Acc: 0.9806 | F1: 0.9807 | Recall: 0.9926 | Precision: 0.9691\n",
      "  Model not saved. Best F1 so far: 0.9865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25/50 Training: 100%|██████████| 1970/1970 [00:11<00:00, 166.72batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 25/50\n",
      "  Train Total Loss: 0.0259 | CLS Loss: 0.0256 | Align Loss (raw): 0.0006\n",
      "  Train Acc: 0.9913 | F1: 0.9912 | Recall: 0.9914 | Precision: 0.9909\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 646.77batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0573 | CLS Loss: 0.0549 | Align Loss (raw): 0.0048\n",
      "  Val Acc: 0.9839 | F1: 0.9837 | Recall: 0.9785 | Precision: 0.9889\n",
      "  Model not saved. Best F1 so far: 0.9865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26/50 Training: 100%|██████████| 1970/1970 [00:11<00:00, 164.70batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 26/50\n",
      "  Train Total Loss: 0.0262 | CLS Loss: 0.0257 | Align Loss (raw): 0.0010\n",
      "  Train Acc: 0.9910 | F1: 0.9909 | Recall: 0.9913 | Precision: 0.9904\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 631.95batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0509 | CLS Loss: 0.0499 | Align Loss (raw): 0.0021\n",
      "  Val Acc: 0.9853 | F1: 0.9852 | Recall: 0.9859 | Precision: 0.9844\n",
      "  Model not saved. Best F1 so far: 0.9865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.50batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 27/50\n",
      "  Train Total Loss: 0.0235 | CLS Loss: 0.0232 | Align Loss (raw): 0.0007\n",
      "  Train Acc: 0.9912 | F1: 0.9911 | Recall: 0.9913 | Precision: 0.9909\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 576.93batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0550 | CLS Loss: 0.0536 | Align Loss (raw): 0.0028\n",
      "  Val Acc: 0.9839 | F1: 0.9838 | Recall: 0.9885 | Precision: 0.9792\n",
      "  Model not saved. Best F1 so far: 0.9865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 159.27batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 28/50\n",
      "  Train Total Loss: 0.0246 | CLS Loss: 0.0242 | Align Loss (raw): 0.0008\n",
      "  Train Acc: 0.9919 | F1: 0.9918 | Recall: 0.9921 | Precision: 0.9916\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 650.61batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0523 | CLS Loss: 0.0507 | Align Loss (raw): 0.0032\n",
      "  Val Acc: 0.9865 | F1: 0.9864 | Recall: 0.9836 | Precision: 0.9892\n",
      "  Model not saved. Best F1 so far: 0.9865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.43batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29/50\n",
      "  Train Total Loss: 0.0259 | CLS Loss: 0.0254 | Align Loss (raw): 0.0009\n",
      "  Train Acc: 0.9909 | F1: 0.9908 | Recall: 0.9911 | Precision: 0.9904\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 643.66batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0479 | CLS Loss: 0.0472 | Align Loss (raw): 0.0014\n",
      "  Val Acc: 0.9834 | F1: 0.9833 | Recall: 0.9859 | Precision: 0.9807\n",
      "  Model not saved. Best F1 so far: 0.9865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30/50 Training: 100%|██████████| 1970/1970 [00:11<00:00, 165.73batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 30/50\n",
      "  Train Total Loss: 0.0227 | CLS Loss: 0.0222 | Align Loss (raw): 0.0011\n",
      "  Train Acc: 0.9922 | F1: 0.9921 | Recall: 0.9927 | Precision: 0.9914\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 634.15batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0509 | CLS Loss: 0.0488 | Align Loss (raw): 0.0043\n",
      "  Val Acc: 0.9868 | F1: 0.9868 | Recall: 0.9921 | Precision: 0.9815\n",
      "  New Best F1: 0.9868. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.16batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 31/50\n",
      "  Train Total Loss: 0.0196 | CLS Loss: 0.0193 | Align Loss (raw): 0.0006\n",
      "  Train Acc: 0.9933 | F1: 0.9932 | Recall: 0.9936 | Precision: 0.9929\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 661.51batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0478 | CLS Loss: 0.0457 | Align Loss (raw): 0.0042\n",
      "  Val Acc: 0.9867 | F1: 0.9866 | Recall: 0.9859 | Precision: 0.9872\n",
      "  Model not saved. Best F1 so far: 0.9868\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.10batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 32/50\n",
      "  Train Total Loss: 0.0210 | CLS Loss: 0.0207 | Align Loss (raw): 0.0007\n",
      "  Train Acc: 0.9924 | F1: 0.9923 | Recall: 0.9919 | Precision: 0.9926\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 638.02batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0530 | CLS Loss: 0.0528 | Align Loss (raw): 0.0004\n",
      "  Val Acc: 0.9848 | F1: 0.9845 | Recall: 0.9752 | Precision: 0.9940\n",
      "  Model not saved. Best F1 so far: 0.9868\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 162.95batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 33/50\n",
      "  Train Total Loss: 0.0211 | CLS Loss: 0.0209 | Align Loss (raw): 0.0004\n",
      "  Train Acc: 0.9925 | F1: 0.9924 | Recall: 0.9926 | Precision: 0.9922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 621.84batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0487 | CLS Loss: 0.0477 | Align Loss (raw): 0.0019\n",
      "  Val Acc: 0.9859 | F1: 0.9857 | Recall: 0.9806 | Precision: 0.9910\n",
      "  Model not saved. Best F1 so far: 0.9868\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.81batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 34/50\n",
      "  Train Total Loss: 0.0201 | CLS Loss: 0.0198 | Align Loss (raw): 0.0006\n",
      "  Train Acc: 0.9929 | F1: 0.9928 | Recall: 0.9925 | Precision: 0.9931\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 629.67batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0522 | CLS Loss: 0.0500 | Align Loss (raw): 0.0043\n",
      "  Val Acc: 0.9860 | F1: 0.9860 | Recall: 0.9890 | Precision: 0.9830\n",
      "  Model not saved. Best F1 so far: 0.9868\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 155.03batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 35/50\n",
      "  Train Total Loss: 0.0219 | CLS Loss: 0.0215 | Align Loss (raw): 0.0007\n",
      "  Train Acc: 0.9927 | F1: 0.9926 | Recall: 0.9925 | Precision: 0.9926\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 651.83batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0561 | CLS Loss: 0.0557 | Align Loss (raw): 0.0006\n",
      "  Val Acc: 0.9848 | F1: 0.9845 | Recall: 0.9765 | Precision: 0.9927\n",
      "  Model not saved. Best F1 so far: 0.9868\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 156.16batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 36/50\n",
      "  Train Total Loss: 0.0197 | CLS Loss: 0.0192 | Align Loss (raw): 0.0009\n",
      "  Train Acc: 0.9934 | F1: 0.9933 | Recall: 0.9933 | Precision: 0.9933\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 647.33batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0408 | CLS Loss: 0.0399 | Align Loss (raw): 0.0019\n",
      "  Val Acc: 0.9878 | F1: 0.9877 | Recall: 0.9844 | Precision: 0.9910\n",
      "  New Best F1: 0.9877. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 159.30batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 37/50\n",
      "  Train Total Loss: 0.0180 | CLS Loss: 0.0177 | Align Loss (raw): 0.0006\n",
      "  Train Acc: 0.9935 | F1: 0.9935 | Recall: 0.9940 | Precision: 0.9929\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 685.44batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0498 | CLS Loss: 0.0487 | Align Loss (raw): 0.0022\n",
      "  Val Acc: 0.9865 | F1: 0.9865 | Recall: 0.9936 | Precision: 0.9796\n",
      "  Model not saved. Best F1 so far: 0.9877\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.13batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 38/50\n",
      "  Train Total Loss: 0.0175 | CLS Loss: 0.0172 | Align Loss (raw): 0.0005\n",
      "  Train Acc: 0.9937 | F1: 0.9937 | Recall: 0.9935 | Precision: 0.9938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 635.13batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0500 | CLS Loss: 0.0485 | Align Loss (raw): 0.0029\n",
      "  Val Acc: 0.9863 | F1: 0.9862 | Recall: 0.9862 | Precision: 0.9862\n",
      "  Model not saved. Best F1 so far: 0.9877\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.13batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 39/50\n",
      "  Train Total Loss: 0.0210 | CLS Loss: 0.0205 | Align Loss (raw): 0.0012\n",
      "  Train Acc: 0.9926 | F1: 0.9925 | Recall: 0.9929 | Precision: 0.9921\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 664.22batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0489 | CLS Loss: 0.0479 | Align Loss (raw): 0.0020\n",
      "  Val Acc: 0.9877 | F1: 0.9876 | Recall: 0.9890 | Precision: 0.9862\n",
      "  Model not saved. Best F1 so far: 0.9877\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40/50 Training: 100%|██████████| 1970/1970 [00:11<00:00, 166.42batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 40/50\n",
      "  Train Total Loss: 0.0191 | CLS Loss: 0.0188 | Align Loss (raw): 0.0007\n",
      "  Train Acc: 0.9933 | F1: 0.9932 | Recall: 0.9936 | Precision: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 664.79batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0558 | CLS Loss: 0.0551 | Align Loss (raw): 0.0013\n",
      "  Val Acc: 0.9843 | F1: 0.9840 | Recall: 0.9742 | Precision: 0.9940\n",
      "  Model not saved. Best F1 so far: 0.9877\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41/50 Training: 100%|██████████| 1970/1970 [00:11<00:00, 170.37batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 41/50\n",
      "  Train Total Loss: 0.0172 | CLS Loss: 0.0168 | Align Loss (raw): 0.0007\n",
      "  Train Acc: 0.9941 | F1: 0.9940 | Recall: 0.9946 | Precision: 0.9935\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 591.77batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0570 | CLS Loss: 0.0563 | Align Loss (raw): 0.0015\n",
      "  Val Acc: 0.9845 | F1: 0.9843 | Recall: 0.9765 | Precision: 0.9922\n",
      "  Model not saved. Best F1 so far: 0.9877\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42/50 Training: 100%|██████████| 1970/1970 [00:11<00:00, 164.70batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 42/50\n",
      "  Train Total Loss: 0.0192 | CLS Loss: 0.0188 | Align Loss (raw): 0.0008\n",
      "  Train Acc: 0.9937 | F1: 0.9936 | Recall: 0.9934 | Precision: 0.9938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 646.24batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0402 | CLS Loss: 0.0397 | Align Loss (raw): 0.0012\n",
      "  Val Acc: 0.9882 | F1: 0.9881 | Recall: 0.9885 | Precision: 0.9877\n",
      "  New Best F1: 0.9881. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 158.81batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 43/50\n",
      "  Train Total Loss: 0.0158 | CLS Loss: 0.0156 | Align Loss (raw): 0.0005\n",
      "  Train Acc: 0.9942 | F1: 0.9941 | Recall: 0.9943 | Precision: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 575.90batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0649 | CLS Loss: 0.0648 | Align Loss (raw): 0.0002\n",
      "  Val Acc: 0.9813 | F1: 0.9810 | Recall: 0.9678 | Precision: 0.9945\n",
      "  Model not saved. Best F1 so far: 0.9881\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 156.75batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 44/50\n",
      "  Train Total Loss: 0.0149 | CLS Loss: 0.0148 | Align Loss (raw): 0.0002\n",
      "  Train Acc: 0.9945 | F1: 0.9945 | Recall: 0.9950 | Precision: 0.9940\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 624.92batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0508 | CLS Loss: 0.0499 | Align Loss (raw): 0.0017\n",
      "  Val Acc: 0.9873 | F1: 0.9872 | Recall: 0.9839 | Precision: 0.9905\n",
      "  Model not saved. Best F1 so far: 0.9881\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.06batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 45/50\n",
      "  Train Total Loss: 0.0186 | CLS Loss: 0.0183 | Align Loss (raw): 0.0005\n",
      "  Train Acc: 0.9940 | F1: 0.9939 | Recall: 0.9942 | Precision: 0.9936\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 626.03batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0684 | CLS Loss: 0.0682 | Align Loss (raw): 0.0004\n",
      "  Val Acc: 0.9803 | F1: 0.9800 | Recall: 0.9691 | Precision: 0.9911\n",
      "  Model not saved. Best F1 so far: 0.9881\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46/50 Training: 100%|██████████| 1970/1970 [00:13<00:00, 150.75batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 46/50\n",
      "  Train Total Loss: 0.0165 | CLS Loss: 0.0163 | Align Loss (raw): 0.0004\n",
      "  Train Acc: 0.9941 | F1: 0.9940 | Recall: 0.9940 | Precision: 0.9940\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46/50 Validation: 100%|██████████| 493/493 [00:01<00:00, 424.22batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0414 | CLS Loss: 0.0405 | Align Loss (raw): 0.0017\n",
      "  Val Acc: 0.9890 | F1: 0.9888 | Recall: 0.9857 | Precision: 0.9920\n",
      "  New Best F1: 0.9888. Saving model to model_wts/best_model_cosine_similarity_loss.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 155.67batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 47/50\n",
      "  Train Total Loss: 0.0161 | CLS Loss: 0.0158 | Align Loss (raw): 0.0006\n",
      "  Train Acc: 0.9948 | F1: 0.9947 | Recall: 0.9948 | Precision: 0.9947\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 598.14batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0750 | CLS Loss: 0.0693 | Align Loss (raw): 0.0114\n",
      "  Val Acc: 0.9831 | F1: 0.9832 | Recall: 0.9954 | Precision: 0.9713\n",
      "  Model not saved. Best F1 so far: 0.9888\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 160.61batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 48/50\n",
      "  Train Total Loss: 0.0133 | CLS Loss: 0.0127 | Align Loss (raw): 0.0012\n",
      "  Train Acc: 0.9955 | F1: 0.9955 | Recall: 0.9956 | Precision: 0.9953\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 659.64batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0571 | CLS Loss: 0.0558 | Align Loss (raw): 0.0025\n",
      "  Val Acc: 0.9848 | F1: 0.9846 | Recall: 0.9783 | Precision: 0.9909\n",
      "  Model not saved. Best F1 so far: 0.9888\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49/50 Training: 100%|██████████| 1970/1970 [00:11<00:00, 172.82batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 49/50\n",
      "  Train Total Loss: 0.0143 | CLS Loss: 0.0138 | Align Loss (raw): 0.0011\n",
      "  Train Acc: 0.9954 | F1: 0.9954 | Recall: 0.9956 | Precision: 0.9951\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 678.62batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0514 | CLS Loss: 0.0511 | Align Loss (raw): 0.0006\n",
      "  Val Acc: 0.9873 | F1: 0.9872 | Recall: 0.9829 | Precision: 0.9915\n",
      "  Model not saved. Best F1 so far: 0.9888\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50/50 Training: 100%|██████████| 1970/1970 [00:12<00:00, 161.75batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50/50\n",
      "  Train Total Loss: 0.0144 | CLS Loss: 0.0141 | Align Loss (raw): 0.0006\n",
      "  Train Acc: 0.9952 | F1: 0.9951 | Recall: 0.9954 | Precision: 0.9949\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50/50 Validation: 100%|██████████| 493/493 [00:00<00:00, 650.43batch/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1052 | CLS Loss: 0.1034 | Align Loss (raw): 0.0035\n",
      "  Val Acc: 0.9777 | F1: 0.9771 | Recall: 0.9588 | Precision: 0.9960\n",
      "  Model not saved. Best F1 so far: 0.9888\n",
      "Training Finished.\n",
      "Best Validation F1 achieved: 0.9888\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n",
    "from torch.optim import Adam , AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "from sklearn.metrics import roc_curve\n",
    "import numpy as np\n",
    "\n",
    "def compute_eer(y_true, y_scores):\n",
    "    fpr, tpr, thresholds = roc_curve(y_true, y_scores)\n",
    "    fnr = 1 - tpr\n",
    "    # Find the point where FPR = FNRbest_model_l1_loss.pth\n",
    "    eer_threshold_index = np.nanargmin(np.absolute((fnr - fpr)))\n",
    "    eer = (fpr[eer_threshold_index] + fnr[eer_threshold_index]) / 2\n",
    "    return eer\n",
    "\n",
    "\n",
    "# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "val_split = 0.2\n",
    "train_size = int((1 - val_split) * len(train_dataset))\n",
    "val_size = len(train_dataset) - train_size\n",
    "train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size] , generator=torch.Generator().manual_seed(42))\n",
    "train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "in_channel1 = 13\n",
    "in_channel2 = 13\n",
    "embed_dim1 = 768\n",
    "embed_dim2 = 768\n",
    "best_acc = 0\n",
    "best_f1 = 0\n",
    "\n",
    "model = CLAM(in_channel1 , in_channel2 , embed_dim1 , embed_dim2).to(device)\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=1e-4)\n",
    "\n",
    "classification_criterion = nn.BCEWithLogitsLoss() # For fake/real classificationalignment_loss = alignment_criterion(real_emb_filtered, fake_emb_filtered)\n",
    "\n",
    "save_folder = \"model_wts\"\n",
    "if not os.path.exists(save_folder):\n",
    "    os.makedirs(save_folder)\n",
    "\n",
    "epochs = 50 # Reduced for quick test\n",
    "alignment_loss_weight = 0.5\n",
    "print(\"Starting Training...\")\n",
    "\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    model.train()\n",
    "    train_loss_total = 0\n",
    "    train_loss_cls = 0\n",
    "    train_loss_align = 0\n",
    "    train_preds_all = []\n",
    "    train_labels_all = []\n",
    "\n",
    "    for batch in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{epochs} Training\", unit=\"batch\"):\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        embed1, embed2, labels = batch\n",
    "        embed1 = embed1.to(device) # Instrumental features\n",
    "        embed2 = embed2.to(device) # Vocal features\n",
    "        labels = labels.float().to(device).view(-1)\n",
    "\n",
    "        outputs , main_embeddings, aligned_embeddings  = model.forward_training(embed1, embed2)\n",
    "        # In your original code: real_emb, fake_emb\n",
    "        # Let's stick to your naming: real_emb = main_embeddings, fake_emb = aligned_embeddings\n",
    "        # where for real samples, we align main_embeddings (e.g. instrumental) with aligned_embeddings (e.g. vocal)\n",
    "        real_emb = main_embeddings\n",
    "        fake_emb = aligned_embeddings # This is the embedding you want to align TO real_emb for real samples.\n",
    "\n",
    "        outputs = outputs.view(-1)\n",
    "\n",
    "        classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    "        real_indices = (labels == 0).nonzero(as_tuple=True)[0]\n",
    "        alignment_loss_val = torch.tensor(0.0).to(device) # Renamed to avoid clash\n",
    "\n",
    "        if real_indices.nelement() > 0:\n",
    "            real_emb_filtered = real_emb[real_indices]\n",
    "            fake_emb_filtered = fake_emb[real_indices] # These are the vocal embeddings for real samples\n",
    "\n",
    "            if real_emb_filtered.shape[0] > 0: # Ensure there are samples to align\n",
    "                # --- MODIFICATION 2: Calculate Cosine Embedding Loss ---\n",
    "                # Target for CosineEmbeddingLoss: 1 for similar, -1 for dissimilar.\n",
    "                # We want them to be similar, so target is all 1s.\n",
    "                target_ones = torch.ones(real_emb_filtered.size(0)).to(device)\n",
    "                alignment_loss_val = alignment_criterion(real_emb_filtered, fake_emb_filtered, target_ones)\n",
    "                # --- End of MODIFICATION 2 ---\n",
    "\n",
    "        total_loss = classification_loss + alignment_loss_weight * alignment_loss_val\n",
    "\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss_total += total_loss.item()\n",
    "        train_loss_cls += classification_loss.item()\n",
    "        train_loss_align += alignment_loss_val.item()\n",
    "\n",
    "        preds = torch.sigmoid(outputs).detach().round()\n",
    "        train_preds_all.extend(preds.cpu().numpy())\n",
    "        train_labels_all.extend(labels.cpu().numpy())\n",
    "\n",
    "    num_batches = len(train_loader)\n",
    "    avg_train_loss_total = train_loss_total / num_batches if num_batches > 0 else 0\n",
    "    avg_train_loss_cls = train_loss_cls / num_batches if num_batches > 0 else 0\n",
    "    # Average raw alignment loss: only average over batches where it was calculated,\n",
    "    # or be mindful that it might be 0 for batches with no 'real' samples.\n",
    "    # For simplicity here, averaging over all batches, but if many batches have no real samples, this average will be lower.\n",
    "    avg_train_loss_align = train_loss_align / num_batches if num_batches > 0 else 0\n",
    "\n",
    "\n",
    "    train_acc = accuracy_score(train_labels_all, train_preds_all)\n",
    "    train_f1 = f1_score(train_labels_all, train_preds_all, average='binary', zero_division=0)\n",
    "    train_recall = recall_score(train_labels_all, train_preds_all, average='binary', zero_division=0)\n",
    "    train_precision = precision_score(train_labels_all, train_preds_all, average='binary', zero_division=0)\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{epochs}\")\n",
    "    print(f\"  Train Total Loss: {avg_train_loss_total:.4f} | CLS Loss: {avg_train_loss_cls:.4f} | Align Loss (raw): {avg_train_loss_align:.4f}\")\n",
    "    print(f\"  Train Acc: {train_acc:.4f} | F1: {train_f1:.4f} | Recall: {train_recall:.4f} | Precision: {train_precision:.4f}\")\n",
    "\n",
    "    model.eval()\n",
    "    val_loss_total = 0\n",
    "    val_loss_cls = 0\n",
    "    val_loss_align = 0\n",
    "    val_preds_all = []\n",
    "    val_labels_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(val_loader, desc=f\"Epoch {epoch+1}/{epochs} Validation\", unit=\"batch\"):\n",
    "            embed1, embed2, labels = batch\n",
    "            embed1 = embed1.to(device)\n",
    "            embed2 = embed2.to(device)\n",
    "            labels = labels.float().to(device).view(-1)\n",
    "\n",
    "            outputs , real_emb, fake_emb  = model.forward_training(embed1, embed2)\n",
    "            outputs = outputs.view(-1)\n",
    "\n",
    "            classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    "            real_indices = (labels == 0).nonzero(as_tuple=True)[0]\n",
    "            alignment_loss_val = torch.tensor(0.0).to(device) # Renamed\n",
    "            if real_indices.nelement() > 0:\n",
    "                real_emb_filtered = real_emb[real_indices]\n",
    "                fake_emb_filtered = fake_emb[real_indices]\n",
    "\n",
    "                if real_emb_filtered.shape[0] > 0:\n",
    "                    # --- MODIFICATION 3: Calculate Cosine Embedding Loss for validation ---\n",
    "                    target_ones = torch.ones(real_emb_filtered.size(0)).to(device)\n",
    "                    alignment_loss_val = alignment_criterion(real_emb_filtered, fake_emb_filtered, target_ones)\n",
    "                    # --- End of MODIFICATION 3 ---\n",
    "\n",
    "            total_loss = classification_loss + alignment_loss_weight * alignment_loss_val\n",
    "\n",
    "            val_loss_total += total_loss.item()\n",
    "            val_loss_cls += classification_loss.item()\n",
    "            val_loss_align += alignment_loss_val.item()\n",
    "\n",
    "            preds = torch.sigmoid(outputs).detach().round()\n",
    "            val_preds_all.extend(preds.cpu().numpy())\n",
    "            val_labels_all.extend(labels.cpu().numpy())\n",
    "\n",
    "    num_batches_val = len(val_loader)\n",
    "    avg_val_loss_total = val_loss_total / num_batches_val if num_batches_val > 0 else 0\n",
    "    avg_val_loss_cls = val_loss_cls / num_batches_val if num_batches_val > 0 else 0\n",
    "    avg_val_loss_align = val_loss_align / num_batches_val if num_batches_val > 0 else 0\n",
    "\n",
    "    val_acc = accuracy_score(val_labels_all, val_preds_all)\n",
    "    val_f1 = f1_score(val_labels_all, val_preds_all, average='binary', zero_division=0)\n",
    "    val_recall = recall_score(val_labels_all, val_preds_all, average='binary', zero_division=0)\n",
    "    val_precision = precision_score(val_labels_all, val_preds_all, average='binary', zero_division=0)\n",
    "\n",
    "    print(f\"  Val Total Loss: {avg_val_loss_total:.4f} | CLS Loss: {avg_val_loss_cls:.4f} | Align Loss (raw): {avg_val_loss_align:.4f}\")\n",
    "    print(f\"  Val Acc: {val_acc:.4f} | F1: {val_f1:.4f} | Recall: {val_recall:.4f} | Precision: {val_precision:.4f}\")\n",
    "\n",
    "    current_val_metric = val_f1 # Using F1 for saving best model as per your original logic\n",
    "\n",
    "    if val_f1 > best_f1: # Or current_val_metric > best_acc if using accuracy\n",
    "        best_f1 = val_f1\n",
    "        print(f\"  New Best F1: {best_f1:.4f}. Saving model to {os.path.join(save_folder, save_name)}\")\n",
    "        torch.save(model.state_dict(), os.path.join(save_folder, save_name))\n",
    "    else:\n",
    "        print(f\"  Model not saved. Best F1 so far: {best_f1:.4f}\")\n",
    "\n",
    "\n",
    "print(\"Training Finished.\")\n",
    "print(f\"Best Validation F1 achieved: {best_f1:.4f}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "2426b781",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Testing: 100%|██████████| 2139/2139 [00:05<00:00, 414.24batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Total Loss: 1.0951 | CLS Loss: 0.4286 | Align Loss (raw): 1.3329\n",
      "Test Acc: 0.9285 | F1: 0.9232 | Recall: 0.8630 | Precision: 0.9924\n",
      "EER: 0.0611\n"
     ]
    }
   ],
   "source": [
    "#On test set\n",
    "model.load_state_dict(torch.load(os.path.join(save_folder, f\"{save_name}\")))\n",
    "model.eval()\n",
    "test_loss_total = 0\n",
    "test_loss_cls = 0\n",
    "test_loss_align = 0\n",
    "test_preds_all = []\n",
    "test_labels_all = []\n",
    "test_scores_all = []    \n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(test_loader, desc=\"Testing\", unit=\"batch\"):\n",
    "        embed1, embed2, labels = batch\n",
    "        embed1 = embed1.to(device)\n",
    "        embed2 = embed2.to(device)\n",
    "        labels = labels.float().to(device).view(-1)\n",
    "\n",
    "        # --- Forward pass (testing) ---\n",
    "        outputs , real_emb, fake_emb  = model.forward_training(embed1, embed2)\n",
    "        outputs = outputs.view(-1)\n",
    "\n",
    "        # --- Calculate Classification Loss (Testing) ---\n",
    "        classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    "        real_indices = (labels == 1).nonzero(as_tuple=True)[0]\n",
    "        alignment_loss = torch.tensor(0.0).to(device)\n",
    "        if real_indices.nelement() > 0:\n",
    "            real_emb_filtered = real_emb[real_indices]\n",
    "            fake_emb_filtered = fake_emb[real_indices]\n",
    "            if real_emb_filtered.shape[0] > 0:\n",
    "                # --- Calculate Cosine Embedding Loss for testing ---\n",
    "                target_ones = torch.ones(real_emb_filtered.size(0)).to(device)\n",
    "                alignment_loss = alignment_criterion(real_emb_filtered, fake_emb_filtered, target_ones)\n",
    "                # --- End of Cosine Embedding Loss ---\n",
    "        else:\n",
    "            alignment_loss = torch.tensor(0.0).to(device)\n",
    "\n",
    "        total_loss = classification_loss + alignment_loss_weight * alignment_loss\n",
    "\n",
    "        test_loss_total += total_loss.item()\n",
    "        test_loss_cls += classification_loss.item()\n",
    "        test_loss_align += alignment_loss.item()\n",
    "\n",
    "        preds = torch.sigmoid(outputs).detach().round()\n",
    "        test_preds_all.extend(preds.cpu().numpy())\n",
    "        test_labels_all.extend(labels.cpu().numpy())\n",
    "        test_scores_all.extend(torch.sigmoid(outputs).cpu().numpy())\n",
    "\n",
    "    num_batches_test = len(test_loader)\n",
    "    avg_test_loss_total = test_loss_total / num_batches_test\n",
    "    avg_test_loss_cls = test_loss_cls / num_batches_test    \n",
    "    avg_test_loss_align = test_loss_align / num_batches_test\n",
    "    test_acc = accuracy_score(test_labels_all, test_preds_all)  \n",
    "    test_f1 = f1_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    test_recall = recall_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    test_precision = precision_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    eer = compute_eer(test_labels_all, test_scores_all)  # Calculate EER using the raw scores\n",
    "    print(f\"Test Total Loss: {avg_test_loss_total:.4f} | CLS Loss: {avg_test_loss_cls:.4f} | Align Loss (raw): {avg_test_loss_align:.4f}\")\n",
    "    print(f\"Test Acc: {test_acc:.4f} | F1: {test_f1:.4f} | Recall: {test_recall:.4f} | Precision: {test_precision:.4f}\")\n",
    "    print(f\"EER: {eer:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c39c682",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Testing: 100%|██████████| 2139/2139 [00:03<00:00, 595.20batch/s]\n",
    "# Test Total Loss: 0.4560 | CLS Loss: 0.4164 | Align Loss (raw): 0.0793\n",
    "# Test Acc: 0.9185 | F1: 0.9114 | Recall: 0.8414 | Precision: 0.9942\n",
    "# EER: 0.0625"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "uni3d",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
