{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "%autoreload\n",
    "import sys\n",
    "sys.path.insert(0, \"../\")\n",
    "sys.path.insert(0, \"../../\")\n",
    "\n",
    "import warnings\n",
    "import os\n",
    "import time\n",
    "\n",
    "import math as m\n",
    "import numpy as np\n",
    "np.random.seed(1)\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import keras\n",
    "import keras.backend as K\n",
    "from keras.callbacks import LearningRateScheduler\n",
    "from keras.optimizers import SGD, Adam\n",
    "from utils_training import history_todict, lr_schedule, StoppingCriteria\n",
    "\n",
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "\n",
    "from save_results import ResultSaver\n",
    "\n",
    "from tasks import load_task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# storing results\n",
    "result_saver = {'C100coarse-WRN':ResultSaver('C100coarse-WRN_model_performances.p'),\n",
    "                'C100-WRN':ResultSaver('C100-WRN_model_performances.p'),\n",
    "                'C10-VGG':ResultSaver('C10-VGG_model_performances.p')}\n",
    "\n",
    "if not os.path.exists(\"saved_weights\"):\n",
    "    os.makedirs(\"saved_weights\")\n",
    "\n",
    "# file for monitoring the experiment's progress\n",
    "monitor_file = 'monitor.txt' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_and_evaluate(task,param):\n",
    "    get_model, x_train, y_train, suby_train, x_test, y_test, suby_test = load_task(task)\n",
    "    model = get_model(param['depth_factor'],param['width_factor'],param['dropout'],param['weight_decay'])\n",
    "\n",
    "    lr = param['lr']/100. if param['optimizer']=='adam' else param['lr']\n",
    "    batch_size = param['batch_size']\n",
    "    \n",
    "    epochs = 250\n",
    "    lr_sched = LearningRateScheduler(lr_schedule(lr,0.2,[150,230,240]))\n",
    "    stop = StoppingCriteria(finished = 1e-4)\n",
    "\n",
    "    optimizer = Adam(lr) if param['optimizer']=='adam' else SGD(lr,momentum = 0.9)\n",
    "    model.compile(loss='categorical_crossentropy',\n",
    "                  optimizer=optimizer,\n",
    "                  metrics=['accuracy'])\n",
    "\n",
    "    warnings.simplefilter(\"ignore\") # removes warning from keras for slow callback\n",
    "    if param['data_augmentation']:\n",
    "        datagen = ImageDataGenerator(width_shift_range=0.125,\n",
    "                                     height_shift_range=0.125,\n",
    "                                     fill_mode='reflect',\n",
    "                                     horizontal_flip=True)\n",
    "    else: \n",
    "        datagen = ImageDataGenerator()\n",
    "    history = model.fit_generator(datagen.flow(x_train, y_train,batch_size=batch_size),\n",
    "                                  steps_per_epoch=int(x_train.shape[0]/batch_size), \n",
    "                                  epochs=epochs,\n",
    "                                  verbose = 0,\n",
    "                                  callbacks = [lr_sched,stop])   \n",
    "    \n",
    "    test_performance = model.evaluate(x_test,y_test,batch_size = 300)\n",
    "    \n",
    "    return model, history, test_performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create set of hyperparam values\n",
    "effective_lrs = [{'lr':0.01, 'weight_decay':0.},\n",
    "                 {'lr':0.032, 'weight_decay':0.},\n",
    "                 {'lr':0.1, 'weight_decay':0.},\n",
    "                 {'lr':0.1, 'weight_decay':4e-5}]\n",
    "\n",
    "noise_regularizations = [{'dropout':0., 'data_augmentation':False},\n",
    "                         {'dropout':0.2, 'data_augmentation':False},\n",
    "                         {'dropout':0.4, 'data_augmentation':False},\n",
    "                         {'dropout':0., 'data_augmentation':True}]\n",
    "\n",
    "architectures = [{'depth_factor':1.,'width_factor':1.},\n",
    "                 {'depth_factor':1.,'width_factor':1.5},\n",
    "                 {'depth_factor':1.5,'width_factor':1.}]\n",
    "\n",
    "params = []\n",
    "for architecture in architectures:\n",
    "    for batch_size in [100,300]: \n",
    "        for optimizer in ['SGD','adam']:\n",
    "            for noise_regularization in noise_regularizations:\n",
    "                for effective_lr in effective_lrs:\n",
    "                    param = {}\n",
    "                    param.update({'batch_size':batch_size,'optimizer':optimizer})\n",
    "                    param.update(architecture)\n",
    "                    param.update(noise_regularization)\n",
    "                    param.update(effective_lr)\n",
    "                    params.append(param)\n",
    "print(len(params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "task = 'C100-WRN'\n",
    "results = result_saver[task].load_results()\n",
    "\n",
    "if not os.path.exists(\"saved_weights/\"+task):\n",
    "    os.makedirs(\"saved_weights/\"+task)\n",
    "for param in params:\n",
    "    if frozenset(param.items()) not in results.keys():\n",
    "        print(param)\n",
    "        start = time.time()\n",
    "        model, history, test_performance = fit_and_evaluate(task,param)\n",
    "        \n",
    "        model.save_weights('saved_weights/'+task+'/'+str([str(key)+':'+str(param[key]) for key in sorted(param)])+'.h5')\n",
    "        \n",
    "        result_saver[task].update_results([],{frozenset(param.items()):{'history':history_todict(history), \n",
    "                                                                        'test_performance':test_performance}})\n",
    "    \n",
    "        with open(monitor_file,'a') as file:\n",
    "            file.write(task + ', '+str(param)+': done in '+str(time.time()-start)+' seconds.\\n')\n",
    "        \n",
    "        K.clear_session()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
