{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9742f0d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import dataset\n",
    "from models.clam import CLAM\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import os\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1a3624ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47945\n",
      "65450\n",
      "65450\n",
      "17504\n",
      "17504\n",
      "17504\n"
     ]
    }
   ],
   "source": [
    "#Real Dataset\n",
    "\n",
    "\n",
    "\n",
    "real_df_1 = \"real_songs_2.csv\"\n",
    "real_df_1 = pd.read_csv(real_df_1)\n",
    "real_df_2 = \"real_yt_covers.csv\"\n",
    "real_df_2 = pd.read_csv(real_df_2)\n",
    "\n",
    "\n",
    "real_mert_1 = \"real_songs_mert\"\n",
    "real_mert_2 = \"yt_covers_mert\"\n",
    "\n",
    "real_wav2vec2_1 = \"real_songs_wav2vec2\"\n",
    "real_wav2vec2_2 = \"yt_covers_wav2vec2\"\n",
    "\n",
    "print(len(real_df_1))\n",
    "print(len(os.listdir(real_mert_1)))\n",
    "print(len(os.listdir(real_wav2vec2_1)))\n",
    "\n",
    "print(len(real_df_2))\n",
    "print(len(os.listdir(real_mert_2)))\n",
    "print(len(os.listdir(real_wav2vec2_2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7baef53b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "fake_df = \"ai_generated_music_metadata.csv\"\n",
    "fake_df = pd.read_csv(fake_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "836f41c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>model_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ai_covers_0.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ai_covers_1.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ai_covers_2.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>ai_covers_3.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ai_covers_4.mp3</td>\n",
       "      <td>AI_COVERS</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          filename model_name\n",
       "0  ai_covers_0.mp3  AI_COVERS\n",
       "1  ai_covers_1.mp3  AI_COVERS\n",
       "2  ai_covers_2.mp3  AI_COVERS\n",
       "3  ai_covers_3.mp3  AI_COVERS\n",
       "4  ai_covers_4.mp3  AI_COVERS"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fake_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e23ec37b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64972\n",
      "64958\n",
      "64958\n"
     ]
    }
   ],
   "source": [
    "#Fake Dataset\n",
    "\n",
    "\n",
    "fake_df = \"ai_generated_music_metadata.csv\"\n",
    "fake_df = pd.read_csv(fake_df)\n",
    "fake_mert = \"ai_generated_music_mert\"\n",
    "fake_wav2vec2 = \"ai_generated_music_wav2vec2\"\n",
    "\n",
    "\n",
    "print(len(fake_df))\n",
    "print(len(os.listdir(fake_mert)))\n",
    "print(len(os.listdir(fake_wav2vec2)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c40f0b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 47945/47945 [00:17<00:00, 2757.04it/s]\n",
      "100%|██████████| 17504/17504 [00:02<00:00, 8391.61it/s]\n",
      "100%|██████████| 64972/64972 [00:21<00:00, 2976.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47945\n",
      "17504\n",
      "64958\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def remove_filename(df, folder):\n",
    "    filenames = os.listdir(folder)\n",
    "    df_copy = df.copy()\n",
    "    for index, row in tqdm(df_copy.iterrows(), total=df_copy.shape[0]):\n",
    "        filename = row['filename']\n",
    "        if filename.endswith('.mp3'):\n",
    "            filename = filename[:-4]\n",
    "        if filename.endswith('.wav'):\n",
    "            filename = filename[:-4]\n",
    "\n",
    "        if filename + '.pt' not in filenames:\n",
    "            df_copy.drop(index, inplace=True)\n",
    "    return df_copy\n",
    "\n",
    "\n",
    "real_df_1 = remove_filename(real_df_1, real_mert_1)\n",
    "real_df_2 = remove_filename(real_df_2, real_mert_2)\n",
    "fake_df = remove_filename(fake_df, fake_mert)\n",
    "print(len(real_df_1))\n",
    "print(len(real_df_2))\n",
    "print(len(fake_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6e654958",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47945\n",
      "17504\n"
     ]
    }
   ],
   "source": [
    "print(len(real_df_1))\n",
    "print(len(real_df_2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b9facf9d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65449\n"
     ]
    }
   ],
   "source": [
    "real_df = pd.concat([real_df_1, real_df_2])\n",
    "print(len(real_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2b17865a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "model_name\n",
       "suno_3.5     23695\n",
       "Udio_1.5     19500\n",
       "riffusion     7043\n",
       "Yue           5278\n",
       "Diffrythm     4606\n",
       "suno_3        3512\n",
       "AI_COVERS     1166\n",
       "suno_2         110\n",
       "suno_4          48\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fake_df[\"model_name\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "901d4ef6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17047\n",
      "47911\n",
      "0.7375688906678162\n",
      "48273\n",
      "17176\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "fake_df_test = fake_df[fake_df[\"model_name\"].isin([\"riffusion\", \"suno_3\", \"AI_COVERS\"  , 'suno_4' , 'Yue'])]\n",
    "print(len(fake_df_test))\n",
    "fake_df_train = fake_df[~fake_df[\"model_name\"].isin([\"riffusion\", \"suno_3\", \"AI_COVERS\" , \"suno_4\" , \"Yue\"] )]\n",
    "print(len(fake_df_train))\n",
    "\n",
    "ratio = len(fake_df_train) / ( len(fake_df_test) + len(fake_df_train)  )\n",
    "print(ratio)\n",
    "\n",
    "real_df = pd.concat([real_df_1, real_df_2])\n",
    "\n",
    "real_df_train, real_df_test = train_test_split(real_df, test_size= 1 - ratio, random_state=42)\n",
    "print(len(real_df_train))\n",
    "print(len(real_df_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "af6715f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preloading embeddings for train split...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "abc8895461ba44f98de562d3ec9433a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading train data to CPU:   0%|          | 0/48273 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/run/media/arnesh/MySSD/dataset/utils/dataset.py:32: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  mert = torch.load(mert_path, map_location=torch.device('cpu'))\n",
      "/run/media/arnesh/MySSD/dataset/utils/dataset.py:36: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  wav2vec2 = torch.load(wav2vec2_path, map_location=torch.device('cpu'))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_8382.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_5219.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_11880.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_9557.pt: Ran out of input. Skipping sample.\n",
      "PytorchStreamReader failed locating file data/0: file not found\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_11542.pt: PytorchStreamReader failed locating file data/0: file not found. Skipping sample.\n",
      "Preloading embeddings for train split...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1871e91312724d5f9477647823e0754d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading train data to CPU:   0%|          | 0/47911 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for diffrythm_566.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_122.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_278.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_2113.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_11812.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_18613.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_5_23683.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for udio_1_5_607.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for udio_1_5_7188.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for udio_1_5_18205.pt: Ran out of input. Skipping sample.\n",
      "Preloading embeddings for test split...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ee19aa93542f45568e926187a14ed94b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading test data to CPU:   0%|          | 0/17176 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_13185.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_9011.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_10928.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_7006.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_5289.pt: Ran out of input. Skipping sample.\n",
      "PytorchStreamReader failed locating file data/0: file not found\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_7797.pt: PytorchStreamReader failed locating file data/0: file not found. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for yt_covers_15133.pt: Ran out of input. Skipping sample.\n",
      "Preloading embeddings for test split...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab432e4de64f4d2b8eeb1651278da7f7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading test data to CPU:   0%|          | 0/17047 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_1802.pt: Ran out of input. Skipping sample.\n",
      "Ran out of input\n",
      "\n",
      "Warning: Error loading embeddings for suno_3_2149.pt: Ran out of input. Skipping sample.\n"
     ]
    }
   ],
   "source": [
    "from utils import dataset\n",
    "\n",
    "train_real_dataset = dataset.SongsDataset(real_df_train, real_mert_1, real_wav2vec2_1, label = 0)\n",
    "train_fake_dataset = dataset.SongsDataset(fake_df_train, fake_mert, fake_wav2vec2, label = 1)\n",
    "\n",
    "test_real_dataset = dataset.SongsDataset(real_df_test, real_mert_1, real_wav2vec2_1, label = 0 , split = \"test\")\n",
    "test_fake_dataset = dataset.SongsDataset(fake_df_test, fake_mert, fake_wav2vec2, label = 1 , split = \"test\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "296bb56e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "48268\n",
      "47901\n",
      "17169\n",
      "17045\n"
     ]
    }
   ],
   "source": [
    "print(len(train_real_dataset))\n",
    "print(len(train_fake_dataset))\n",
    "print(len(test_real_dataset))\n",
    "print(len(test_fake_dataset))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "738fc507",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "96169\n",
      "34214\n"
     ]
    }
   ],
   "source": [
    "train_dataset = torch.utils.data.ConcatDataset([train_real_dataset, train_fake_dataset])\n",
    "test_dataset = torch.utils.data.ConcatDataset([test_real_dataset, test_fake_dataset])\n",
    "print(len(train_dataset))\n",
    "print(len(test_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d23726ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Training...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Testing: 100%|██████████| 2139/2139 [00:13<00:00, 163.11batch/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Total Loss: 0.3613 | CLS Loss: 0.3245 | Align Loss (raw): 0.0736\n",
      "Test Acc: 0.9311 | F1: 0.9264 | Recall: 0.8709 | Precision: 0.9894\n",
      "EER: 0.0652\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n",
    "from torch.optim import Adam , AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "from sklearn.metrics import roc_curve\n",
    "import numpy as np\n",
    "\n",
    "def compute_eer(y_true, y_scores):\n",
    "    fpr, tpr, thresholds = roc_curve(y_true, y_scores)\n",
    "    fnr = 1 - tpr\n",
    "    # Find the point where FPR = FNRbest_model_l1_loss.pth\n",
    "    eer_threshold_index = np.nanargmin(np.absolute((fnr - fpr)))\n",
    "    eer = (fpr[eer_threshold_index] + fnr[eer_threshold_index]) / 2\n",
    "    return eer\n",
    "\n",
    "\n",
    "val_split = 0.2\n",
    "train_size = int((1 - val_split) * len(train_dataset))\n",
    "val_size = len(train_dataset) - train_size\n",
    "train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size] , generator=torch.Generator().manual_seed(42))\n",
    "train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "in_channel1 = 13\n",
    "in_channel2 = 13\n",
    "embed_dim1 = 768\n",
    "embed_dim2 = 768\n",
    "best_acc = 0\n",
    "best_f1 = 0\n",
    "\n",
    "model = CLAM(in_channel1 , in_channel2 , embed_dim1 , embed_dim2).to(device)\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=1e-4)\n",
    "\n",
    "classification_criterion = nn.BCEWithLogitsLoss() # For fake/real classification\n",
    "margin = 0.2 # You might need to tune this hyperparameter\n",
    "alignment_criterion = nn.TripletMarginLoss(margin=margin)\n",
    "save_name = f\"best_model_triplet_loss_margin_{margin}.pth\" # Update save name\n",
    "save_folder = \"model_wts\"\n",
    "# --- Hyperparameters ---\n",
    "epochs = 50\n",
    "alignment_loss_weight = 0.5 # Weight factor for the alignment loss \n",
    "print(\"Starting Training...\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "model.load_state_dict(torch.load(os.path.join(save_folder, f\"{save_name}\")))\n",
    "model.eval()\n",
    "test_loss_total = 0\n",
    "test_loss_cls = 0\n",
    "test_loss_align = 0\n",
    "test_preds_all = []\n",
    "test_labels_all = []\n",
    "test_scores_all = []    \n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(test_loader, desc=\"Testing\", unit=\"batch\"):\n",
    "        embed1, embed2, labels = batch\n",
    "        embed1 = embed1.to(device)\n",
    "        embed2 = embed2.to(device)\n",
    "        labels = labels.float().to(device).view(-1)\n",
    "\n",
    "        # --- Forward pass (testing) ---\n",
    "        outputs , real_emb, fake_emb  = model.forward_training(embed1, embed2)\n",
    "        outputs = outputs.view(-1)\n",
    "\n",
    "        # --- Calculate Classification Loss (Testing) ---\n",
    "        classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    "        real_indices = (labels == 1).nonzero(as_tuple=True)[0]\n",
    "        alignment_loss = torch.tensor(0.0).to(device)\n",
    "        if real_indices.nelement() > 0:\n",
    "            real_emb_filtered = real_emb[real_indices]\n",
    "            fake_emb_filtered = fake_emb[real_indices]\n",
    "            \n",
    "            num_real_samples = real_emb_filtered.size(0)\n",
    "            if num_real_samples > 1:\n",
    "                triplet_losses = []\n",
    "                for i in range(num_real_samples):\n",
    "                    anchor = real_emb_filtered[i]\n",
    "                    positive = fake_emb_filtered[i]\n",
    "                    negative_indices = [j for j in range(num_real_samples) if j != i]\n",
    "                    negatives = fake_emb_filtered[negative_indices]\n",
    "\n",
    "                    for neg in negatives:\n",
    "                        loss = alignment_criterion(anchor.unsqueeze(0), positive.unsqueeze(0), neg.unsqueeze(0))\n",
    "                        triplet_losses.append(loss)\n",
    "\n",
    "                if triplet_losses:\n",
    "                    alignment_loss = torch.mean(torch.stack(triplet_losses))\n",
    "                else:\n",
    "                    alignment_loss = torch.tensor(0.0).to(device)\n",
    "\n",
    "        total_loss = classification_loss + alignment_loss_weight * alignment_loss\n",
    "\n",
    "        test_loss_total += total_loss.item()\n",
    "        test_loss_cls += classification_loss.item()\n",
    "        test_loss_align += alignment_loss.item()\n",
    "\n",
    "        preds = torch.sigmoid(outputs).detach().round()\n",
    "        test_preds_all.extend(preds.cpu().numpy())\n",
    "        test_labels_all.extend(labels.cpu().numpy())\n",
    "        test_scores_all.extend(torch.sigmoid(outputs).cpu().numpy())\n",
    "\n",
    "    num_batches_test = len(test_loader)\n",
    "    avg_test_loss_total = test_loss_total / num_batches_test\n",
    "    avg_test_loss_cls = test_loss_cls / num_batches_test    \n",
    "    avg_test_loss_align = test_loss_align / num_batches_test\n",
    "    test_acc = accuracy_score(test_labels_all, test_preds_all)  \n",
    "    test_f1 = f1_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    test_recall = recall_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    test_precision = precision_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    eer = compute_eer(test_labels_all, test_scores_all)  # Calculate EER using the raw scores\n",
    "    print(f\"Test Total Loss: {avg_test_loss_total:.4f} | CLS Loss: {avg_test_loss_cls:.4f} | Align Loss (raw): {avg_test_loss_align:.4f}\")\n",
    "    print(f\"Test Acc: {test_acc:.4f} | F1: {test_f1:.4f} | Recall: {test_recall:.4f} | Precision: {test_precision:.4f}\")\n",
    "    print(f\"EER: {eer:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "359ca8a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Training...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1/50 Training: 100%|██████████| 4809/4809 [01:51<00:00, 43.15batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/50\n",
      "  Train Total Loss: 0.3677 | CLS Loss: 0.2768 | Align Loss (raw): 0.1817\n",
      "  Train Acc: 0.8733 | F1: 0.8679 | Recall: 0.8364 | Precision: 0.9018\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1/50 Validation: 100%|██████████| 1203/1203 [00:07<00:00, 170.38batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1858 | CLS Loss: 0.1267 | Align Loss (raw): 0.1183\n",
      "  Val Acc: 0.9502 | F1: 0.9492 | Recall: 0.9305 | Precision: 0.9687\n",
      "  New Best F1: 0.9492\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2/50 Training: 100%|██████████| 4809/4809 [01:51<00:00, 43.15batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/50\n",
      "  Train Total Loss: 0.1449 | CLS Loss: 0.1007 | Align Loss (raw): 0.0885\n",
      "  Train Acc: 0.9661 | F1: 0.9658 | Recall: 0.9634 | Precision: 0.9683\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.12batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1055 | CLS Loss: 0.0740 | Align Loss (raw): 0.0643\n",
      "  Val Acc: 0.9755 | F1: 0.9755 | Recall: 0.9753 | Precision: 0.9757\n",
      "  New Best F1: 0.9755\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3/50 Training: 100%|██████████| 4809/4809 [01:48<00:00, 44.42batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/50\n",
      "  Train Total Loss: 0.1046 | CLS Loss: 0.0790 | Align Loss (raw): 0.0513\n",
      "  Train Acc: 0.9733 | F1: 0.9732 | Recall: 0.9717 | Precision: 0.9746\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 179.84batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.1034 | CLS Loss: 0.0780 | Align Loss (raw): 0.0541\n",
      "  Val Acc: 0.9731 | F1: 0.9726 | Recall: 0.9543 | Precision: 0.9916\n",
      "  Model not saved. Best F1 so far: 0.9755\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.14batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/50\n",
      "  Train Total Loss: 0.0842 | CLS Loss: 0.0672 | Align Loss (raw): 0.0340\n",
      "  Train Acc: 0.9775 | F1: 0.9774 | Recall: 0.9761 | Precision: 0.9787\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 180.43batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0651 | CLS Loss: 0.0522 | Align Loss (raw): 0.0259\n",
      "  Val Acc: 0.9825 | F1: 0.9825 | Recall: 0.9817 | Precision: 0.9832\n",
      "  New Best F1: 0.9825\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.23batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/50\n",
      "  Train Total Loss: 0.0753 | CLS Loss: 0.0606 | Align Loss (raw): 0.0294\n",
      "  Train Acc: 0.9795 | F1: 0.9793 | Recall: 0.9782 | Precision: 0.9804\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 181.14batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0711 | CLS Loss: 0.0528 | Align Loss (raw): 0.0364\n",
      "  Val Acc: 0.9829 | F1: 0.9829 | Recall: 0.9795 | Precision: 0.9863\n",
      "  New Best F1: 0.9829\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6/50 Training: 100%|██████████| 4809/4809 [01:49<00:00, 43.84batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/50\n",
      "  Train Total Loss: 0.0669 | CLS Loss: 0.0543 | Align Loss (raw): 0.0252\n",
      "  Train Acc: 0.9813 | F1: 0.9812 | Recall: 0.9799 | Precision: 0.9826\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 173.42batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0651 | CLS Loss: 0.0535 | Align Loss (raw): 0.0232\n",
      "  Val Acc: 0.9816 | F1: 0.9815 | Recall: 0.9706 | Precision: 0.9926\n",
      "  Model not saved. Best F1 so far: 0.9829\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7/50 Training: 100%|██████████| 4809/4809 [01:49<00:00, 43.83batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/50\n",
      "  Train Total Loss: 0.0627 | CLS Loss: 0.0505 | Align Loss (raw): 0.0243\n",
      "  Train Acc: 0.9831 | F1: 0.9830 | Recall: 0.9822 | Precision: 0.9837\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7/50 Validation: 100%|██████████| 1203/1203 [00:07<00:00, 168.20batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0576 | CLS Loss: 0.0477 | Align Loss (raw): 0.0197\n",
      "  Val Acc: 0.9837 | F1: 0.9839 | Recall: 0.9923 | Precision: 0.9756\n",
      "  New Best F1: 0.9839\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8/50 Training: 100%|██████████| 4809/4809 [01:49<00:00, 43.83batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/50\n",
      "  Train Total Loss: 0.0573 | CLS Loss: 0.0461 | Align Loss (raw): 0.0224\n",
      "  Train Acc: 0.9845 | F1: 0.9844 | Recall: 0.9837 | Precision: 0.9851\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8/50 Validation: 100%|██████████| 1203/1203 [00:07<00:00, 168.33batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0563 | CLS Loss: 0.0477 | Align Loss (raw): 0.0172\n",
      "  Val Acc: 0.9846 | F1: 0.9847 | Recall: 0.9910 | Precision: 0.9786\n",
      "  New Best F1: 0.9847\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9/50 Training: 100%|██████████| 4809/4809 [01:48<00:00, 44.25batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/50\n",
      "  Train Total Loss: 0.0563 | CLS Loss: 0.0454 | Align Loss (raw): 0.0218\n",
      "  Train Acc: 0.9843 | F1: 0.9842 | Recall: 0.9835 | Precision: 0.9849\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.24batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0538 | CLS Loss: 0.0427 | Align Loss (raw): 0.0222\n",
      "  Val Acc: 0.9864 | F1: 0.9863 | Recall: 0.9808 | Precision: 0.9919\n",
      "  New Best F1: 0.9863\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10/50 Training: 100%|██████████| 4809/4809 [01:49<00:00, 44.00batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10/50\n",
      "  Train Total Loss: 0.0517 | CLS Loss: 0.0414 | Align Loss (raw): 0.0206\n",
      "  Train Acc: 0.9860 | F1: 0.9859 | Recall: 0.9851 | Precision: 0.9868\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 172.69batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0488 | CLS Loss: 0.0387 | Align Loss (raw): 0.0201\n",
      "  Val Acc: 0.9879 | F1: 0.9880 | Recall: 0.9921 | Precision: 0.9839\n",
      "  New Best F1: 0.9880\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11/50 Training: 100%|██████████| 4809/4809 [01:48<00:00, 44.24batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11/50\n",
      "  Train Total Loss: 0.0493 | CLS Loss: 0.0397 | Align Loss (raw): 0.0192\n",
      "  Train Acc: 0.9867 | F1: 0.9866 | Recall: 0.9860 | Precision: 0.9873\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.61batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0468 | CLS Loss: 0.0385 | Align Loss (raw): 0.0167\n",
      "  Val Acc: 0.9867 | F1: 0.9868 | Recall: 0.9875 | Precision: 0.9860\n",
      "  Model not saved. Best F1 so far: 0.9880\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12/50 Training: 100%|██████████| 4809/4809 [01:49<00:00, 44.00batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12/50\n",
      "  Train Total Loss: 0.0459 | CLS Loss: 0.0367 | Align Loss (raw): 0.0185\n",
      "  Train Acc: 0.9874 | F1: 0.9874 | Recall: 0.9873 | Precision: 0.9874\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 173.08batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0458 | CLS Loss: 0.0380 | Align Loss (raw): 0.0174\n",
      "  Val Acc: 0.9871 | F1: 0.9871 | Recall: 0.9926 | Precision: 0.9817\n",
      "  Model not saved. Best F1 so far: 0.9880\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13/50 Training: 100%|██████████| 4809/4809 [01:48<00:00, 44.50batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13/50\n",
      "  Train Total Loss: 0.0439 | CLS Loss: 0.0351 | Align Loss (raw): 0.0175\n",
      "  Train Acc: 0.9882 | F1: 0.9882 | Recall: 0.9874 | Precision: 0.9889\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 172.40batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0424 | CLS Loss: 0.0339 | Align Loss (raw): 0.0170\n",
      "  Val Acc: 0.9889 | F1: 0.9889 | Recall: 0.9871 | Precision: 0.9906\n",
      "  New Best F1: 0.9889\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.73batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14/50\n",
      "  Train Total Loss: 0.0417 | CLS Loss: 0.0334 | Align Loss (raw): 0.0167\n",
      "  Train Acc: 0.9888 | F1: 0.9887 | Recall: 0.9882 | Precision: 0.9892\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.13batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0541 | CLS Loss: 0.0456 | Align Loss (raw): 0.0169\n",
      "  Val Acc: 0.9858 | F1: 0.9859 | Recall: 0.9947 | Precision: 0.9773\n",
      "  Model not saved. Best F1 so far: 0.9889\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.04batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15/50\n",
      "  Train Total Loss: 0.0408 | CLS Loss: 0.0328 | Align Loss (raw): 0.0159\n",
      "  Train Acc: 0.9891 | F1: 0.9890 | Recall: 0.9884 | Precision: 0.9896\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.41batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0459 | CLS Loss: 0.0330 | Align Loss (raw): 0.0259\n",
      "  Val Acc: 0.9891 | F1: 0.9891 | Recall: 0.9900 | Precision: 0.9882\n",
      "  New Best F1: 0.9891\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.74batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16/50\n",
      "  Train Total Loss: 0.0402 | CLS Loss: 0.0323 | Align Loss (raw): 0.0158\n",
      "  Train Acc: 0.9890 | F1: 0.9890 | Recall: 0.9886 | Precision: 0.9894\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.52batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0420 | CLS Loss: 0.0350 | Align Loss (raw): 0.0140\n",
      "  Val Acc: 0.9886 | F1: 0.9885 | Recall: 0.9825 | Precision: 0.9945\n",
      "  Model not saved. Best F1 so far: 0.9891\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 44.95batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17/50\n",
      "  Train Total Loss: 0.0380 | CLS Loss: 0.0305 | Align Loss (raw): 0.0150\n",
      "  Train Acc: 0.9894 | F1: 0.9894 | Recall: 0.9886 | Precision: 0.9901\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.33batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0459 | CLS Loss: 0.0392 | Align Loss (raw): 0.0134\n",
      "  Val Acc: 0.9870 | F1: 0.9869 | Recall: 0.9792 | Precision: 0.9947\n",
      "  Model not saved. Best F1 so far: 0.9891\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.02batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18/50\n",
      "  Train Total Loss: 0.0373 | CLS Loss: 0.0297 | Align Loss (raw): 0.0150\n",
      "  Train Acc: 0.9900 | F1: 0.9899 | Recall: 0.9889 | Precision: 0.9909\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.56batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0383 | CLS Loss: 0.0316 | Align Loss (raw): 0.0133\n",
      "  Val Acc: 0.9897 | F1: 0.9897 | Recall: 0.9865 | Precision: 0.9929\n",
      "  New Best F1: 0.9897\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.78batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19/50\n",
      "  Train Total Loss: 0.0359 | CLS Loss: 0.0288 | Align Loss (raw): 0.0141\n",
      "  Train Acc: 0.9895 | F1: 0.9895 | Recall: 0.9891 | Precision: 0.9898\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.47batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0676 | CLS Loss: 0.0621 | Align Loss (raw): 0.0109\n",
      "  Val Acc: 0.9795 | F1: 0.9799 | Recall: 0.9974 | Precision: 0.9630\n",
      "  Model not saved. Best F1 so far: 0.9897\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.09batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20/50\n",
      "  Train Total Loss: 0.0351 | CLS Loss: 0.0282 | Align Loss (raw): 0.0137\n",
      "  Train Acc: 0.9905 | F1: 0.9904 | Recall: 0.9899 | Precision: 0.9910\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.68batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0376 | CLS Loss: 0.0322 | Align Loss (raw): 0.0107\n",
      "  Val Acc: 0.9896 | F1: 0.9896 | Recall: 0.9857 | Precision: 0.9935\n",
      "  Model not saved. Best F1 so far: 0.9897\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 44.95batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 21/50\n",
      "  Train Total Loss: 0.0333 | CLS Loss: 0.0267 | Align Loss (raw): 0.0133\n",
      "  Train Acc: 0.9909 | F1: 0.9908 | Recall: 0.9904 | Precision: 0.9912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.67batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0419 | CLS Loss: 0.0367 | Align Loss (raw): 0.0103\n",
      "  Val Acc: 0.9883 | F1: 0.9884 | Recall: 0.9941 | Precision: 0.9827\n",
      "  Model not saved. Best F1 so far: 0.9897\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.92batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 22/50\n",
      "  Train Total Loss: 0.0324 | CLS Loss: 0.0260 | Align Loss (raw): 0.0128\n",
      "  Train Acc: 0.9910 | F1: 0.9909 | Recall: 0.9904 | Precision: 0.9915\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.13batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0520 | CLS Loss: 0.0480 | Align Loss (raw): 0.0095\n",
      "  Val Acc: 0.9866 | F1: 0.9867 | Recall: 0.9951 | Precision: 0.9784\n",
      "  Model not saved. Best F1 so far: 0.9897\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.08batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 23/50\n",
      "  Train Total Loss: 0.0319 | CLS Loss: 0.0257 | Align Loss (raw): 0.0125\n",
      "  Train Acc: 0.9912 | F1: 0.9912 | Recall: 0.9905 | Precision: 0.9918\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.77batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0350 | CLS Loss: 0.0299 | Align Loss (raw): 0.0102\n",
      "  Val Acc: 0.9901 | F1: 0.9901 | Recall: 0.9914 | Precision: 0.9889\n",
      "  New Best F1: 0.9901\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.07batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 24/50\n",
      "  Train Total Loss: 0.0309 | CLS Loss: 0.0248 | Align Loss (raw): 0.0123\n",
      "  Train Acc: 0.9917 | F1: 0.9916 | Recall: 0.9911 | Precision: 0.9922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.12batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0324 | CLS Loss: 0.0280 | Align Loss (raw): 0.0087\n",
      "  Val Acc: 0.9907 | F1: 0.9907 | Recall: 0.9884 | Precision: 0.9930\n",
      "  New Best F1: 0.9907\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.78batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 25/50\n",
      "  Train Total Loss: 0.0305 | CLS Loss: 0.0242 | Align Loss (raw): 0.0125\n",
      "  Train Acc: 0.9915 | F1: 0.9914 | Recall: 0.9909 | Precision: 0.9919\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.29batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0546 | CLS Loss: 0.0476 | Align Loss (raw): 0.0140\n",
      "  Val Acc: 0.9856 | F1: 0.9858 | Recall: 0.9970 | Precision: 0.9748\n",
      "  Model not saved. Best F1 so far: 0.9907\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.72batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 26/50\n",
      "  Train Total Loss: 0.0295 | CLS Loss: 0.0235 | Align Loss (raw): 0.0121\n",
      "  Train Acc: 0.9917 | F1: 0.9916 | Recall: 0.9911 | Precision: 0.9922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.84batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0413 | CLS Loss: 0.0352 | Align Loss (raw): 0.0123\n",
      "  Val Acc: 0.9884 | F1: 0.9884 | Recall: 0.9948 | Precision: 0.9822\n",
      "  Model not saved. Best F1 so far: 0.9907\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.00batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 27/50\n",
      "  Train Total Loss: 0.0285 | CLS Loss: 0.0227 | Align Loss (raw): 0.0115\n",
      "  Train Acc: 0.9922 | F1: 0.9922 | Recall: 0.9915 | Precision: 0.9928\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 174.58batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0352 | CLS Loss: 0.0303 | Align Loss (raw): 0.0097\n",
      "  Val Acc: 0.9906 | F1: 0.9906 | Recall: 0.9937 | Precision: 0.9876\n",
      "  Model not saved. Best F1 so far: 0.9907\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 44.96batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 28/50\n",
      "  Train Total Loss: 0.0283 | CLS Loss: 0.0225 | Align Loss (raw): 0.0116\n",
      "  Train Acc: 0.9920 | F1: 0.9920 | Recall: 0.9915 | Precision: 0.9925\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.61batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0530 | CLS Loss: 0.0483 | Align Loss (raw): 0.0091\n",
      "  Val Acc: 0.9842 | F1: 0.9841 | Recall: 0.9715 | Precision: 0.9969\n",
      "  Model not saved. Best F1 so far: 0.9907\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.29batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29/50\n",
      "  Train Total Loss: 0.0283 | CLS Loss: 0.0227 | Align Loss (raw): 0.0112\n",
      "  Train Acc: 0.9921 | F1: 0.9921 | Recall: 0.9912 | Precision: 0.9929\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.53batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0348 | CLS Loss: 0.0296 | Align Loss (raw): 0.0106\n",
      "  Val Acc: 0.9902 | F1: 0.9902 | Recall: 0.9885 | Precision: 0.9920\n",
      "  Model not saved. Best F1 so far: 0.9907\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 44.98batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 30/50\n",
      "  Train Total Loss: 0.0270 | CLS Loss: 0.0216 | Align Loss (raw): 0.0108\n",
      "  Train Acc: 0.9923 | F1: 0.9922 | Recall: 0.9916 | Precision: 0.9929\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.43batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0444 | CLS Loss: 0.0398 | Align Loss (raw): 0.0091\n",
      "  Val Acc: 0.9880 | F1: 0.9881 | Recall: 0.9958 | Precision: 0.9805\n",
      "  Model not saved. Best F1 so far: 0.9907\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.93batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 31/50\n",
      "  Train Total Loss: 0.0268 | CLS Loss: 0.0216 | Align Loss (raw): 0.0105\n",
      "  Train Acc: 0.9926 | F1: 0.9925 | Recall: 0.9918 | Precision: 0.9932\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.70batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0331 | CLS Loss: 0.0281 | Align Loss (raw): 0.0099\n",
      "  Val Acc: 0.9913 | F1: 0.9913 | Recall: 0.9894 | Precision: 0.9931\n",
      "  New Best F1: 0.9913\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.91batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 32/50\n",
      "  Train Total Loss: 0.0260 | CLS Loss: 0.0209 | Align Loss (raw): 0.0102\n",
      "  Train Acc: 0.9926 | F1: 0.9926 | Recall: 0.9921 | Precision: 0.9930\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.36batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0495 | CLS Loss: 0.0441 | Align Loss (raw): 0.0109\n",
      "  Val Acc: 0.9854 | F1: 0.9856 | Recall: 0.9942 | Precision: 0.9771\n",
      "  Model not saved. Best F1 so far: 0.9913\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.07batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 33/50\n",
      "  Train Total Loss: 0.0256 | CLS Loss: 0.0206 | Align Loss (raw): 0.0101\n",
      "  Train Acc: 0.9930 | F1: 0.9929 | Recall: 0.9923 | Precision: 0.9936\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 175.89batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0343 | CLS Loss: 0.0300 | Align Loss (raw): 0.0085\n",
      "  Val Acc: 0.9903 | F1: 0.9903 | Recall: 0.9877 | Precision: 0.9928\n",
      "  Model not saved. Best F1 so far: 0.9913\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.88batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 34/50\n",
      "  Train Total Loss: 0.0254 | CLS Loss: 0.0202 | Align Loss (raw): 0.0104\n",
      "  Train Acc: 0.9930 | F1: 0.9929 | Recall: 0.9923 | Precision: 0.9936\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.38batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0305 | CLS Loss: 0.0259 | Align Loss (raw): 0.0092\n",
      "  Val Acc: 0.9917 | F1: 0.9917 | Recall: 0.9894 | Precision: 0.9941\n",
      "  New Best F1: 0.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.77batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 35/50\n",
      "  Train Total Loss: 0.0241 | CLS Loss: 0.0191 | Align Loss (raw): 0.0099\n",
      "  Train Acc: 0.9934 | F1: 0.9934 | Recall: 0.9928 | Precision: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.74batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0544 | CLS Loss: 0.0496 | Align Loss (raw): 0.0096\n",
      "  Val Acc: 0.9845 | F1: 0.9843 | Recall: 0.9715 | Precision: 0.9974\n",
      "  Model not saved. Best F1 so far: 0.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.80batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 36/50\n",
      "  Train Total Loss: 0.0249 | CLS Loss: 0.0198 | Align Loss (raw): 0.0101\n",
      "  Train Acc: 0.9929 | F1: 0.9929 | Recall: 0.9922 | Precision: 0.9935\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.11batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0380 | CLS Loss: 0.0334 | Align Loss (raw): 0.0090\n",
      "  Val Acc: 0.9897 | F1: 0.9898 | Recall: 0.9959 | Precision: 0.9837\n",
      "  Model not saved. Best F1 so far: 0.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.10batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 37/50\n",
      "  Train Total Loss: 0.0233 | CLS Loss: 0.0184 | Align Loss (raw): 0.0099\n",
      "  Train Acc: 0.9933 | F1: 0.9933 | Recall: 0.9927 | Precision: 0.9939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.86batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0316 | CLS Loss: 0.0274 | Align Loss (raw): 0.0083\n",
      "  Val Acc: 0.9916 | F1: 0.9916 | Recall: 0.9901 | Precision: 0.9931\n",
      "  Model not saved. Best F1 so far: 0.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.01batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 38/50\n",
      "  Train Total Loss: 0.0230 | CLS Loss: 0.0184 | Align Loss (raw): 0.0093\n",
      "  Train Acc: 0.9937 | F1: 0.9936 | Recall: 0.9931 | Precision: 0.9942\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.36batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0433 | CLS Loss: 0.0339 | Align Loss (raw): 0.0186\n",
      "  Val Acc: 0.9894 | F1: 0.9893 | Recall: 0.9823 | Precision: 0.9964\n",
      "  Model not saved. Best F1 so far: 0.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.10batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 39/50\n",
      "  Train Total Loss: 0.0220 | CLS Loss: 0.0171 | Align Loss (raw): 0.0097\n",
      "  Train Acc: 0.9939 | F1: 0.9939 | Recall: 0.9935 | Precision: 0.9943\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.50batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0446 | CLS Loss: 0.0386 | Align Loss (raw): 0.0121\n",
      "  Val Acc: 0.9884 | F1: 0.9883 | Recall: 0.9802 | Precision: 0.9965\n",
      "  Model not saved. Best F1 so far: 0.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.13batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 40/50\n",
      "  Train Total Loss: 0.0221 | CLS Loss: 0.0173 | Align Loss (raw): 0.0095\n",
      "  Train Acc: 0.9937 | F1: 0.9937 | Recall: 0.9933 | Precision: 0.9941\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 175.93batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0723 | CLS Loss: 0.0683 | Align Loss (raw): 0.0080\n",
      "  Val Acc: 0.9805 | F1: 0.9801 | Recall: 0.9629 | Precision: 0.9980\n",
      "  Model not saved. Best F1 so far: 0.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.10batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 41/50\n",
      "  Train Total Loss: 0.0217 | CLS Loss: 0.0169 | Align Loss (raw): 0.0095\n",
      "  Train Acc: 0.9941 | F1: 0.9941 | Recall: 0.9936 | Precision: 0.9946\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.70batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0381 | CLS Loss: 0.0314 | Align Loss (raw): 0.0136\n",
      "  Val Acc: 0.9912 | F1: 0.9911 | Recall: 0.9862 | Precision: 0.9961\n",
      "  Model not saved. Best F1 so far: 0.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.89batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 42/50\n",
      "  Train Total Loss: 0.0212 | CLS Loss: 0.0168 | Align Loss (raw): 0.0089\n",
      "  Train Acc: 0.9942 | F1: 0.9942 | Recall: 0.9939 | Precision: 0.9945\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 178.17batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0305 | CLS Loss: 0.0262 | Align Loss (raw): 0.0087\n",
      "  Val Acc: 0.9912 | F1: 0.9912 | Recall: 0.9903 | Precision: 0.9920\n",
      "  Model not saved. Best F1 so far: 0.9917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.79batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 43/50\n",
      "  Train Total Loss: 0.0208 | CLS Loss: 0.0162 | Align Loss (raw): 0.0092\n",
      "  Train Acc: 0.9945 | F1: 0.9944 | Recall: 0.9940 | Precision: 0.9948\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.36batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0292 | CLS Loss: 0.0256 | Align Loss (raw): 0.0072\n",
      "  Val Acc: 0.9918 | F1: 0.9918 | Recall: 0.9930 | Precision: 0.9906\n",
      "  New Best F1: 0.9918\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.87batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 44/50\n",
      "  Train Total Loss: 0.0213 | CLS Loss: 0.0168 | Align Loss (raw): 0.0090\n",
      "  Train Acc: 0.9939 | F1: 0.9939 | Recall: 0.9934 | Precision: 0.9944\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 178.83batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0505 | CLS Loss: 0.0460 | Align Loss (raw): 0.0091\n",
      "  Val Acc: 0.9862 | F1: 0.9860 | Recall: 0.9741 | Precision: 0.9982\n",
      "  Model not saved. Best F1 so far: 0.9918\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.04batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 45/50\n",
      "  Train Total Loss: 0.0197 | CLS Loss: 0.0151 | Align Loss (raw): 0.0093\n",
      "  Train Acc: 0.9950 | F1: 0.9950 | Recall: 0.9947 | Precision: 0.9952\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.36batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0288 | CLS Loss: 0.0245 | Align Loss (raw): 0.0086\n",
      "  Val Acc: 0.9923 | F1: 0.9922 | Recall: 0.9898 | Precision: 0.9947\n",
      "  New Best F1: 0.9922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.02batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 46/50\n",
      "  Train Total Loss: 0.0209 | CLS Loss: 0.0163 | Align Loss (raw): 0.0092\n",
      "  Train Acc: 0.9945 | F1: 0.9945 | Recall: 0.9941 | Precision: 0.9949\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 175.37batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0416 | CLS Loss: 0.0366 | Align Loss (raw): 0.0098\n",
      "  Val Acc: 0.9878 | F1: 0.9877 | Recall: 0.9796 | Precision: 0.9959\n",
      "  Model not saved. Best F1 so far: 0.9922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.72batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 47/50\n",
      "  Train Total Loss: 0.0196 | CLS Loss: 0.0152 | Align Loss (raw): 0.0088\n",
      "  Train Acc: 0.9947 | F1: 0.9947 | Recall: 0.9946 | Precision: 0.9947\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 177.37batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0288 | CLS Loss: 0.0250 | Align Loss (raw): 0.0075\n",
      "  Val Acc: 0.9918 | F1: 0.9918 | Recall: 0.9906 | Precision: 0.9930\n",
      "  Model not saved. Best F1 so far: 0.9922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48/50 Training: 100%|██████████| 4809/4809 [01:47<00:00, 44.80batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 48/50\n",
      "  Train Total Loss: 0.0195 | CLS Loss: 0.0151 | Align Loss (raw): 0.0088\n",
      "  Train Acc: 0.9945 | F1: 0.9945 | Recall: 0.9944 | Precision: 0.9946\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.61batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0384 | CLS Loss: 0.0347 | Align Loss (raw): 0.0074\n",
      "  Val Acc: 0.9901 | F1: 0.9902 | Recall: 0.9940 | Precision: 0.9864\n",
      "  Model not saved. Best F1 so far: 0.9922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 44.95batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 49/50\n",
      "  Train Total Loss: 0.0189 | CLS Loss: 0.0147 | Align Loss (raw): 0.0085\n",
      "  Train Acc: 0.9950 | F1: 0.9950 | Recall: 0.9945 | Precision: 0.9955\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 178.20batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0316 | CLS Loss: 0.0278 | Align Loss (raw): 0.0076\n",
      "  Val Acc: 0.9915 | F1: 0.9915 | Recall: 0.9877 | Precision: 0.9953\n",
      "  Model not saved. Best F1 so far: 0.9922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50/50 Training: 100%|██████████| 4809/4809 [01:46<00:00, 45.02batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50/50\n",
      "  Train Total Loss: 0.0196 | CLS Loss: 0.0152 | Align Loss (raw): 0.0087\n",
      "  Train Acc: 0.9946 | F1: 0.9945 | Recall: 0.9940 | Precision: 0.9950\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50/50 Validation: 100%|██████████| 1203/1203 [00:06<00:00, 176.91batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Val Total Loss: 0.0444 | CLS Loss: 0.0380 | Align Loss (raw): 0.0128\n",
      "  Val Acc: 0.9892 | F1: 0.9892 | Recall: 0.9809 | Precision: 0.9976\n",
      "  Model not saved. Best F1 so far: 0.9922\n",
      "Training Finished.\n",
      "Best Validation Accuracy achieved: 0.0000\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n",
    "from torch.optim import Adam , AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "from sklearn.metrics import roc_curve\n",
    "import numpy as np\n",
    "\n",
    "def compute_eer(y_true, y_scores):\n",
    "    fpr, tpr, thresholds = roc_curve(y_true, y_scores)\n",
    "    fnr = 1 - tpr\n",
    "    # Find the point where FPR = FNRbest_model_l1_loss.pth\n",
    "    eer_threshold_index = np.nanargmin(np.absolute((fnr - fpr)))\n",
    "    eer = (fpr[eer_threshold_index] + fnr[eer_threshold_index]) / 2\n",
    "    return eer\n",
    "\n",
    "\n",
    "# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "val_split = 0.2\n",
    "train_size = int((1 - val_split) * len(train_dataset))\n",
    "val_size = len(train_dataset) - train_size\n",
    "train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size] , generator=torch.Generator().manual_seed(42))\n",
    "train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "in_channel1 = 13\n",
    "in_channel2 = 13\n",
    "embed_dim1 = 768\n",
    "embed_dim2 = 768\n",
    "best_acc = 0\n",
    "best_f1 = 0\n",
    "\n",
    "model = CLAM(in_channel1 , in_channel2 , embed_dim1 , embed_dim2).to(device)\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=1e-4)\n",
    "\n",
    "classification_criterion = nn.BCEWithLogitsLoss() # For fake/real classification\n",
    "margin = 0.2 # You might need to tune this hyperparameter\n",
    "alignment_criterion = nn.TripletMarginLoss(margin=margin)\n",
    "save_name = f\"best_model_triplet_loss_margin_{margin}.pth\" # Update save name\n",
    "save_folder = \"model_wts\"\n",
    "# --- Hyperparameters ---\n",
    "epochs = 50\n",
    "alignment_loss_weight = 0.5 # Weight factor for the alignment loss \n",
    "print(\"Starting Training...\")\n",
    "\n",
    "\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    model.train()\n",
    "    train_loss_total = 0\n",
    "    train_loss_cls = 0\n",
    "    train_loss_align = 0\n",
    "    train_preds_all = []\n",
    "    train_labels_all = []\n",
    "\n",
    "    for batch in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{epochs} Training\", unit=\"batch\"):\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        embed1, embed2, labels = batch\n",
    "        embed1 = embed1.to(device)\n",
    "        embed2 = embed2.to(device)\n",
    "        labels = labels.float().to(device).view(-1)\n",
    "\n",
    "        # --- Forward pass ---\n",
    "        outputs , real_emb, fake_emb  = model.forward_training(embed1, embed2)\n",
    "        outputs = outputs.view(-1) \n",
    "\n",
    "        # --- Calculate Classification Loss (for all samples) ---\n",
    "        classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    " \n",
    "        real_indices = (labels == 0).nonzero(as_tuple=True)[0]\n",
    "\n",
    "        alignment_loss = torch.tensor(0.0).to(device) # Initialize alignment loss for this batch\n",
    "\n",
    "        if real_indices.nelement() > 0: # Check if there are any real samples in the batch\n",
    "            # 2. Select the embeddings corresponding to real samples\n",
    "            real_emb_filtered = real_emb[real_indices] # These will be anchors\n",
    "            fake_emb_filtered = fake_emb[real_indices] # These will be positives\n",
    "\n",
    "            num_real_samples = real_emb_filtered.size(0)\n",
    "\n",
    "            # We need at least 2 real samples in the batch to form triplets with negatives\n",
    "            if num_real_samples > 1:\n",
    "                triplet_losses = []\n",
    "                # In-batch mining: Iterate through each real sample as anchor/positive\n",
    "                for i in range(num_real_samples):\n",
    "                    anchor = real_emb_filtered[i]\n",
    "                    positive = fake_emb_filtered[i]\n",
    "\n",
    "                    # Select all *other* fake embeddings from real samples in the batch as negatives\n",
    "                    # We need to ensure the negative is not the positive sample itself\n",
    "                    negative_indices = [j for j in range(num_real_samples) if j != i]\n",
    "                    negatives = fake_emb_filtered[negative_indices]\n",
    "\n",
    "                    for neg in negatives:\n",
    "\n",
    "                        loss = alignment_criterion(anchor.unsqueeze(0), positive.unsqueeze(0), neg.unsqueeze(0))\n",
    "                        triplet_losses.append(loss)\n",
    "\n",
    "                if triplet_losses: # Ensure the list is not empty\n",
    "                    alignment_loss = torch.mean(torch.stack(triplet_losses)) # Average the triplet losses\n",
    "                else:\n",
    "                    alignment_loss = torch.tensor(0.0).to(device) # No valid triplets formed\n",
    "\n",
    "        # --- Combine Losses ---\n",
    "        total_loss = classification_loss + alignment_loss_weight * alignment_loss\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # --- Accumulate metrics and losses for reporting ---\n",
    "        train_loss_total += total_loss.item()\n",
    "        train_loss_cls += classification_loss.item()\n",
    "        train_loss_align += alignment_loss.item() # Note: this is the raw alignment loss before weighting\n",
    "\n",
    "\n",
    "        preds = torch.sigmoid(outputs).detach().round() # Get predictions (0 or 1)\n",
    "        train_preds_all.extend(preds.cpu().numpy())\n",
    "        train_labels_all.extend(labels.cpu().numpy())\n",
    "\n",
    "    # --- Calculate Epoch Metrics (Training) ---\n",
    "    num_batches = len(train_loader)\n",
    "    avg_train_loss_total = train_loss_total / num_batches\n",
    "    avg_train_loss_cls = train_loss_cls / num_batches\n",
    "    avg_train_loss_align = train_loss_align / num_batches # Average raw alignment loss across batches where it was calculated\n",
    "\n",
    "    train_acc = accuracy_score(train_labels_all, train_preds_all)\n",
    "    train_f1 = f1_score(train_labels_all, train_preds_all, average='binary', zero_division=0)\n",
    "    train_recall = recall_score(train_labels_all, train_preds_all, average='binary', zero_division=0)\n",
    "    train_precision = precision_score(train_labels_all, train_preds_all, average='binary', zero_division=0)\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{epochs}\")\n",
    "    print(f\"  Train Total Loss: {avg_train_loss_total:.4f} | CLS Loss: {avg_train_loss_cls:.4f} | Align Loss (raw): {avg_train_loss_align:.4f}\")\n",
    "    print(f\"  Train Acc: {train_acc:.4f} | F1: {train_f1:.4f} | Recall: {train_recall:.4f} | Precision: {train_precision:.4f}\")\n",
    "\n",
    "\n",
    "    # --- Validation ---\n",
    "    model.eval()\n",
    "    val_loss_total = 0\n",
    "    val_loss_cls = 0\n",
    "    val_loss_align = 0 \n",
    "    val_preds_all = []\n",
    "    val_labels_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(val_loader, desc=f\"Epoch {epoch+1}/{epochs} Validation\", unit=\"batch\"):\n",
    "            embed1, embed2, labels = batch\n",
    "            embed1 = embed1.to(device)\n",
    "            embed2 = embed2.to(device)\n",
    "            labels = labels.float().to(device).view(-1)\n",
    "\n",
    "            # --- Forward pass (validation) ---\n",
    "            outputs , real_emb, fake_emb  = model.forward_training(embed1, embed2)\n",
    "            outputs = outputs.view(-1)\n",
    "\n",
    "            classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    "            real_indices = (labels == 0).nonzero(as_tuple=True)[0]\n",
    "            alignment_loss = torch.tensor(0.0).to(device) # Initialize alignment loss\n",
    "\n",
    "            if real_indices.nelement() > 0:\n",
    "                real_emb_filtered = real_emb[real_indices]\n",
    "                fake_emb_filtered = fake_emb[real_indices]\n",
    "                num_real_samples = real_emb_filtered.size(0)\n",
    "\n",
    "                if num_real_samples > 1:\n",
    "                    triplet_losses = []\n",
    "                    for i in range(num_real_samples):\n",
    "                        anchor = real_emb_filtered[i]\n",
    "                        positive = fake_emb_filtered[i]\n",
    "                        negative_indices = [j for j in range(num_real_samples) if j != i]\n",
    "                        negatives = fake_emb_filtered[negative_indices]\n",
    "\n",
    "                        for neg in negatives:\n",
    "                            loss = alignment_criterion(anchor.unsqueeze(0), positive.unsqueeze(0), neg.unsqueeze(0))\n",
    "                            triplet_losses.append(loss)\n",
    "\n",
    "                    if triplet_losses:\n",
    "                        alignment_loss = torch.mean(torch.stack(triplet_losses))\n",
    "                    else:\n",
    "                        alignment_loss = torch.tensor(0.0).to(device)\n",
    "\n",
    "\n",
    "            total_loss = classification_loss + alignment_loss_weight * alignment_loss\n",
    "            val_loss_total += total_loss.item()\n",
    "            val_loss_cls += classification_loss.item()\n",
    "            val_loss_align += alignment_loss.item()\n",
    "\n",
    "            preds = torch.sigmoid(outputs).detach().round()\n",
    "            val_preds_all.extend(preds.cpu().numpy())\n",
    "            val_labels_all.extend(labels.cpu().numpy())\n",
    "\n",
    "    num_batches_val = len(val_loader)\n",
    "    avg_val_loss_total = val_loss_total / num_batches_val\n",
    "    avg_val_loss_cls = val_loss_cls / num_batches_val\n",
    "    avg_val_loss_align = val_loss_align / num_batches_val\n",
    "\n",
    "    val_acc = accuracy_score(val_labels_all, val_preds_all)\n",
    "    val_f1 = f1_score(val_labels_all, val_preds_all, average='binary', zero_division=0)\n",
    "    val_recall = recall_score(val_labels_all, val_preds_all, average='binary', zero_division=0)\n",
    "    val_precision = precision_score(val_labels_all, val_preds_all, average='binary', zero_division=0)\n",
    "\n",
    "    print(f\"  Val Total Loss: {avg_val_loss_total:.4f} | CLS Loss: {avg_val_loss_cls:.4f} | Align Loss (raw): {avg_val_loss_align:.4f}\")\n",
    "    print(f\"  Val Acc: {val_acc:.4f} | F1: {val_f1:.4f} | Recall: {val_recall:.4f} | Precision: {val_precision:.4f}\")\n",
    "\n",
    "    current_val_metric = val_acc \n",
    "\n",
    "    #Best F1\n",
    "    if val_f1 > best_f1:\n",
    "        best_f1 = val_f1\n",
    "        print(f\"  New Best F1: {best_f1:.4f}\")\n",
    "        torch.save(model.state_dict(), os.path.join(save_folder, save_name))\n",
    "    else:\n",
    "        print(f\"  Model not saved. Best F1 so far: {best_f1:.4f}\")\n",
    "\n",
    "\n",
    "\n",
    "print(\"Training Finished.\")\n",
    "print(f\"Best Validation Accuracy achieved: {best_acc:.4f}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2426b781",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Testing: 100%|██████████| 2139/2139 [00:14<00:00, 152.63batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Total Loss: 0.3658 | CLS Loss: 0.3240 | Align Loss (raw): 0.0835\n",
      "Test Acc: 0.9302 | F1: 0.9255 | Recall: 0.8692 | Precision: 0.9894\n",
      "EER: 0.0649\n"
     ]
    }
   ],
   "source": [
    "#On test set\n",
    "model.load_state_dict(torch.load(os.path.join(save_folder, f\"{save_name}\")))\n",
    "model.eval()\n",
    "test_loss_total = 0\n",
    "test_loss_cls = 0\n",
    "test_loss_align = 0\n",
    "test_preds_all = []\n",
    "test_labels_all = []\n",
    "test_scores_all = []    \n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(test_loader, desc=\"Testing\", unit=\"batch\"):\n",
    "        embed1, embed2, labels = batch\n",
    "        embed1 = embed1.to(device)\n",
    "        embed2 = embed2.to(device)\n",
    "        labels = labels.float().to(device).view(-1)\n",
    "\n",
    "        # --- Forward pass (testing) ---\n",
    "        outputs , real_emb, fake_emb  = model.forward_training(embed1, embed2)\n",
    "        outputs = outputs.view(-1)\n",
    "\n",
    "        # --- Calculate Classification Loss (Testing) ---\n",
    "        classification_loss = classification_criterion(outputs, labels)\n",
    "\n",
    "        real_indices = (labels == 1).nonzero(as_tuple=True)[0]\n",
    "        alignment_loss = torch.tensor(0.0).to(device)\n",
    "        if real_indices.nelement() > 0:\n",
    "            real_emb_filtered = real_emb[real_indices]\n",
    "            fake_emb_filtered = fake_emb[real_indices]\n",
    "            \n",
    "            num_real_samples = real_emb_filtered.size(0)\n",
    "            if num_real_samples > 1:\n",
    "                triplet_losses = []\n",
    "                for i in range(num_real_samples):\n",
    "                    anchor = real_emb_filtered[i]\n",
    "                    positive = fake_emb_filtered[i]\n",
    "                    negative_indices = [j for j in range(num_real_samples) if j != i]\n",
    "                    negatives = fake_emb_filtered[negative_indices]\n",
    "\n",
    "                    for neg in negatives:\n",
    "                        loss = alignment_criterion(anchor.unsqueeze(0), positive.unsqueeze(0), neg.unsqueeze(0))\n",
    "                        triplet_losses.append(loss)\n",
    "\n",
    "                if triplet_losses:\n",
    "                    alignment_loss = torch.mean(torch.stack(triplet_losses))\n",
    "                else:\n",
    "                    alignment_loss = torch.tensor(0.0).to(device)\n",
    "\n",
    "        total_loss = classification_loss + alignment_loss_weight * alignment_loss\n",
    "\n",
    "        test_loss_total += total_loss.item()\n",
    "        test_loss_cls += classification_loss.item()\n",
    "        test_loss_align += alignment_loss.item()\n",
    "\n",
    "        preds = torch.sigmoid(outputs).detach().round()\n",
    "        test_preds_all.extend(preds.cpu().numpy())\n",
    "        test_labels_all.extend(labels.cpu().numpy())\n",
    "        test_scores_all.extend(torch.sigmoid(outputs).cpu().numpy())\n",
    "\n",
    "    num_batches_test = len(test_loader)\n",
    "    avg_test_loss_total = test_loss_total / num_batches_test\n",
    "    avg_test_loss_cls = test_loss_cls / num_batches_test    \n",
    "    avg_test_loss_align = test_loss_align / num_batches_test\n",
    "    test_acc = accuracy_score(test_labels_all, test_preds_all)  \n",
    "    test_f1 = f1_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    test_recall = recall_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    test_precision = precision_score(test_labels_all, test_preds_all, average='binary', zero_division=0)\n",
    "    eer = compute_eer(test_labels_all, test_scores_all)  # Calculate EER using the raw scores\n",
    "    print(f\"Test Total Loss: {avg_test_loss_total:.4f} | CLS Loss: {avg_test_loss_cls:.4f} | Align Loss (raw): {avg_test_loss_align:.4f}\")\n",
    "    print(f\"Test Acc: {test_acc:.4f} | F1: {test_f1:.4f} | Recall: {test_recall:.4f} | Precision: {test_precision:.4f}\")\n",
    "    print(f\"EER: {eer:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c39c682",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Testing: 100%|██████████| 2139/2139 [00:03<00:00, 595.20batch/s]\n",
    "# Test Total Loss: 0.4560 | CLS Loss: 0.4164 | Align Loss (raw): 0.0793\n",
    "# Test Acc: 0.9185 | F1: 0.9114 | Recall: 0.8414 | Precision: 0.9942\n",
    "# EER: 0.0625"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "uni3d",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
