{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "62469c51-1444-472c-b671-801b0ae093db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initializing ... \n",
      "generating json ...\n"
     ]
    }
   ],
   "source": [
    "import set_path\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pycocotools.coco import COCO\n",
    "import skimage.io as io\n",
    "import nibabel as nib\n",
    "import pickle\n",
    "import seaborn as sns\n",
    "import os\n",
    "import re\n",
    "import warnings\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import random\n",
    "\n",
    "from scipy import signal, stats\n",
    "from scipy.interpolate import interp1d\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn.functional as F\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from settings import settings\n",
    "from utils import *\n",
    "from filters import filters\n",
    "from funcs import *\n",
    "\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "from datasets_transfer import HCPAgingTransferDataset\n",
    "from model_utils import CosineWarmupScheduler\n",
    "from models_autoencoder import fMRIAutoEncoder, fMRIStateTransferModel\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.metrics import roc_auc_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ee5a7893-5673-44ce-a093-72bde0497ad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "varName = 'TransferredStateModels'\n",
    "fName = getattr(settings.projectData.files.general_HCPAging, varName)\n",
    "fPath = settings.projectData.dir.general_HCPAging / fName\n",
    "with open(fPath, 'rb') as handle:\n",
    "    TransferredStateModels = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "313c0275-97c7-4afd-9cfb-dc3e4fd6f375",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(modelInfo, checkpoint='best', kf_split_id=0, random_seed=43):\n",
    "    args = modelInfo.args\n",
    "    \n",
    "    pl.seed_everything(43)\n",
    "    # Ensure that all operations are deterministic on GPU (if used) for reproducibility\n",
    "    torch.backends.cudnn.determinstic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "    train_dataset = HCPAgingTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = args.training.output_fmri_size,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            random_seed=random_seed)\n",
    "    train_dataset.train()\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=args.training.batch_size, shuffle=True, num_workers=8)\n",
    "    args.training.sessIDs = [(row.SUBJECT, row.SESSION) for _, row in train_dataset.train_data_infos.iterrows()]\n",
    "    assert train_dataset.kf_num_splits == kf_num_splits\n",
    "\n",
    "    test_dataset = HCPAgingTransferDataset(settings, region_roi='Yeo100Parc',\n",
    "                            output_continuous_targets = args.training.output_continuous_targets,\n",
    "                            output_discrete_targets = args.training.output_discrete_targets,\n",
    "                            standardized_output = True,\n",
    "                            output_fmri_size = 375,\n",
    "                            fmri_type=args.training.fmri_type, \n",
    "                            kf_split_id=kf_split_id,\n",
    "                            overlapping_segments=args.training.overlapping_segments, \n",
    "                            normalize_fmri=args.training.normalize_fmri,\n",
    "                            random_seed=random_seed)\n",
    "    test_dataset.test()\n",
    "    test_dataset.tgt_norms = train_dataset.tgt_norms\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=8)\n",
    "\n",
    "    model = fMRIStateTransferModel(args.model)\n",
    "    mdl_path = modelInfo.model_files.set_index('checkpoint').loc[checkpoint].file_path\n",
    "    msg = model.load_state_dict(torch.load(mdl_path), strict=False)\n",
    "    model = model.to(args.training.device)\n",
    "\n",
    "    model.eval()\n",
    "    preds_all = []\n",
    "    tgt_all = []\n",
    "    feats_all = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        dataIter = iter(test_dataloader)\n",
    "\n",
    "        for i_iter in tqdm(range(len(dataIter)), \n",
    "                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):\n",
    "\n",
    "            batch = next(dataIter)\n",
    "            batch = {key: item.to(args.training.device) for key, item in batch.items()}\n",
    "\n",
    "            # mask_ratio = np.random.random() * args.training.max_mask_ratio\n",
    "            mask_ratio = 0\n",
    "            preds = model.forward(batch['fmri_segs'], mask_ratio=0)\n",
    "            feats = model.forward_features(batch['fmri_segs'])\n",
    "\n",
    "            _tgts = {}\n",
    "            _preds = {}\n",
    "            for tgt_name in args.training.output_continuous_targets:\n",
    "                idx = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = preds[:,idx].detach().cpu().numpy()\n",
    "\n",
    "            for tgt_name in args.training.output_discrete_targets:\n",
    "                idxs = getattr(args.model.head_tgt2dims, tgt_name)\n",
    "                _tgts[tgt_name] = batch[tgt_name].detach().cpu().numpy()\n",
    "                _preds[tgt_name] = np.argmax(preds[:,idxs[0]:idxs[1]+1].detach().cpu().numpy(), axis=1)\n",
    "\n",
    "            preds_all.append(pd.DataFrame(_preds))\n",
    "            tgt_all.append(pd.DataFrame(_tgts))\n",
    "            feats_all.append(feats.cpu().numpy())\n",
    "\n",
    "    preds_all = pd.concat(preds_all, axis=0, ignore_index=True)\n",
    "    tgt_all = pd.concat(tgt_all, axis=0, ignore_index=True)\n",
    "    feats_all = np.concatenate(feats_all, axis=0).squeeze()\n",
    "    \n",
    "    results = {}\n",
    "    results_n = {}\n",
    "    for tgt_name in args.training.output_continuous_targets:\n",
    "        nscalar = test_dataset.tgt_norms[tgt_name]\n",
    "        tgt = nscalar.inverse_transform(tgt_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        pred = nscalar.inverse_transform(preds_all[tgt_name].values.reshape(-1,1)).squeeze()\n",
    "        results_n[tgt_name] = mean_absolute_error(pred, tgt) / np.sqrt(nscalar.var_)\n",
    "        results[tgt_name] = mean_absolute_error(pred, tgt)\n",
    "\n",
    "    for tgt_name in args.training.output_discrete_targets:\n",
    "        results[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results['AUROC_{:s}'.format(tgt_name)] = roc_auc_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        results_n[tgt_name] = accuracy_score(preds_all[tgt_name].values, tgt_all[tgt_name].values)\n",
    "        \n",
    "    return results, results_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a624c934-d3cb-4311-b9e9-18e398875a6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:03<00:00,  4.06it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  7.97it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.06it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.09it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.02it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(TransferredStateModels.AE_MaskTSN_TRANSFERv3, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2eb1a30a-b776-49ac-8949-0ad44d4ad175",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 5.7883+0.3957\n",
      "Gender: 0.9267+0.0107\n",
      "AUROC_Gender: 0.9251+0.0107\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e1c3ae2e-5f69-4403-8d04-8bd6e8708f88",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.52it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.38it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.57it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.50it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.50it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(TransferredStateModels.AE_DynamicMaskTSN_TRANSFERv3, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "58e28a1a-e210-4587-ab4b-d590d1619719",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 6.5052+0.9630\n",
      "Gender: 0.9112+0.0199\n",
      "AUROC_Gender: 0.9115+0.0203\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8bba76fa-49c2-4246-9491-3464f938da80",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.48it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.55it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.57it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.58it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.41it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(TransferredStateModels.AE_VanillaTSN_TRANSFERv1, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ed1b6e77-a9f7-465e-94f7-8ed18e9393da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 7.2624+0.9458\n",
      "Gender: 0.8854+0.0250\n",
      "AUROC_Gender: 0.8853+0.0237\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "54c4adf8-c796-4fd9-a5e1-1c9b991dcd31",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.41it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.56it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.50it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.51it/s]\n",
      "Global seed set to 43\n",
      "test epoch [250|250]: 100%|███████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  8.57it/s]\n"
     ]
    }
   ],
   "source": [
    "kf_num_splits = 5\n",
    "results = []\n",
    "results_n = []\n",
    "for kf_split_id in range(kf_num_splits):\n",
    "    attr = 'iter{:d}'.format(kf_split_id)\n",
    "    modelInfo = getattr(TransferredStateModels.AE_VanillaTSN_TRANSFERv2, attr)\n",
    "    rtns = evaluation(modelInfo, kf_split_id=kf_split_id)\n",
    "    results_n.append(rtns[1])\n",
    "    results.append(rtns[0])\n",
    "    \n",
    "results = pd.DataFrame(results)\n",
    "results.age = results.age / 12\n",
    "results_n = pd.DataFrame(results_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bab0b772-9a9a-45db-b875-6a4bd8fb2f46",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age: 8.3302+0.4930\n",
      "Gender: 0.8092+0.0240\n",
      "AUROC_Gender: 0.8103+0.0252\n"
     ]
    }
   ],
   "source": [
    "dmean = results.mean().to_dict()\n",
    "dstd = results.std().to_dict()\n",
    "for name, val in dmean.items():\n",
    "    print('{:s}: {:.4f}+{:.4f}'.format(name, val, dstd[name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1d6d5cd-57c0-487e-8269-6140c0ff5ac1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pyfmri]",
   "language": "python",
   "name": "conda-env-pyfmri-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
