{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b354b274-052b-49d2-828f-915eead6ec68",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initializing ... \n",
      "generating json ...\n"
     ]
    }
   ],
   "source": [
    "import set_path\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pycocotools.coco import COCO\n",
    "import skimage.io as io\n",
    "import nibabel as nib\n",
    "import pickle\n",
    "import seaborn as sns\n",
    "import os\n",
    "import re\n",
    "import warnings\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import random\n",
    "\n",
    "from scipy import signal, stats\n",
    "from scipy.interpolate import interp1d\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn.functional as F\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from settings import settings\n",
    "from utils import *\n",
    "from filters import filters\n",
    "from funcs import *\n",
    "\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "from datasets_transfer import HCP3TTransferDataset\n",
    "from model_utils import CosineWarmupScheduler\n",
    "from baseline_models import BrainNetworkTransformer\n",
    "from baseline_models import GraphTransformer\n",
    "from baseline_models import BrainNetCNN\n",
    "from baseline_models import FBNETGEN\n",
    "from baseline_models import ModelSTAGIN\n",
    "from baseline_models.STAGIN import process_dynamic_fc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3ebb49b3-f5ac-4c20-8623-c93615a927f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "varName = 'BaselineModels'\n",
    "fName = getattr(settings.projectData.files.general_HCP3T, varName)\n",
    "fPath = settings.projectData.dir.general_HCP3T / fName\n",
    "with open(fPath, 'rb') as handle:\n",
    "    BaselineModels = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "491a8fe8-bf61-4f91-a36a-8126b82e5fcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from einops import repeat\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "def evaluation_STAGIN(modelInfo, checkpoint='best', kf_split_id=0):\n",
    "    args = modelInfo.args\n",
    "\n",
    "    train_dataset = HCP3TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    train_dataset.train()\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = HCP3TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    test_dataset.test()\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = ModelSTAGIN(\n",
    "                    input_dim=args.model.num_nodes,\n",
    "                    hidden_dim=args.model.hidden_dim,\n",
    "                    num_classes=args.model.output_dim,\n",
    "                    num_heads=args.model.num_heads,\n",
    "                    num_layers=args.model.num_layers,\n",
    "                    sparsity=args.model.sparsity,\n",
    "                    dropout=args.model.dropout,\n",
    "                    cls_token=args.model.cls_token,\n",
    "                    readout=args.model.readout,\n",
    "                )\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=True)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "            batch = next(dataIter)\n",
    "            batch['timeseries'] = batch['timeseries'].permute(0,2,1)\n",
    "            # process the data\n",
    "            dyn_a, sampling_points = process_dynamic_fc(batch['timeseries'], \n",
    "                                        args.model.window_size, args.model.window_stride, args.model.dynamic_length)\n",
    "            sampling_endpoints = [p + args.model.window_size for p in sampling_points]\n",
    "\n",
    "            if i_iter==0: dyn_v = repeat(torch.eye(args.model.num_nodes), 'n1 n2 -> b t n1 n2', \n",
    "                                    t=len(sampling_points), b=args.training.batch_size)\n",
    "            if len(dyn_a) < args.training.batch_size: dyn_v = dyn_v[:len(dyn_a)]\n",
    "            t = batch['timeseries'].permute(1,0,2)\n",
    "\n",
    "            # prediction\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "            preds, _, _, _ = model(dyn_v.to(args.training.device), \n",
    "                                dyn_a.to(args.training.device), \n",
    "                                t.to(args.training.device), \n",
    "                                sampling_endpoints)\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results['AUROC_{:s}'.format(tgt_name)] = roc_auc_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e881eab9-ff59-4c59-b75c-513d50b34861",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "def evaluation(modelInfo, checkpoint='best', kf_split_id=0):\n",
    "    args = modelInfo.args\n",
    "\n",
    "    train_dataset = HCP3TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    train_dataset.train()\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = HCP3TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    test_dataset.test()\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = eval(args.model.name)(args)\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=True)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "\n",
    "            batch = next(dataIter)\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "\n",
    "            preds = model(batch['timeseries'], batch['fc'])\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results['AUROC_{:s}'.format(tgt_name)] = roc_auc_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "40be0d60-6f72-4742-a119-6ab1e800fc39",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  3.31it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.31it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.32it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.16it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.46it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.BrainNetworkTransformer, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "786b4247-bbbd-4eb6-8de9-489c866d33a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 1.5356+0.0329\n",
      "PicSeq_Unadj: 7.3302+0.2775\n",
      "PMAT24_A_CR: 2.4268+0.0818\n",
      "PMAT24_A_SI: 1.8453+0.0979\n",
      "PicVocab_Unadj: 4.8642+0.0589\n",
      "IWRD_TOT: 1.6045+0.0532\n",
      "ListSort_Unadj: 6.2123+0.1803\n",
      "LifeSatisf_Unadj: 5.1754+0.1546\n",
      "Gender: 0.9441+0.0098\n",
      "AUROC_Gender: 0.9439+0.0090\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bd5bdbe5-19c2-4d6b-976b-faa261fef4fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.87it/s]\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.36it/s]\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.94it/s]\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.52it/s]\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.75it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.GraphTransformer, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "24a3ad7b-0dfa-4c33-9b49-08a2ffcb1a2e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 1.6747+0.0539\n",
      "PicSeq_Unadj: 8.1897+0.4625\n",
      "PMAT24_A_CR: 2.7290+0.0670\n",
      "PMAT24_A_SI: 2.1292+0.0718\n",
      "PicVocab_Unadj: 5.3925+0.1346\n",
      "IWRD_TOT: 1.8105+0.0748\n",
      "ListSort_Unadj: 6.9073+0.2103\n",
      "LifeSatisf_Unadj: 5.7117+0.0773\n",
      "Gender: 0.9000+0.0105\n",
      "AUROC_Gender: 0.8993+0.0094\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f2244c8e-e24b-4955-9edb-5d65cb367bb9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 22/22 [00:02<00:00,  9.91it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 20.53it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 21/21 [00:01<00:00, 19.42it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 20.99it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 21.39it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.BrainNetCNN, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0a86df43-49a5-4d9d-b24c-305de3d85718",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 2.1248+0.0496\n",
      "PicSeq_Unadj: 10.2212+0.1345\n",
      "PMAT24_A_CR: 3.2372+0.1134\n",
      "PMAT24_A_SI: 2.6408+0.1373\n",
      "PicVocab_Unadj: 6.6798+0.1954\n",
      "IWRD_TOT: 2.2706+0.0461\n",
      "ListSort_Unadj: 8.4924+0.1554\n",
      "LifeSatisf_Unadj: 7.1420+0.2513\n",
      "Gender: 0.9068+0.0180\n",
      "AUROC_Gender: 0.9089+0.0155\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3a47c9e0-5bcf-4813-83f7-7216bd56e342",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00, 10.60it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  9.68it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  9.30it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00, 10.56it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.61it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.FBNETGEN, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f4deed54-ead1-4c08-be48-dc38fed7f6c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 1.8137+0.0301\n",
      "PicSeq_Unadj: 8.6149+0.2143\n",
      "PMAT24_A_CR: 2.9341+0.1082\n",
      "PMAT24_A_SI: 2.3387+0.1070\n",
      "PicVocab_Unadj: 5.8340+0.1449\n",
      "IWRD_TOT: 1.9630+0.0440\n",
      "ListSort_Unadj: 7.3091+0.0955\n",
      "LifeSatisf_Unadj: 6.0883+0.1044\n",
      "Gender: 0.8805+0.0100\n",
      "AUROC_Gender: 0.8793+0.0097\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "be7cb7d1-219f-4106-ad3c-e47331c539a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.00s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.02s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.01s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.03s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.03s/it]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.STAGIN_GARO, attr)\n",
    "    rtns = evaluation_STAGIN(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0aec7a09-c25c-445b-8b1a-bc8ebe61bf3e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 2.0592+0.0628\n",
      "PicSeq_Unadj: 10.1815+0.2560\n",
      "PMAT24_A_CR: 3.4490+0.0372\n",
      "PMAT24_A_SI: 2.6953+0.0718\n",
      "PicVocab_Unadj: 6.8648+0.2876\n",
      "IWRD_TOT: 2.2845+0.0483\n",
      "ListSort_Unadj: 8.5514+0.1614\n",
      "LifeSatisf_Unadj: 7.1389+0.3813\n",
      "Gender: 0.8834+0.0094\n",
      "AUROC_Gender: 0.8833+0.0091\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "02c9c693-a1b3-45db-88da-c87c2ab6dd94",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.02s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.01s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:05<00:00,  1.00it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.02s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.02s/it]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.STAGIN_SERO, attr)\n",
    "    rtns = evaluation_STAGIN(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "579edb5e-64ae-404b-b6e9-13534ff5bc5f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 2.1055+0.0568\n",
      "PicSeq_Unadj: 10.3130+0.2418\n",
      "PMAT24_A_CR: 3.5090+0.0337\n",
      "PMAT24_A_SI: 2.7244+0.0592\n",
      "PicVocab_Unadj: 6.7722+0.1911\n",
      "IWRD_TOT: 2.2660+0.0358\n",
      "ListSort_Unadj: 8.5983+0.3411\n",
      "LifeSatisf_Unadj: 7.1882+0.1782\n",
      "Gender: 0.8873+0.0136\n",
      "AUROC_Gender: 0.8869+0.0141\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5f0b5aa-8799-4c82-91d8-27a820100c51",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pyfmri]",
   "language": "python",
   "name": "conda-env-pyfmri-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
