{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\adria\\miniconda3\\envs\\pytorch\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import accuracy_score\n",
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "import random\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score,roc_auc_score\n",
    "\n",
    "from scipy import signal\n",
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "from sklearn.svm import SVC\n",
    "\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "from vit import ViT\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_BLOCKS = 4\n",
    "MODEL_DIM = 128\n",
    "NUM_HEADS = 4\n",
    "PATCH_SIZE = 10\n",
    "FS = 100\n",
    "L = 10\n",
    "C_IN = 1\n",
    "\n",
    "model = ViT(num_blocks=NUM_BLOCKS, num_heads=NUM_HEADS, model_dim=MODEL_DIM, \n",
    "            patch_size=PATCH_SIZE, in_channels=C_IN, fs=FS, l=L, do_prob=0.1).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "signals = pd.read_csv(\"Eval/svta_signals.csv\")\n",
    "annos = pd.read_csv(\"Eval/svta_anno.csv\")\n",
    "\n",
    "signals = np.array(signals)\n",
    "annos = np.array(annos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Eval clocs.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [01:03<00:00,  6.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.520026466696074\n",
      "0.008318691176364133\n",
      "\n",
      "AUC Results: \n",
      "0.5614401020263851\n",
      "0.011256315482775477\n",
      "\n",
      "Eval cupid_0.5.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [01:07<00:00,  6.77s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.5796647551830614\n",
      "0.012281817643625506\n",
      "\n",
      "AUC Results: \n",
      "0.660345425654935\n",
      "0.005417583596678688\n",
      "\n",
      "Eval deaps.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [01:01<00:00,  6.17s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.4832289369210411\n",
      "0.014072895347400239\n",
      "\n",
      "AUC Results: \n",
      "0.57845642479209\n",
      "0.01085782909157493\n",
      "\n",
      "Eval jepa_0.5.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [01:02<00:00,  6.27s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.5231936479929422\n",
      "0.00654633395402153\n",
      "\n",
      "AUC Results: \n",
      "0.621194214380224\n",
      "0.006905792486262466\n",
      "\n",
      "Eval mae_0.5.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [01:01<00:00,  6.16s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.5197441552712835\n",
      "0.006230651010254346\n",
      "\n",
      "AUC Results: \n",
      "0.6031814543572512\n",
      "0.009230528813224769\n",
      "\n",
      "Eval mix_up.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [01:01<00:00,  6.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.5263431848257609\n",
      "0.011591315084995907\n",
      "\n",
      "AUC Results: \n",
      "0.6123618888814298\n",
      "0.010358943063800105\n",
      "\n",
      "Eval pclr.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [01:03<00:00,  6.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy Results: \n",
      "0.49315394794883105\n",
      "0.013775990873034103\n",
      "\n",
      "AUC Results: \n",
      "0.5855360001127654\n",
      "0.009758637135090713\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "train_dir = r\"ViT Trained Models\"\n",
    "\n",
    "for weights in os.listdir(train_dir):\n",
    "    random.seed(0)\n",
    "    np.random.seed(0)\n",
    "\n",
    "    total_accus, total_auc = [], []\n",
    "    features = []\n",
    "    bs = 512\n",
    "    weights_path = os.path.join(train_dir, weights)\n",
    "    model.load_state_dict(torch.load(weights_path))\n",
    "    model.eval()\n",
    "    \n",
    "    for i in range(0, signals.shape[0], bs):\n",
    "        tmp_signal = signals[i : i + bs, :]\n",
    "\n",
    "        with torch.no_grad():\n",
    "            tmp_strip = torch.tensor(tmp_signal).float().to(device)\n",
    "            tmp_features = model(tmp_strip).squeeze().detach().cpu().numpy()\n",
    "            features.extend(tmp_features)\n",
    "\n",
    "    features = np.array(features)\n",
    "    print(\"\\nEval \" + weights)\n",
    "\n",
    "    for seed in tqdm(np.arange(10)):  \n",
    "        \n",
    "        # print(\"Model Loaded\")\n",
    "\n",
    "        final_results = []\n",
    "        final_y = []\n",
    "        final_probs = []\n",
    "\n",
    "        accu = []\n",
    "        kf = KFold(n_splits=10, shuffle=True, random_state=seed)\n",
    "\n",
    "        unique_patients = np.unique(annos[:, 0])\n",
    "\n",
    "        #for i in tqdm(np.unique(l_af_anno[:, 0])):\n",
    "        for i, (train_pat_index, test_pat_index) in enumerate(kf.split(unique_patients)):\n",
    "\n",
    "            svc = LogisticRegression(random_state=0)\n",
    "            # svc = SVC(gamma='auto', probability=True)\n",
    "\n",
    "            train_patients = unique_patients[train_pat_index]\n",
    "            test_patients = unique_patients[test_pat_index]\n",
    "\n",
    "            train_index = [i for i in range(annos.shape[0]) if annos[i, 0] not in test_patients]\n",
    "            test_index = [i for i in range(annos.shape[0]) if annos[i, 0] in test_patients]\n",
    "\n",
    "            train_features = features[train_index, :]\n",
    "            train_annos = annos[train_index, 1]\n",
    "\n",
    "            test_features = features[test_index, :]\n",
    "            test_annos = annos[test_index, 1]\n",
    "\n",
    "            y_train = np.array(train_annos, dtype = np.int32)\n",
    "            y_val = np.array(test_annos, dtype = np.int32)\n",
    "\n",
    "            \n",
    "            scaler = StandardScaler()\n",
    "            train_features = scaler.fit_transform(train_features)\n",
    "            test_features = scaler.transform(test_features)\n",
    "\n",
    "            svc.fit(train_features, y_train)\n",
    "            pred = svc.predict(test_features)\n",
    "            probs = svc.predict_proba(test_features)\n",
    "\n",
    "            accu.append(accuracy_score(pred, y_val))\n",
    "            final_results.extend(pred)\n",
    "            final_y.extend(y_val)\n",
    "            final_probs.extend(probs)\n",
    "\n",
    "        total_accus.append(np.mean(np.array(final_y) == np.array(final_results)))\n",
    "        total_auc.append(roc_auc_score(final_y, final_probs, multi_class='ovr'))\n",
    "        \n",
    "    print(\"\\nAccuracy Results: \")\n",
    "    print(np.mean(total_accus))\n",
    "    print(np.std(total_accus))\n",
    "\n",
    "    print(\"\\nAUC Results: \")\n",
    "    print(np.mean(total_auc))\n",
    "    print(np.std(total_auc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
