{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b5480f0d-9931-433f-9d33-d32a6b29236f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initializing ... \n",
      "generating json ...\n"
     ]
    }
   ],
   "source": [
    "import set_path\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pycocotools.coco import COCO\n",
    "import skimage.io as io\n",
    "import nibabel as nib\n",
    "import pickle\n",
    "import seaborn as sns\n",
    "import os\n",
    "import re\n",
    "import warnings\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import random\n",
    "\n",
    "from scipy import signal, stats\n",
    "from scipy.interpolate import interp1d\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn.functional as F\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from settings import settings\n",
    "from utils import *\n",
    "from filters import filters\n",
    "from funcs import *\n",
    "\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "from datasets_transfer import NSDTransferDataset\n",
    "from model_utils import CosineWarmupScheduler\n",
    "from models_autoencoder import fMRIAutoEncoder, fMRIStateTransferModel\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "\n",
    "from datasets_transfer import HCP3TTransferDataset\n",
    "from model_utils import CosineWarmupScheduler\n",
    "from baseline_models import BrainNetworkTransformer\n",
    "from baseline_models import GraphTransformer\n",
    "from baseline_models import BrainNetCNN\n",
    "from baseline_models import FBNETGEN\n",
    "from baseline_models import ModelSTAGIN\n",
    "from baseline_models.STAGIN import process_dynamic_fc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "642080ed-8cf5-4480-8552-84ab1400b3d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "varName = 'BaselineModels'\n",
    "fName = getattr(settings.projectData.files.general_NSD, varName)\n",
    "fPath = settings.projectData.dir.general_NSD / fName\n",
    "with open(fPath, 'rb') as handle:\n",
    "    BaselineModels = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "85d3abe3-1103-4833-a4a6-66d86d44f141",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from einops import repeat\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "def evaluation_STAGIN(modelInfo, checkpoint='best', kf_split_id=0, random_seed=42):\n",
    "    args = modelInfo.args\n",
    "\n",
    "    train_dataset = NSDTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            subject_id = args.training.subject_id,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True,\n",
    "                            random_seed=random_seed)\n",
    "    train_dataset.train()\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=args.training.batch_size, shuffle=True, num_workers=8)\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = NSDTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            subject_id = args.training.subject_id,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True,\n",
    "                            random_seed=random_seed)\n",
    "    test_dataset.test()\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = ModelSTAGIN(\n",
    "                    input_dim=args.model.num_nodes,\n",
    "                    hidden_dim=args.model.hidden_dim,\n",
    "                    num_classes=args.model.output_dim,\n",
    "                    num_heads=args.model.num_heads,\n",
    "                    num_layers=args.model.num_layers,\n",
    "                    sparsity=args.model.sparsity,\n",
    "                    dropout=args.model.dropout,\n",
    "                    cls_token=args.model.cls_token,\n",
    "                    readout=args.model.readout,\n",
    "                )\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=True)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "            batch = next(dataIter)\n",
    "            batch['timeseries'] = batch['timeseries'].permute(0,2,1)\n",
    "            # process the data\n",
    "            dyn_a, sampling_points = process_dynamic_fc(batch['timeseries'], \n",
    "                                        args.model.window_size, args.model.window_stride, args.model.dynamic_length)\n",
    "            sampling_endpoints = [p + args.model.window_size for p in sampling_points]\n",
    "\n",
    "            if i_iter==0: dyn_v = repeat(torch.eye(args.model.num_nodes), 'n1 n2 -> b t n1 n2', \n",
    "                                    t=len(sampling_points), b=args.training.batch_size)\n",
    "            if len(dyn_a) < args.training.batch_size: dyn_v = dyn_v[:len(dyn_a)]\n",
    "            t = batch['timeseries'].permute(1,0,2)\n",
    "\n",
    "            # prediction\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "            preds, _, _, _ = model(dyn_v.to(args.training.device), \n",
    "                                dyn_a.to(args.training.device), \n",
    "                                t.to(args.training.device), \n",
    "                                sampling_endpoints)\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results['AUROC_{:s}'.format(tgt_name)] = roc_auc_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ce11a7ca-ac1a-48ff-be79-69566c6899c3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def evaluation(modelInfo, checkpoint='best', kf_split_id=0, random_seed=42):\n",
    "    args = modelInfo.args\n",
    "\n",
    "    train_dataset = NSDTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            subject_id = args.training.subject_id,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True,\n",
    "                            random_seed=random_seed)\n",
    "    train_dataset.train()\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=args.training.batch_size, shuffle=True, num_workers=8)\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = NSDTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            subject_id = args.training.subject_id,\n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True,\n",
    "                            random_seed=random_seed)\n",
    "    test_dataset.test()\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = eval(args.model.name)(args)\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=True)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "\n",
    "            batch = next(dataIter)\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "\n",
    "            preds = model(batch['timeseries'], batch['fc'])\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1fd0e3b4-a826-41e0-bb25-ef4eb266af04",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.90it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.08it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.00it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.89it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.04it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.BrainNetworkTransformer, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "868ef01b-1469-41be-804e-82226a77f35a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Task_Accuracy: 0.069+0.003\n",
      "RT: 92.344+2.343\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.3f}+{:.3f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aa466e2b-128c-42a1-9a3a-c56d499f5446",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.40it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  2.43it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  4.14it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.23it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  3.84it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.GraphTransformer, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ba04e18b-e480-4400-a93f-84e758b31c91",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Task_Accuracy: 0.075+0.004\n",
      "RT: 96.252+2.133\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.3f}+{:.3f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7c25a200-4e59-4487-b2f0-5e26368fa319",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.35it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.68it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.50it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.69it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.41it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.FBNETGEN, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "77a51017-6fd6-46bd-95b5-e6e7e4fc4857",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Task_Accuracy: 0.074+0.005\n",
      "RT: 95.349+2.320\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.3f}+{:.3f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "96554a80-20fb-4758-82f7-140e75cedff1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  3.02it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  8.50it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  8.50it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  8.54it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  8.64it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.BrainNetCNN, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "86e3d31c-5906-44db-8d6e-b8b6408c42e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Task_Accuracy: 0.078+0.004\n",
      "RT: 102.911+2.225\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.3f}+{:.3f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fcfe286a-b94c-4681-97a9-f9ae46d9d39b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.19it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.22it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.24it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.23it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.21it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.STAGIN_GARO, attr)\n",
    "    rtns = evaluation_STAGIN(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "eb4e9b7a-011a-4283-924a-5b9a2d8ca8da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Task_Accuracy: 0.091+0.002\n",
      "RT: 116.130+2.099\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.3f}+{:.3f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "096fad7a-7ed7-4ce6-ad6f-b6f11cef5829",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.23it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.29it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.28it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.26it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.28it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.STAGIN_SERO, attr)\n",
    "    rtns = evaluation_STAGIN(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "17ef4f60-8ac3-4f33-a81f-08c3c1c0a870",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Task_Accuracy: 0.089+0.003\n",
      "RT: 116.635+2.197\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.3f}+{:.3f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b41904a-8c45-4165-bac3-be3aac60f555",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pyfmri]",
   "language": "python",
   "name": "conda-env-pyfmri-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
