{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "\n",
    "\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import random \n",
    "\n",
    "from vit import ViT\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "l_af_signals = np.load(\"Eval/l_af_signals.npy\")\n",
    "l_af_anno = pd.read_csv(\"Eval/l_af_annos.csv\")\n",
    "l_af_anno = np.array(l_af_anno)\n",
    "\n",
    "for i in range(l_af_anno.shape[0]):\n",
    "    if l_af_anno[i, 1] == 3:\n",
    "        l_af_anno[i, 1] = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_BLOCKS = 4\n",
    "MODEL_DIM = 128\n",
    "NUM_HEADS = 4\n",
    "PATCH_SIZE = 10\n",
    "FS = 100\n",
    "L = 10\n",
    "C_IN = 1\n",
    "\n",
    "model = ViT(num_blocks=NUM_BLOCKS, num_heads=NUM_HEADS, model_dim=MODEL_DIM, \n",
    "            patch_size=PATCH_SIZE, in_channels=C_IN, fs=FS, l=L, do_prob=0.1).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Eval clocs.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [47:21<00:00, 284.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.6775692031402174\n",
      "0.009826711780677072\n",
      "\n",
      "AUC Results: \n",
      "0.765167171831721\n",
      "0.014210023518024003\n",
      "\n",
      "Eval cupid_0.5.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [43:49<00:00, 262.92s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.8791636216120127\n",
      "0.003338301005074082\n",
      "\n",
      "AUC Results: \n",
      "0.934668557443284\n",
      "0.0019335426211997487\n",
      "\n",
      "Eval deaps.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [45:43<00:00, 274.35s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.8429990078566207\n",
      "0.0051954092696666\n",
      "\n",
      "AUC Results: \n",
      "0.8819042803866803\n",
      "0.007208380524447907\n",
      "\n",
      "Eval jepa_0.5.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [46:04<00:00, 276.40s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.7812512249856042\n",
      "0.0047707654105323604\n",
      "\n",
      "AUC Results: \n",
      "0.8678041651834831\n",
      "0.003970643245403917\n",
      "\n",
      "Eval mae_0.5.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [46:13<00:00, 277.37s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.8048440298949666\n",
      "0.006164349052660974\n",
      "\n",
      "AUC Results: \n",
      "0.8842139737923944\n",
      "0.005927803478116117\n",
      "\n",
      "Eval mix_up.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [42:26<00:00, 254.67s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.6101598216876769\n",
      "0.007614861410339213\n",
      "\n",
      "AUC Results: \n",
      "0.6476143019469276\n",
      "0.017465333194692657\n",
      "\n",
      "Eval pclr.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [35:47<00:00, 214.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.8076615347683201\n",
      "0.003352883286242947\n",
      "\n",
      "AUC Results: \n",
      "0.8010406345810829\n",
      "0.005988818946593131\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "train_dir = r\"ViT Trained Models\"\n",
    "\n",
    "for weights in os.listdir(train_dir):\n",
    "    random.seed(0)\n",
    "    np.random.seed(0)\n",
    "\n",
    "    total_accus, total_auc = [], []\n",
    "    features = []\n",
    "    bs = 512\n",
    "    weights_path = os.path.join(train_dir, weights)\n",
    "    model.load_state_dict(torch.load(weights_path))\n",
    "    model.eval()\n",
    "    \n",
    "    for i in range(0, l_af_signals.shape[0], bs):\n",
    "        tmp_signal = l_af_signals[i : i + bs, :]\n",
    "\n",
    "        with torch.no_grad():\n",
    "            tmp_strip = torch.tensor(tmp_signal).float().to(device)\n",
    "            tmp_features = model(tmp_strip).squeeze().detach().cpu().numpy()\n",
    "            features.extend(tmp_features)\n",
    "\n",
    "    features = np.array(features)\n",
    "    print(\"\\nEval \" + weights)\n",
    "\n",
    "    for seed in tqdm(np.arange(10)):  \n",
    "        \n",
    "        final_results = []\n",
    "        final_y = []\n",
    "        final_probs = []\n",
    "\n",
    "        accu = []\n",
    "        kf = KFold(n_splits=10, shuffle=True, random_state=seed)\n",
    "\n",
    "        unique_patients = np.unique(l_af_anno[:, 0])\n",
    "\n",
    "        #for i in tqdm(np.unique(l_af_anno[:, 0])):\n",
    "        for i, (train_pat_index, test_pat_index) in enumerate(kf.split(unique_patients)):\n",
    "\n",
    "            svc = LogisticRegression(random_state=0)\n",
    "\n",
    "            train_patients = unique_patients[train_pat_index]\n",
    "            test_patients = unique_patients[test_pat_index]\n",
    "\n",
    "            train_index = [i for i in range(l_af_anno.shape[0]) if l_af_anno[i, 0] not in test_patients]\n",
    "            test_index = [i for i in range(l_af_anno.shape[0]) if l_af_anno[i, 0] in test_patients]\n",
    "\n",
    "            train_features = features[train_index, :]\n",
    "            train_annos = l_af_anno[train_index, 1]\n",
    "\n",
    "            test_features = features[test_index, :]\n",
    "            test_annos = l_af_anno[test_index, 1]\n",
    "\n",
    "            y_train = np.array(train_annos, dtype = np.int32)\n",
    "            y_val = np.array(test_annos, dtype = np.int32)\n",
    "\n",
    "            \n",
    "            scaler = StandardScaler()\n",
    "            train_features = scaler.fit_transform(train_features)\n",
    "            test_features = scaler.transform(test_features)\n",
    "\n",
    "            svc.fit(train_features, y_train)\n",
    "            pred = svc.predict(test_features)\n",
    "            probs = svc.predict_proba(test_features)\n",
    "\n",
    "            final_results.extend(pred)\n",
    "            final_y.extend(y_val)\n",
    "            final_probs.extend(probs)\n",
    "\n",
    "        total_accus.append(np.mean(np.array(final_y) == np.array(final_results)))\n",
    "        total_auc.append(roc_auc_score(final_y, final_probs, multi_class='ovr'))\n",
    "        \n",
    "    print(\"\\nAccuracy Results: \")\n",
    "    print(np.mean(total_accus))\n",
    "    print(np.std(total_accus))\n",
    "\n",
    "    print(\"\\nAUC Results: \")\n",
    "    print(np.mean(total_auc))\n",
    "    print(np.std(total_auc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
