{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.metrics import f1_score,roc_auc_score\n",
    "from sklearn.svm import SVC\n",
    "\n",
    "\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import os\n",
    "\n",
    "from vit import ViT\n",
    "\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "signals = pd.read_csv(\"Eval/cpsc2018_signals.csv\")\n",
    "annos = pd.read_csv(\"Eval/cpsc2018_annos.csv\")\n",
    "\n",
    "signals = np.array(signals)\n",
    "annos = np.array(annos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_BLOCKS = 4\n",
    "MODEL_DIM = 128\n",
    "NUM_HEADS = 4\n",
    "PATCH_SIZE = 10\n",
    "FS = 100\n",
    "L = 10\n",
    "C_IN = 1\n",
    "\n",
    "\n",
    "model = ViT(num_blocks=NUM_BLOCKS, num_heads=NUM_HEADS, model_dim=MODEL_DIM, \n",
    "            patch_size=PATCH_SIZE, in_channels=C_IN, fs=FS, l=L, do_prob=0.1)\n",
    "\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# For Main Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Eval: clocs.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [39:13<00:00, 235.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.4714466712739001\n",
      "0.0019152255762736085\n",
      "3.6680890080125755e-06\n",
      "\n",
      "F1: \n",
      "0.41366043938628827\n",
      "0.0015700715276558586\n",
      "2.4651246019556015e-06\n",
      "AUC\n",
      "0.8266554964265289\n",
      "0.0002391819666834899\n",
      "5.720801318658207e-08\n",
      "\n",
      "Eval: cupid_0.5.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [26:42<00:00, 160.26s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6854756968440452\n",
      "0.0010963981238343455\n",
      "1.2020888459474728e-06\n",
      "\n",
      "F1: \n",
      "0.650017356386059\n",
      "0.0016213751203946274\n",
      "2.6288572810346926e-06\n",
      "AUC\n",
      "0.9278566053100077\n",
      "0.00025801378644996193\n",
      "6.657111399824655e-08\n",
      "\n",
      "Eval: deaps.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [28:56<00:00, 173.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6670237272517853\n",
      "0.0020729030390526882\n",
      "4.296927009313871e-06\n",
      "\n",
      "F1: \n",
      "0.6342198474279475\n",
      "0.0027822656486396037\n",
      "7.741002139599955e-06\n",
      "AUC\n",
      "0.9186302277553746\n",
      "0.0003604673515120129\n",
      "1.2993671150608508e-07\n",
      "\n",
      "Eval: jepa_0.5.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [33:45<00:00, 202.55s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.5628887353144438\n",
      "0.0017513547854179428\n",
      "3.0672435844063282e-06\n",
      "\n",
      "F1: \n",
      "0.5144391545378366\n",
      "0.0024425892404027617\n",
      "5.96624219733134e-06\n",
      "AUC\n",
      "0.8803548134421486\n",
      "0.0004045761315990073\n",
      "1.6368184625961729e-07\n",
      "\n",
      "Eval: mae_0.5.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [27:49<00:00, 166.95s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.5929394148813638\n",
      "0.0019379496200140716\n",
      "3.7556487297126843e-06\n",
      "\n",
      "F1: \n",
      "0.5430598836820741\n",
      "0.002702530937238135\n",
      "7.303673466729233e-06\n",
      "AUC\n",
      "0.8943054457277999\n",
      "0.0003222665136353211\n",
      "1.0385570581066461e-07\n",
      "\n",
      "Eval: mix_up.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [37:51<00:00, 227.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.5020732550103663\n",
      "0.0019114466739377075\n",
      "3.6536283873075244e-06\n",
      "\n",
      "F1: \n",
      "0.4514573877149702\n",
      "0.003688208403465567\n",
      "1.3602881227394027e-05\n",
      "AUC\n",
      "0.8367167227022044\n",
      "0.0004002996687223686\n",
      "1.6023982477923808e-07\n",
      "\n",
      "Eval: pclr.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [30:23<00:00, 182.36s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6309145358212394\n",
      "0.0017077865001474844\n",
      "2.9165347300859937e-06\n",
      "\n",
      "F1: \n",
      "0.5940538997712297\n",
      "0.00279889693395078\n",
      "7.833824046879076e-06\n",
      "AUC\n",
      "0.9028375845091263\n",
      "0.00043823870155278557\n",
      "1.9205315953867145e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for weights in os.listdir(\"ViT Trained Models\"):\n",
    "    random.seed(0)\n",
    "    np.random.seed(0)\n",
    "\n",
    "    features = []\n",
    "    bs = 512\n",
    "\n",
    "    weights_path = os.path.join(\"ViT Trained Models\", weights)\n",
    "\n",
    "    print(\"\\nEval: \" + weights)\n",
    "    model.load_state_dict(torch.load(weights_path))\n",
    "    model.eval()\n",
    "    print(\"Model Loaded\")\n",
    "\n",
    "    seed_accus, seed_f1, seed_auc = [], [], []\n",
    "\n",
    "    for i in range(0, signals.shape[0], bs):\n",
    "        tmp_signal = signals[i : i + bs, :]\n",
    "\n",
    "        with torch.no_grad():\n",
    "            tmp_strip = torch.tensor(tmp_signal).float().to(device)\n",
    "            tmp_features = model(tmp_strip).squeeze().detach().cpu().numpy()\n",
    "            features.extend(tmp_features)\n",
    "\n",
    "    features = np.array(features)\n",
    "\n",
    "\n",
    "    for seed in tqdm(np.arange(10)):\n",
    "        \n",
    "        final_results = []\n",
    "        final_y = []\n",
    "        final_probs = []\n",
    "\n",
    "        kf = KFold(n_splits=10, shuffle=True, random_state=seed)\n",
    "\n",
    "        for i, (train_index, test_index) in enumerate(kf.split(np.arange(signals.shape[0]))):\n",
    "\n",
    "\n",
    "            svc = SVC(gamma='auto', probability=True)\n",
    "\n",
    "            train_features = features[train_index, :]\n",
    "            train_annos = annos[train_index, 0]\n",
    "\n",
    "            test_features = features[test_index, :]\n",
    "            test_annos = annos[test_index, 0]\n",
    "\n",
    "            y_train = np.array(train_annos, dtype = np.int32)\n",
    "            y_val = np.array(test_annos, dtype = np.int32)\n",
    "\n",
    "            \n",
    "            scaler = StandardScaler()\n",
    "            train_features = scaler.fit_transform(train_features)\n",
    "            test_features = scaler.transform(test_features)\n",
    "\n",
    "            svc.fit(train_features, y_train)\n",
    "            pred = svc.predict(test_features)\n",
    "            probs = svc.predict_proba(test_features)\n",
    "\n",
    "            final_results.extend(pred)\n",
    "            final_y.extend(y_val)\n",
    "            final_probs.extend(probs)\n",
    "\n",
    "        seed_accus.append(np.mean(np.array(final_y) == np.array(final_results)))\n",
    "        seed_f1.append(f1_score(final_y, final_results, average='macro'))  \n",
    "        seed_auc.append(roc_auc_score(final_y, final_probs, multi_class='ovr'))\n",
    "\n",
    "\n",
    "    print(\"Accuracy: \")\n",
    "    print(np.mean(seed_accus))\n",
    "    print(np.std(seed_accus))\n",
    "    print(\"\")\n",
    "\n",
    "    print(\"F1: \")\n",
    "    print(np.mean(seed_f1))\n",
    "    print(np.std(seed_f1))\n",
    "\n",
    "    # print(confusion_matrix(final_results, final_y))\n",
    "\n",
    "    print(\"AUC\")\n",
    "    print(np.mean(seed_auc))\n",
    "    print(np.std(seed_auc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# For Ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Eval: cupid_0.4.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:06<00:00, 24.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6761345312140061\n",
      "0.0012000936445640262\n",
      "\n",
      "F1: \n",
      "0.6338767364544758\n",
      "0.0014298034730277503\n",
      "\n",
      "Eval: cupid_0.6.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:03<00:00, 24.36s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6588919603777932\n",
      "0.0015099236471854262\n",
      "\n",
      "F1: \n",
      "0.6155023812364807\n",
      "0.0020231307629660217\n",
      "\n",
      "Eval: jepa_0.4.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:38<00:00, 27.84s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6298663902326652\n",
      "0.0014014719965910964\n",
      "\n",
      "F1: \n",
      "0.5862753855801976\n",
      "0.0013023109911869596\n",
      "\n",
      "Eval: jepa_0.6.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:40<00:00, 28.05s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.5790831605620823\n",
      "0.0011299468756904444\n",
      "\n",
      "F1: \n",
      "0.5323873644862083\n",
      "0.001749489944901544\n",
      "\n",
      "Eval: jepa_spec_0.4.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:47<00:00, 28.78s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6730131306150656\n",
      "0.0009546121918909657\n",
      "\n",
      "F1: \n",
      "0.6363144778169956\n",
      "0.001614657722821414\n",
      "\n",
      "Eval: jepa_spec_0.5.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:30<00:00, 27.02s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.663441603317208\n",
      "0.0013696352449845332\n",
      "\n",
      "F1: \n",
      "0.6270644647503697\n",
      "0.002375522710446359\n",
      "\n",
      "Eval: jepa_spec_0.6.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:58<00:00, 29.89s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.5937571988021194\n",
      "0.0024821402897122337\n",
      "\n",
      "F1: \n",
      "0.5492590012270677\n",
      "0.003157662438310486\n",
      "\n",
      "Eval: mae_0.4.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [05:08<00:00, 30.88s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6201105736005529\n",
      "0.0016368651392297262\n",
      "\n",
      "F1: \n",
      "0.5730549994063189\n",
      "0.001874965868545553\n",
      "\n",
      "Eval: mae_0.6.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [05:15<00:00, 31.51s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.5505183137525915\n",
      "0.0015177670039188344\n",
      "\n",
      "F1: \n",
      "0.4951443808246248\n",
      "0.0019188700774471193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for weights in os.listdir(\"ViT Models Ablation\"):\n",
    "    random.seed(0)\n",
    "    np.random.seed(0)\n",
    "\n",
    "    features = []\n",
    "    bs = 512\n",
    "\n",
    "    weights_path = os.path.join(\"ViT Models Ablation\", weights)\n",
    "\n",
    "    print(\"\\nEval: \" + weights)\n",
    "    model.load_state_dict(torch.load(weights_path))\n",
    "    model.eval()\n",
    "    print(\"Model Loaded\")\n",
    "\n",
    "    seed_accus, seed_f1, seed_auc = [], [], []\n",
    "\n",
    "    for i in range(0, signals.shape[0], bs):\n",
    "        tmp_signal = signals[i : i + bs, :]\n",
    "\n",
    "        with torch.no_grad():\n",
    "            tmp_strip = torch.tensor(tmp_signal).float().to(device)\n",
    "            tmp_features = model(tmp_strip).squeeze().detach().cpu().numpy()\n",
    "            features.extend(tmp_features)\n",
    "\n",
    "    features = np.array(features)\n",
    "\n",
    "\n",
    "    for seed in tqdm(np.arange(10)):\n",
    "        \n",
    "        final_results = []\n",
    "        final_y = []\n",
    "        final_probs = []\n",
    "\n",
    "        kf = KFold(n_splits=10, shuffle=True, random_state=seed)\n",
    "\n",
    "        for i, (train_index, test_index) in enumerate(kf.split(np.arange(signals.shape[0]))):\n",
    "\n",
    "\n",
    "            svc = SVC(gamma='auto')\n",
    "\n",
    "            train_features = features[train_index, :]\n",
    "            train_annos = annos[train_index, 0]\n",
    "\n",
    "            test_features = features[test_index, :]\n",
    "            test_annos = annos[test_index, 0]\n",
    "\n",
    "            y_train = np.array(train_annos, dtype = np.int32)\n",
    "            y_val = np.array(test_annos, dtype = np.int32)\n",
    "\n",
    "            \n",
    "            scaler = StandardScaler()\n",
    "            train_features = scaler.fit_transform(train_features)\n",
    "            test_features = scaler.transform(test_features)\n",
    "\n",
    "            svc.fit(train_features, y_train)\n",
    "            pred = svc.predict(test_features)\n",
    "            # probs = svc.predict_proba(test_features)\n",
    "\n",
    "            final_results.extend(pred)\n",
    "            final_y.extend(y_val)\n",
    "            # final_probs.extend(probs)\n",
    "\n",
    "        seed_accus.append(np.mean(np.array(final_y) == np.array(final_results)))\n",
    "        seed_f1.append(f1_score(final_y, final_results, average='macro'))  \n",
    "        # seed_auc.append(roc_auc_score(final_y, final_probs, multi_class='ovr'))\n",
    "\n",
    "    print(\"Accuracy: \")\n",
    "    print(np.mean(seed_accus))\n",
    "    print(np.std(seed_accus))\n",
    "    print(\"\")\n",
    "\n",
    "    print(\"F1: \")\n",
    "    print(np.mean(seed_f1))\n",
    "    print(np.std(seed_f1))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# For Icentia Ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Eval: cupid_icentia.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:10<00:00, 25.06s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6433425478000461\n",
      "0.001783816102161753\n",
      "\n",
      "F1: \n",
      "0.5988811853770768\n",
      "0.002979108053320549\n",
      "\n",
      "Eval: jepa_icentia.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:35<00:00, 27.52s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.5986754204100438\n",
      "0.001534155094033808\n",
      "\n",
      "F1: \n",
      "0.5479750481046538\n",
      "0.002074122447229188\n",
      "\n",
      "Eval: jepa_spec_icentia.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:19<00:00, 25.94s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6153420870767105\n",
      "0.0013053601996320956\n",
      "\n",
      "F1: \n",
      "0.5602081271914408\n",
      "0.00163905961761644\n",
      "\n",
      "Eval: mae_icentia.pth\n",
      "Model Loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [05:16<00:00, 31.69s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: \n",
      "0.6153420870767105\n",
      "0.0013053601996320956\n",
      "\n",
      "F1: \n",
      "0.5602081271914408\n",
      "0.00163905961761644\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for weights in os.listdir(\"ViT Icentia\"):\n",
    "    random.seed(0)\n",
    "    np.random.seed(0)\n",
    "\n",
    "    features = []\n",
    "    bs = 512\n",
    "\n",
    "    weights_path = os.path.join(\"ViT Icentia\", weights)\n",
    "\n",
    "    print(\"\\nEval: \" + weights)\n",
    "    model.load_state_dict(torch.load(weights_path))\n",
    "    model.eval()\n",
    "    print(\"Model Loaded\")\n",
    "\n",
    "    seed_accus, seed_f1, seed_auc = [], [], []\n",
    "\n",
    "    for i in range(0, signals.shape[0], bs):\n",
    "        tmp_signal = signals[i : i + bs, :]\n",
    "\n",
    "        with torch.no_grad():\n",
    "            tmp_strip = torch.tensor(tmp_signal).float().to(device)\n",
    "            tmp_features = model(tmp_strip).squeeze().detach().cpu().numpy()\n",
    "            features.extend(tmp_features)\n",
    "\n",
    "    features = np.array(features)\n",
    "\n",
    "\n",
    "    for seed in tqdm(np.arange(10)):\n",
    "        \n",
    "        final_results = []\n",
    "        final_y = []\n",
    "        final_probs = []\n",
    "\n",
    "        kf = KFold(n_splits=10, shuffle=True, random_state=seed)\n",
    "\n",
    "        for i, (train_index, test_index) in enumerate(kf.split(np.arange(signals.shape[0]))):\n",
    "\n",
    "\n",
    "            svc = SVC(gamma='auto')\n",
    "\n",
    "            train_features = features[train_index, :]\n",
    "            train_annos = annos[train_index, 0]\n",
    "\n",
    "            test_features = features[test_index, :]\n",
    "            test_annos = annos[test_index, 0]\n",
    "\n",
    "            y_train = np.array(train_annos, dtype = np.int32)\n",
    "            y_val = np.array(test_annos, dtype = np.int32)\n",
    "\n",
    "            \n",
    "            scaler = StandardScaler()\n",
    "            train_features = scaler.fit_transform(train_features)\n",
    "            test_features = scaler.transform(test_features)\n",
    "\n",
    "            svc.fit(train_features, y_train)\n",
    "            pred = svc.predict(test_features)\n",
    "            # probs = svc.predict_proba(test_features)\n",
    "\n",
    "            final_results.extend(pred)\n",
    "            final_y.extend(y_val)\n",
    "            # final_probs.extend(probs)\n",
    "\n",
    "        seed_accus.append(np.mean(np.array(final_y) == np.array(final_results)))\n",
    "        seed_f1.append(f1_score(final_y, final_results, average='macro'))  \n",
    "        # seed_auc.append(roc_auc_score(final_y, final_probs, multi_class='ovr'))\n",
    "\n",
    "    print(\"Accuracy: \")\n",
    "    print(np.mean(seed_accus))\n",
    "    print(np.std(seed_accus))\n",
    "    print(\"\")\n",
    "\n",
    "    print(\"F1: \")\n",
    "    print(np.mean(seed_f1))\n",
    "    print(np.std(seed_f1))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
