{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2ad6342b",
   "metadata": {},
   "source": [
    "This notebook can be used to reproduce any of the experiments that were conducted to obtain the empirical scaling law for image denoising in _Section 3: Empirical scaling laws for denoising_ from the paper **Scaling Laws For Deep Learning Based Image Reconstruction**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cb3ae67",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "\n",
    "from utils.main_function_helpers import *\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b367458",
   "metadata": {},
   "source": [
    "Specify which experiments to run by indicating training set size and network size. Corresponding hyperparameters are loaded automatically.\n",
    "\n",
    "For all available combinations of training set and network size see the config files in options/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd7e6ba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#######################################################\n",
    "# Adjust the following parameters\n",
    "#######################################################\n",
    "# Start or continue training\n",
    "training = True\n",
    "# Evaluate last and best checkpoint on validation set\n",
    "val_testing = True\n",
    "# Evaluate last and best checkpoint on test set\n",
    "test_testing = True\n",
    "\n",
    "\n",
    "# Assign an ID to the experiment\n",
    "exp_nums = ['001','002'] \n",
    "# Path to ImageNet train directory\n",
    "path_to_ImageNet_train = '../../../media/hdd1/ImageNet/ILSVRC/Data/CLS-LOC/'\n",
    "# training set size\n",
    "train_sizes = [300,3000]\n",
    "# network size defined by the number of channels in the first layer\n",
    "channels = [64,128]\n",
    "\n",
    "########################################################\n",
    "\n",
    "# Sanity checks\n",
    "if len(train_sizes) != len(exp_nums) or len(channels) != len(exp_nums):\n",
    "    raise ValueError(\"Specify experiment ID for each experiment\") "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1d581bb",
   "metadata": {},
   "source": [
    "Load hyperparameter configurations for each experiment from options/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "980500fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "hps = []\n",
    "for train_size,channel in zip(train_sizes, channels):\n",
    "    options_name = \"options/trainsize{}_channels{}.txt\".format(train_size,channel)\n",
    "\n",
    "    # Load hyperparameter options\n",
    "    with open(options_name) as handle:\n",
    "        hp = json.load(handle)\n",
    "    hp['path_to_ImageNet_train'] = [path_to_ImageNet_train]\n",
    "    hps.append(hp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "146f01fe",
   "metadata": {},
   "source": [
    "Run training/testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22aafcba",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for ee in range(len(exp_nums)):\n",
    "    \n",
    "    hp = hps[ee]\n",
    "    num_runs = list(np.arange(hp['num_runs'][0]))\n",
    "    \n",
    "    for rr in num_runs:\n",
    "        exp_name =  'E' + exp_nums[ee] + \\\n",
    "                    '_t' + str(hp['train_size'][0]) + \\\n",
    "                    '_l' + str(hp['num_pool_layers'][0]) + \\\n",
    "                    'c' + str(hp['chans'][0]) + \\\n",
    "                    '_bs' + str(hp['batch_size'][0]) +\\\n",
    "                    '_lr' + str(hp['lr'][0])[2:]\n",
    "        if rr>0:\n",
    "            exp_name = exp_name + '_run{}'.format(rr+1)\n",
    "        if not os.path.isdir('./'+exp_name):\n",
    "            os.mkdir('./'+exp_name)\n",
    "        \n",
    "        ########\n",
    "        # Training\n",
    "        ########\n",
    "        if training:  \n",
    "            print('\\n{} - Training\\n'.format(exp_name))\n",
    "            args = get_args(hp,rr)\n",
    "            args.output_dir = './'+exp_name\n",
    "            cli_main(args)\n",
    "            print('\\n{} - Training finished\\n'.format(exp_name))\n",
    "            \n",
    "        ########\n",
    "        # Testing\n",
    "        ########\n",
    "        if val_testing or test_testing:\n",
    "            print('\\n{} - Testing\\n'.format(exp_name))\n",
    "\n",
    "            test_modes = []\n",
    "            if val_testing:\n",
    "                test_modes.append(\"val\")\n",
    "            if test_testing:\n",
    "                test_modes.append(\"test\")\n",
    "\n",
    "            for test_mode in test_modes:\n",
    "                for restore_mode in [\"last\",\"best\"]:\n",
    "                    args = get_args(hp,rr)\n",
    "                    args.output_dir = './'+exp_name\n",
    "                    args.restore_mode = restore_mode\n",
    "                    args.test_mode = test_mode\n",
    "                    args.test_noise_std_min = args.noise_std\n",
    "                    args.test_noise_std_max = args.noise_std\n",
    "                    cli_main_test(args)\n",
    "\n",
    "            print('\\n{} - Testing finished\\n'.format(exp_name))           "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f9c5137",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
