{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "36ca04d4",
   "metadata": {},
   "source": [
    "# Loss Landscapes on CIFAR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f199c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import copy\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import models\n",
    "import ops.tests as tests\n",
    "import ops.datasets as datasets\n",
    "import ops.loss_landscapes as lls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48de8d8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# config_path = \"configs/cifar10_general.json\"\n",
    "config_path = \"configs/cifar100_general.json\"\n",
    "\n",
    "with open(config_path) as f:\n",
    "    args = json.load(f)\n",
    "    print(\"args: \\n\", args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50693515",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_args = copy.deepcopy(args).get(\"dataset\")\n",
    "train_args = copy.deepcopy(args).get(\"train\")\n",
    "val_args = copy.deepcopy(args).get(\"val\")\n",
    "model_args = copy.deepcopy(args).get(\"model\")\n",
    "optim_args = copy.deepcopy(args).get(\"optim\")\n",
    "env_args = copy.deepcopy(args).get(\"env\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5113c601",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train, dataset_test = datasets.get_dataset(**dataset_args, download=True)\n",
    "dataset_name = dataset_args[\"name\"]\n",
    "num_classes = len(dataset_train.classes)\n",
    "\n",
    "dataset_train = DataLoader(dataset_train, \n",
    "                           shuffle=True, \n",
    "                           num_workers=train_args.get(\"num_workers\", 4), \n",
    "                           batch_size=train_args.get(\"batch_size\", 128))\n",
    "dataset_test = DataLoader(dataset_test, \n",
    "                          num_workers=val_args.get(\"num_workers\", 4), \n",
    "                          batch_size=val_args.get(\"batch_size\", 128))\n",
    "\n",
    "print(\"Train: %s, Test: %s, Classes: %s\" % (\n",
    "    len(dataset_train.dataset), \n",
    "    len(dataset_test.dataset), \n",
    "    num_classes\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cea9a688",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce626168",
   "metadata": {},
   "outputs": [],
   "source": [
    "# VGG\n",
    "# name = \"vgg_dnn_19\"\n",
    "# name = \"vgg_dnn_smoothing_19\"\n",
    "# name = \"vgg_mcdo_19\"\n",
    "# name = \"vgg_mcdo_smoothing_19\"\n",
    "\n",
    "# ResNet\n",
    "name = \"resnet_dnn_18\"\n",
    "# name = \"resnet_dnn_smoothing_18\"\n",
    "# name = \"resnet_mcdo_18\"\n",
    "# name = \"resnet_mcdo_smoothing_18\"\n",
    "\n",
    "# name = \"resnet_dnn_50\"\n",
    "# name = \"resnet_mcdo_50\"\n",
    "# name = \"resnet_dnn_smoothing_50\"\n",
    "# name = \"resnet_mcdo_smoothing_50\"\n",
    "\n",
    "# Preact ResNet\n",
    "# name = \"preresnet_dnn_50\"\n",
    "# name = \"preresnet_mcdo_50\"\n",
    "# name = \"preresnet_dnn_smoothing_50\"\n",
    "# name = \"preresnet_mcdo_smoothing_50\"\n",
    "\n",
    "# ResNeXt\n",
    "# name = \"resnext_dnn_50\"\n",
    "# name = \"resnext_mcdo_50\"\n",
    "# name = \"resnext_dnn_smoothing_50\"\n",
    "# name = \"resnext_mcdo_smoothing_50\"\n",
    "\n",
    "# WideResNet\n",
    "# name = \"wideresnet_dnn_50\"\n",
    "# name = \"wideresnet_mcdo_50\"\n",
    "# name = \"wideresnet_dnn_smoothing_50\"\n",
    "# name = \"wideresnet_mcdo_smoothing_50\"\n",
    "\n",
    "\n",
    "uid = \"\"  # Model UID required\n",
    "model = models.get_model(name, num_classes=num_classes, \n",
    "                         stem=model_args.get(\"stem\", False))\n",
    "models.load(model, dataset_name, uid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a377a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpu = torch.cuda.is_available()\n",
    "\n",
    "model = model.cuda() if gpu else model.cpu()\n",
    "metrics_list = []\n",
    "for n_ff in [1]:\n",
    "    print(\"N: %s, \" % n_ff, end=\"\")\n",
    "    *metrics, cal_diag = tests.test(model, n_ff, dataset_test, verbose=False, gpu=gpu)\n",
    "    metrics_list.append([n_ff, *metrics])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e99d1e33",
   "metadata": {},
   "source": [
    "## Investigate the Loss Landscape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "192d134a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "scale = 1e-1\n",
    "n = 21\n",
    "\n",
    "metrics_grid = lls.get_loss_landscape(\n",
    "    model, 1, dataset_train, \n",
    "    x_min=-1.0 * scale, x_max=1.0 * scale, n_x=n, y_min=-1.0 * scale, y_max=1.0 * scale, n_y=n,\n",
    ")\n",
    "leaderboard_path = os.path.join(\"leaderboard\", \"logs\", dataset_name, model.name)\n",
    "Path(leaderboard_path).mkdir(parents=True, exist_ok=True)\n",
    "metrics_dir = os.path.join(leaderboard_path, \"%s_%s_%s_x%s_losslandscape.csv\" % (dataset_name, model.name, uid, int(1 / scale)))\n",
    "metrics_list = [[*grid, *metrics] for grid, metrics in metrics_grid.items()]\n",
    "tests.save_metrics(metrics_dir, metrics_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ef9cf2f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
