{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8f27cf84-02d9-447e-a40f-c36fb06ee674",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initializing ... \n",
      "generating json ...\n"
     ]
    }
   ],
   "source": [
    "import set_path\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pycocotools.coco import COCO\n",
    "import skimage.io as io\n",
    "import nibabel as nib\n",
    "import pickle\n",
    "import seaborn as sns\n",
    "import os\n",
    "import re\n",
    "import warnings\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import random\n",
    "\n",
    "from scipy import signal, stats\n",
    "from scipy.interpolate import interp1d\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn.functional as F\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from settings import settings\n",
    "from utils import *\n",
    "from filters import filters\n",
    "from funcs import *\n",
    "\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "from datasets_transfer import HCP3TTransferDataset, HCP7TTransferDataset\n",
    "from model_utils import CosineWarmupScheduler\n",
    "from baseline_models import BrainNetworkTransformer\n",
    "from baseline_models import GraphTransformer\n",
    "from baseline_models import BrainNetCNN\n",
    "from baseline_models import FBNETGEN\n",
    "from baseline_models import ModelSTAGIN\n",
    "from baseline_models.STAGIN import process_dynamic_fc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e337b044-9037-4614-93c5-9d072431842b",
   "metadata": {},
   "outputs": [],
   "source": [
    "varName = 'BaselineModels'\n",
    "fName = getattr(settings.projectData.files.general_HCP3T, varName)\n",
    "fPath = settings.projectData.dir.general_HCP3T / fName\n",
    "with open(fPath, 'rb') as handle:\n",
    "    BaselineModels = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a5aae5b4-b771-43fc-9520-2d67bf154246",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from einops import repeat\n",
    "\n",
    "def evaluation_HCP3T_STAGIN(modelInfo, fmri_type, checkpoint='best', kf_split_id=0):\n",
    "    args = modelInfo.args\n",
    "\n",
    "    train_dataset = HCP3TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    train_dataset.train()\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = HCP3TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=fmri_type,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    test_dataset.test()\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = ModelSTAGIN(\n",
    "                    input_dim=args.model.num_nodes,\n",
    "                    hidden_dim=args.model.hidden_dim,\n",
    "                    num_classes=args.model.output_dim,\n",
    "                    num_heads=args.model.num_heads,\n",
    "                    num_layers=args.model.num_layers,\n",
    "                    sparsity=args.model.sparsity,\n",
    "                    dropout=args.model.dropout,\n",
    "                    cls_token=args.model.cls_token,\n",
    "                    readout=args.model.readout,\n",
    "                )\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=True)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "\n",
    "            batch = next(dataIter)\n",
    "            batch['timeseries'] = batch['timeseries'].permute(0,2,1)\n",
    "            # process the data\n",
    "            dyn_a, sampling_points = process_dynamic_fc(batch['timeseries'], \n",
    "                                        args.model.window_size, args.model.window_stride, args.model.dynamic_length)\n",
    "            sampling_endpoints = [p + args.model.window_size for p in sampling_points]\n",
    "\n",
    "            if i_iter==0: dyn_v = repeat(torch.eye(args.model.num_nodes), 'n1 n2 -> b t n1 n2', \n",
    "                                    t=len(sampling_points), b=args.training.batch_size)\n",
    "            if len(dyn_a) < args.training.batch_size: dyn_v = dyn_v[:len(dyn_a)]\n",
    "            t = batch['timeseries'].permute(1,0,2)\n",
    "\n",
    "            # prediction\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "            preds, _, _, _ = model(dyn_v.to(args.training.device), \n",
    "                                dyn_a.to(args.training.device), \n",
    "                                t.to(args.training.device), \n",
    "                                sampling_endpoints)\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results['AUROC_{:s}'.format(tgt_name)] = roc_auc_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "49e459c2-8315-40d5-8b2f-ad8052169261",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "def evaluation_HCP3T(modelInfo, fmri_type, checkpoint='best', kf_split_id=0):\n",
    "    args = modelInfo.args\n",
    "\n",
    "    train_dataset = HCP3TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    train_dataset.train()\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = HCP3TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=fmri_type,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    test_dataset.test()\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = eval(args.model.name)(args)\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=True)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "\n",
    "            batch = next(dataIter)\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "\n",
    "            preds = model(batch['timeseries'], batch['fc'])\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results['AUROC_{:s}'.format(tgt_name)] = roc_auc_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "015fecd9-4ee4-4632-9051-6e9c1708536b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation_HCP7T(modelInfo, fmri_type, checkpoint='best', kf_split_id=0):\n",
    "    args = modelInfo.args\n",
    "\n",
    "    train_dataset = HCP3TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    train_dataset.train()\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = HCP7TTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=fmri_type,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True)\n",
    "    test_dataset.test()\n",
    "    test_dataset.test_data_infos = test_dataset.data_infos\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = eval(args.model.name)(args)\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=True)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "\n",
    "            batch = next(dataIter)\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "\n",
    "            preds = model(batch['timeseries'], batch['fc'])\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results['AUROC_{:s}'.format(tgt_name)] = roc_auc_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1c28be06-6970-4843-8c72-a0c2bf2ab2c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  3.30it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.17it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.24it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.55it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.26it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.BrainNetworkTransformer, attr)\n",
    "    rtns = evaluation_HCP3T(modelInfo, fmri_type='minimal_processed', kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ffc458ae-7e6a-4f49-8bbd-bae0890d7340",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 1.9771+0.0511\n",
      "PicSeq_Unadj: 9.7627+0.3798\n",
      "PMAT24_A_CR: 3.5155+0.0764\n",
      "PMAT24_A_SI: 2.7558+0.0401\n",
      "PicVocab_Unadj: 6.6365+0.1772\n",
      "IWRD_TOT: 2.1507+0.0453\n",
      "ListSort_Unadj: 8.2115+0.1106\n",
      "LifeSatisf_Unadj: 6.8336+0.2032\n",
      "Gender: 0.7962+0.0157\n",
      "AUROC_Gender: 0.8015+0.0065\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cfc7967c-81f9-4a75-91ba-67dcc82c4c22",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  9.60it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00, 10.79it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00, 10.87it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00, 10.79it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  9.95it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.FBNETGEN, attr)\n",
    "    rtns = evaluation_HCP3T(modelInfo, fmri_type='minimal_processed', kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2d8e3cc4-a234-4bbf-8d90-7cbf8fb17a27",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 2.3593+0.0895\n",
      "PicSeq_Unadj: 11.4059+0.3023\n",
      "PMAT24_A_CR: 3.9870+0.2486\n",
      "PMAT24_A_SI: 3.2132+0.2242\n",
      "PicVocab_Unadj: 8.0195+0.4363\n",
      "IWRD_TOT: 2.6026+0.0618\n",
      "ListSort_Unadj: 10.0944+0.3037\n",
      "LifeSatisf_Unadj: 8.0268+0.3803\n",
      "Gender: 0.7210+0.0180\n",
      "AUROC_Gender: 0.7219+0.0187\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5f740b12-9671-4ee7-8dce-710fb0c2e89b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 22/22 [00:02<00:00, 10.08it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 22.04it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 21/21 [00:01<00:00, 19.21it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 21.17it/s]\n",
      "test epoch [200|200]: 100%|███████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 19.89it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.BrainNetCNN, attr)\n",
    "    rtns = evaluation_HCP3T(modelInfo, fmri_type='minimal_processed', kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fc7b2f08-671a-41ac-8ec0-5240712d63a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 2.1714+0.0726\n",
      "PicSeq_Unadj: 11.0126+0.1546\n",
      "PMAT24_A_CR: 4.0775+0.2098\n",
      "PMAT24_A_SI: 3.4183+0.2749\n",
      "PicVocab_Unadj: 7.6274+0.2375\n",
      "IWRD_TOT: 2.4323+0.0707\n",
      "ListSort_Unadj: 9.4639+0.1139\n",
      "LifeSatisf_Unadj: 7.3871+0.2752\n",
      "Gender: 0.7334+0.0337\n",
      "AUROC_Gender: 0.7485+0.0180\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e085886e-d74d-4f83-ab5b-8c3840cca453",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.51it/s]\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.45it/s]\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.47it/s]\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.91it/s]\n",
      "test epoch [200|200]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.80it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.GraphTransformer, attr)\n",
    "    rtns = evaluation_HCP3T(modelInfo, fmri_type='minimal_processed', kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cbedeb0b-90f1-4527-ae5a-8c22935dd9cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 2.1224+0.0738\n",
      "PicSeq_Unadj: 10.5955+0.3656\n",
      "PMAT24_A_CR: 3.6740+0.0270\n",
      "PMAT24_A_SI: 2.9251+0.0389\n",
      "PicVocab_Unadj: 7.5351+0.4597\n",
      "IWRD_TOT: 2.3452+0.0619\n",
      "ListSort_Unadj: 8.8986+0.1509\n",
      "LifeSatisf_Unadj: 7.4709+0.1699\n",
      "Gender: 0.7553+0.0158\n",
      "AUROC_Gender: 0.7541+0.0151\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "aafee523-c4b5-46c5-b0cc-e5275ccdc600",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:05<00:00,  1.00it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:05<00:00,  1.00it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:05<00:00,  1.01it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.01s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.01s/it]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.STAGIN_GARO, attr)\n",
    "    rtns = evaluation_HCP3T_STAGIN(modelInfo, fmri_type='minimal_processed', kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "25daa8e5-397d-4cc8-aa89-ab562216f924",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 2.2518+0.0579\n",
      "PicSeq_Unadj: 11.2006+0.1651\n",
      "PMAT24_A_CR: 3.9347+0.1596\n",
      "PMAT24_A_SI: 3.1005+0.1824\n",
      "PicVocab_Unadj: 8.0808+0.3403\n",
      "IWRD_TOT: 2.4127+0.1011\n",
      "ListSort_Unadj: 9.6447+0.2600\n",
      "LifeSatisf_Unadj: 7.5849+0.2314\n",
      "Gender: 0.7003+0.0186\n",
      "AUROC_Gender: 0.7111+0.0126\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c8e614b4-067c-4691-837f-16d33dbd3805",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.02s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.03s/it]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:05<00:00,  1.01it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:05<00:00,  1.01it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.00s/it]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.STAGIN_SERO, attr)\n",
    "    rtns = evaluation_HCP3T_STAGIN(modelInfo, fmri_type='minimal_processed', kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "3a23039b-4e13-4a38-a073-a1544b8c8fe3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSQI_Score: 2.3678+0.0662\n",
      "PicSeq_Unadj: 11.2919+0.4094\n",
      "PMAT24_A_CR: 3.9959+0.1608\n",
      "PMAT24_A_SI: 3.1322+0.1564\n",
      "PicVocab_Unadj: 7.9757+0.4137\n",
      "IWRD_TOT: 2.4672+0.1213\n",
      "ListSort_Unadj: 9.8141+0.4207\n",
      "LifeSatisf_Unadj: 8.1704+0.2182\n",
      "Gender: 0.6922+0.0356\n",
      "AUROC_Gender: 0.6998+0.0286\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2fa2672-bf66-4e77-a807-498be864fd91",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pyfmri]",
   "language": "python",
   "name": "conda-env-pyfmri-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
