{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "62469c51-1444-472c-b671-801b0ae093db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initializing ... \n",
      "generating json ...\n"
     ]
    }
   ],
   "source": [
    "import set_path\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pycocotools.coco import COCO\n",
    "import skimage.io as io\n",
    "import nibabel as nib\n",
    "import pickle\n",
    "import seaborn as sns\n",
    "import os\n",
    "import re\n",
    "import warnings\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import random\n",
    "\n",
    "from scipy import signal, stats\n",
    "from scipy.interpolate import interp1d\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn.functional as F\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from settings import settings\n",
    "from utils import *\n",
    "from filters import filters\n",
    "from funcs import *\n",
    "\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "from datasets_transfer import HCPAgingTransferDataset\n",
    "from model_utils import CosineWarmupScheduler\n",
    "from models_autoencoder import fMRIAutoEncoder, fMRIStateTransferModel\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "from datasets_transfer import HCP3TTransferDataset\n",
    "from model_utils import CosineWarmupScheduler\n",
    "from baseline_models import BrainNetworkTransformer\n",
    "from baseline_models import GraphTransformer\n",
    "from baseline_models import BrainNetCNN\n",
    "from baseline_models import FBNETGEN\n",
    "from baseline_models import ModelSTAGIN\n",
    "from baseline_models.STAGIN import process_dynamic_fc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "313c0275-97c7-4afd-9cfb-dc3e4fd6f375",
   "metadata": {},
   "outputs": [],
   "source": [
    "varName = 'BaselineModels'\n",
    "fName = getattr(settings.projectData.files.general_HCPAging, varName)\n",
    "fPath = settings.projectData.dir.general_HCPAging / fName\n",
    "with open(fPath, 'rb') as handle:\n",
    "    BaselineModels = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a1f2dc5b-11cc-4611-8f0b-67c1354af3da",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from einops import repeat\n",
    "\n",
    "def evaluation_STAGIN(modelInfo, checkpoint='best', kf_split_id=0, random_seed=42):\n",
    "    args = modelInfo.args\n",
    "\n",
    "    train_dataset = HCPAgingTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True,\n",
    "                            random_seed=random_seed)\n",
    "    train_dataset.train()\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=args.training.batch_size, shuffle=True, num_workers=8)\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = HCPAgingTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True,\n",
    "                            random_seed=random_seed)\n",
    "\n",
    "    test_dataset.test()\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = ModelSTAGIN(\n",
    "                    input_dim=args.model.num_nodes,\n",
    "                    hidden_dim=args.model.hidden_dim,\n",
    "                    num_classes=args.model.output_dim,\n",
    "                    num_heads=args.model.num_heads,\n",
    "                    num_layers=args.model.num_layers,\n",
    "                    sparsity=args.model.sparsity,\n",
    "                    dropout=args.model.dropout,\n",
    "                    cls_token=args.model.cls_token,\n",
    "                    readout=args.model.readout,\n",
    "                )\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=True)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "            batch = next(dataIter)\n",
    "            batch['timeseries'] = batch['timeseries'].permute(0,2,1)\n",
    "            # process the data\n",
    "            dyn_a, sampling_points = process_dynamic_fc(batch['timeseries'], \n",
    "                                        args.model.window_size, args.model.window_stride, args.model.dynamic_length)\n",
    "            sampling_endpoints = [p + args.model.window_size for p in sampling_points]\n",
    "\n",
    "            if i_iter==0: dyn_v = repeat(torch.eye(args.model.num_nodes), 'n1 n2 -> b t n1 n2', \n",
    "                                    t=len(sampling_points), b=args.training.batch_size)\n",
    "            if len(dyn_a) < args.training.batch_size: dyn_v = dyn_v[:len(dyn_a)]\n",
    "            t = batch['timeseries'].permute(1,0,2)\n",
    "\n",
    "            # prediction\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "            preds, _, _, _ = model(dyn_v.to(args.training.device), \n",
    "                                dyn_a.to(args.training.device), \n",
    "                                t.to(args.training.device), \n",
    "                                sampling_endpoints)\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results['AUROC_{:s}'.format(tgt_name)] = roc_auc_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "536f3176-f21b-4daf-b61b-c8f97efae3d2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def evaluation(modelInfo, checkpoint='best', kf_split_id=0, random_seed=42):\n",
    "    args = modelInfo.args\n",
    "\n",
    "    train_dataset = HCPAgingTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True,\n",
    "                            random_seed=random_seed)\n",
    "    train_dataset.train()\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=args.training.batch_size, shuffle=True, num_workers=8)\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = HCPAgingTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            output_fc=True,\n",
    "                            output_timeseries=True,\n",
    "                            random_seed=random_seed)\n",
    "\n",
    "    test_dataset.test()\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = eval(args.model.name)(args)\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=True)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "\n",
    "            batch = next(dataIter)\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "\n",
    "            preds = model(batch['timeseries'], batch['fc'])\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results['AUROC_{:s}'.format(tgt_name)] = roc_auc_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0603b555-ea2f-4883-8a3d-55f4e4be3ac4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  3.01it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  6.25it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  6.25it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  6.53it/s]\n",
      "/home/yzy161/anaconda3/envs/pyfmri/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  6.64it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.BrainNetworkTransformer, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6657f0c2-1044-4375-a2d0-38fe2cae34ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 6.1466+0.7136\n",
      "Gender: 0.9021+0.0381\n",
      "AUROC_Gender: 0.9073+0.0285\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "aec096c4-c60c-49fc-8da7-3592e4868a82",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 10.11it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  9.93it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  9.99it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 10.34it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  9.86it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.GraphTransformer, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "46833635-3662-4651-aaa3-0baaf6475a58",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 6.7785+0.5590\n",
      "Gender: 0.8896+0.0216\n",
      "AUROC_Gender: 0.8876+0.0221\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fd470798-194f-4636-b2db-c1958ead3942",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  6.80it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  9.62it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  9.74it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  9.63it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 10.02it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.FBNETGEN, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "eafdd7e3-c3a7-4d02-b8ec-ddda82796f2b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 6.6806+1.0033\n",
      "Gender: 0.8950+0.0358\n",
      "AUROC_Gender: 0.8934+0.0349\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f49d9b6c-86b5-45a0-ade8-a66b74235594",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.30it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  6.90it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  7.18it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  6.76it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  7.09it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.BrainNetCNN, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "17d61024-956d-4414-a52f-4e82a86aa793",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 8.7550+0.6938\n",
      "Gender: 0.8883+0.0152\n",
      "AUROC_Gender: 0.8874+0.0158\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6f205472-17d0-45c6-851f-02d172c69490",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.33it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.31it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.28it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.28it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.30it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.STAGIN_GARO, attr)\n",
    "    rtns = evaluation_STAGIN(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d0794bd4-8738-4c8d-918f-a4004db15664",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 8.5961+0.3725\n",
      "Gender: 0.8067+0.0081\n",
      "AUROC_Gender: 0.8058+0.0103\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b0c91919-95dc-4c29-bf83-44baafa6bfd6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.32it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.31it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.29it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.29it/s]\n",
      "test epoch [100|100]: 100%|█████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.28it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(BaselineModels.STAGIN_SERO, attr)\n",
    "    rtns = evaluation_STAGIN(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8f6546d7-bfad-41b5-83d2-e42d7b5d341b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 8.9209+0.2937\n",
      "Gender: 0.8237+0.0166\n",
      "AUROC_Gender: 0.8257+0.0136\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "404bcef8-6438-43ef-870a-45adde47dd5b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pyfmri]",
   "language": "python",
   "name": "conda-env-pyfmri-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
