{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7a7375e7",
   "metadata": {},
   "source": [
    "This notebook can be used to reproduce any of the experiments that were conducted to obtain the empirical scaling law for compressive sensing in the context of accelerated MRI in _Section 4: Empirical scaling laws for compressive sensing_ from the paper **Scaling Laws For Deep Learning Based Image Reconstruction**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5609530c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "\n",
    "from fastmri.main_functions_helpers import *\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f8780f2",
   "metadata": {},
   "source": [
    "Specify which experiments to run by indicating training set size and network size. Corresponding hyperparameters are loaded automatically.\n",
    "\n",
    "For all available combinations of training set and network size see the config files in options/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91573332",
   "metadata": {},
   "outputs": [],
   "source": [
    "#######################################################\n",
    "# Adjust the following parameters\n",
    "#######################################################\n",
    "# Start or continue training\n",
    "training = True\n",
    "# Evaluate last and best checkpoint on validation and test set\n",
    "testing = True\n",
    "\n",
    "# Assign an ID to the experiment\n",
    "exp_nums = ['001','002'] \n",
    "# Path to fastMRI brain directory containing both the training and validation set\n",
    "path_to_fastMRI_brain_dataset = \"brain_path: ../../../media/ssd1/fastMRIdata/brain\"\n",
    "# training set size\n",
    "train_sizes = [50,100]\n",
    "# network size defined by the number of channels in the first layer\n",
    "channels = [64,128]\n",
    "\n",
    "########################################################\n",
    "\n",
    "# Sanity checks\n",
    "if len(train_sizes) != len(exp_nums) or len(channels) != len(exp_nums):\n",
    "    raise ValueError(\"Specify experiment ID for each experiment\") "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d2da898",
   "metadata": {},
   "source": [
    "Load hyperparameter configurations for each experiment from options/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bf04a5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "hps = []\n",
    "for train_size,channel in zip(train_sizes, channels):\n",
    "    options_name = \"options/trainsize{}_channels{}.txt\".format(train_size,channel)\n",
    "\n",
    "    # Load hyperparameter options\n",
    "    with open(options_name) as handle:\n",
    "        hp = json.load(handle)\n",
    "    hps.append(hp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce7bbbfd",
   "metadata": {},
   "source": [
    "Run training/testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fc79d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "for ee in range(len(exp_nums)):\n",
    "    \n",
    "    hp = hps[ee]\n",
    "    num_runs = list(np.arange(hp['num_runs'][0]))\n",
    "    for rr in num_runs:\n",
    "        exp_name =  'E' + exp_nums[ee] + \\\n",
    "                    '_t' + str(hp['num_examples'][0]) + \\\n",
    "                    '_l' + '4' + \\\n",
    "                    'c' + str(hp['chans'][0]) + \\\n",
    "                    '_bs' + '1' +\\\n",
    "                    '_lr' + '001'\n",
    "        if rr>0:\n",
    "            exp_name = exp_name + '_run{}'.format(rr+1)\n",
    "        if not os.path.isdir('./'+exp_name):\n",
    "            os.mkdir('./'+exp_name)\n",
    "        create_fastmri_dirs_yaml(path_to_fastMRI_brain_dataset,exp_name)\n",
    "        \n",
    "        ########\n",
    "        # Training\n",
    "        ########\n",
    "        if training:  \n",
    "            print('\\n{} - Training\\n'.format(exp_name))\n",
    "            args = build_args(hp,rr)\n",
    "            cli_main(args)\n",
    "            print('\\n{} - Training finished\\n'.format(exp_name))\n",
    "        \n",
    "        ########\n",
    "        # Testing\n",
    "        ########\n",
    "        if testing:\n",
    "            print('\\n{} - Testing\\n'.format(exp_name))\n",
    "            test_modes = [\"test_on_val\",\"test_on_test\"]\n",
    "\n",
    "            for test_mode in test_modes:\n",
    "                for resume_from_which_checkpoint in [\"last\",\"best\"]:\n",
    "\n",
    "                    args = build_args(hp,rr,test_mode)\n",
    "                    args.mode = \"test\"\n",
    "                    args.logger = False\n",
    "                    args.test_path=args.data_path/\"multicoil_val\"\n",
    "                    cli_main(args)\n",
    "                    if test_mode == \"test_on_test\" or test_mode == \"test_on_val\":\n",
    "                        tm = test_mode[8:]\n",
    "                    else:\n",
    "                        tm = test_mode\n",
    "                    metrics_filename = './'+exp_name+'/log_files/metrics_'+exp_name+'_{}_{}.pkl'.format(tm,resume_from_which_checkpoint)\n",
    "                    if resume_from_which_checkpoint==\"best\":\n",
    "                        ckpt = args.resume_from_checkpoint\n",
    "                        ind1 = str(ckpt).find('epoch=')\n",
    "                        ind2 = str(ckpt).find('-step')\n",
    "                        epoch = str(ckpt)[ind1+len('epoch='):ind2]\n",
    "                        metrics_filename = metrics_filename[:-4]+'_'+epoch+'ep'+metrics_filename[-4:]\n",
    "                    evaluate_reconstructions(test_mode,metrics_filename)\n",
    "\n",
    "            print('\\n{} - Testing finished\\n'.format(exp_name)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22691ea3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb0c9d57",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
