{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.metrics import f1_score\n",
    "from sklearn.svm import SVC\n",
    "\n",
    "\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "from vit import ViT\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "afib_dataset_path = \"Eval/af_csv_signals.csv\"\n",
    "afib_csv = pd.read_csv(afib_dataset_path)\n",
    "afib_anno = pd.read_csv(\"Eval/af_annotations.csv\")\n",
    "\n",
    "afib_signals = np.array(afib_csv)\n",
    "afib_anno = np.array(afib_anno)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_BLOCKS = 4\n",
    "MODEL_DIM = 128\n",
    "NUM_HEADS = 4\n",
    "PATCH_SIZE = 10\n",
    "FS = 100\n",
    "L = 10\n",
    "C_IN = 1\n",
    "\n",
    "model = ViT(num_blocks=NUM_BLOCKS, num_heads=NUM_HEADS, model_dim=MODEL_DIM, \n",
    "            patch_size=PATCH_SIZE, in_channels=C_IN, fs=FS, l=L, do_prob=0.1).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# For Main Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Eval clocs.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 23/23 [56:37<00:00, 147.70s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Accu: \n",
      "0.663883214469173\n",
      "\n",
      "F1 Score: \n",
      "0.5897944306750256\n",
      "\n",
      "Eval cupid_0.5.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 23/23 [20:53<00:00, 54.50s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Accu: \n",
      "0.8629299119568983\n",
      "\n",
      "F1 Score: \n",
      "0.8430111646234677\n",
      "\n",
      "Eval deaps.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 23/23 [24:04<00:00, 62.79s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Accu: \n",
      "0.762522548352029\n",
      "\n",
      "F1 Score: \n",
      "0.7471605001081109\n",
      "\n",
      "Eval jepa_0.5.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 23/23 [41:12<00:00, 107.50s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Accu: \n",
      "0.7511856550669581\n",
      "\n",
      "F1 Score: \n",
      "0.7051695826963366\n",
      "\n",
      "Eval mae_0.5.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 23/23 [43:59<00:00, 114.78s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Accu: \n",
      "0.7656882772461743\n",
      "\n",
      "F1 Score: \n",
      "0.7283529998338042\n",
      "\n",
      "Eval mix_up.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 23/23 [53:14<00:00, 138.90s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Accu: \n",
      "0.6193121408689627\n",
      "\n",
      "F1 Score: \n",
      "0.5691843880544553\n",
      "\n",
      "Eval pclr.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 23/23 [42:14<00:00, 110.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Accu: \n",
      "0.7523922158907644\n",
      "\n",
      "F1 Score: \n",
      "0.7384506669022172\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "train_dir = r\"ViT Trained Models\"\n",
    "\n",
    "for weights in os.listdir(train_dir):\n",
    "    print(\"\\nEval \" + weights)\n",
    "    weights_path = os.path.join(train_dir, weights)\n",
    "    model.load_state_dict(torch.load(weights_path))\n",
    "    model.eval()\n",
    "    print(\"Model Loaded\")\n",
    "\n",
    "\n",
    "    final_results = []\n",
    "    final_y = []\n",
    "\n",
    "\n",
    "\n",
    "    features = []\n",
    "    bs = 512\n",
    "\n",
    "    for i in range(0, afib_signals.shape[0], bs):\n",
    "        tmp_signal = afib_signals[i : i + bs, :]\n",
    "\n",
    "        with torch.no_grad():\n",
    "            tmp_strip = torch.tensor(tmp_signal).float().to(device)\n",
    "            tmp_features = model(tmp_strip).squeeze().detach().cpu().numpy()\n",
    "            features.extend(tmp_features)\n",
    "\n",
    "    features = np.array(features)\n",
    "\n",
    "\n",
    "    afib_anno = np.array(afib_anno)\n",
    "    final_results = []\n",
    "    final_y = []\n",
    "\n",
    "    accu = []\n",
    "\n",
    "    for i in tqdm(np.unique(afib_anno[:, 0])):\n",
    "        train_idx = afib_anno[:, 0] != i\n",
    "        test_idx = afib_anno[:, 0] == i\n",
    "        \n",
    "        svc = SVC(gamma='auto')\n",
    "        y_train = np.array(afib_anno[train_idx, 2], dtype = np.int32)\n",
    "        y_val = np.array(afib_anno[test_idx, 2], dtype = np.int32)\n",
    "\n",
    "        train_features = features[train_idx, :]\n",
    "        test_features = features[test_idx, :]\n",
    "\n",
    "        \n",
    "        scaler = StandardScaler()\n",
    "        train_features = scaler.fit_transform(train_features)\n",
    "        test_features = scaler.transform(test_features)\n",
    "\n",
    "        svc.fit(train_features, y_train)\n",
    "        pred = svc.predict(test_features)\n",
    "\n",
    "        final_results.extend(pred)\n",
    "        final_y.extend(y_val)\n",
    "\n",
    "    print(\"Total Accu: \")\n",
    "    print(np.mean(np.array(final_y) == np.array(final_results)))\n",
    "    print(\"\")\n",
    "    print(\"F1 Score: \")\n",
    "    print(f1_score(final_y, final_results))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
