{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import time\n",
    "\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from model import *\n",
    "from swd_optim import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==> Preparing data..\n"
     ]
    }
   ],
   "source": [
    "print('==> Preparing data..')\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimizers(net, opti_name, lr, weight_decay):\n",
    "    if opti_name == 'VanillaSGD':\n",
    "        return optim.SGD(net.parameters(), lr=lr, momentum=0, weight_decay=weight_decay)\n",
    "    elif opti_name == 'SGD':\n",
    "        return optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=False)\n",
    "    elif opti_name == 'SGDS':\n",
    "        return SGDS(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=False)\n",
    "    elif opti_name == 'Adam':\n",
    "        return optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay)\n",
    "    elif opti_name == 'AMSGrad':\n",
    "        return optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay,amsgrad=True)\n",
    "    elif opti_name == 'AdamW':\n",
    "        return optim.AdamW(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay/lr)\n",
    "    elif opti_name == 'AdamS':\n",
    "        return AdamS(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay)\n",
    "    elif opti_name == 'Adai':\n",
    "        return Adai(net.parameters(), lr=lr, betas=(0.1, 0.99), eps=1e-03, weight_decay=weight_decay)\n",
    "    elif opti_name == 'AdaiS':\n",
    "        return AdaiS(net.parameters(), lr=lr, betas=(0.1, 0.99), eps=1e-03, weight_decay=weight_decay)\n",
    "    else:\n",
    "        raise('Unspecified optimizer.')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss(reduction='mean')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(net, optimizer, epoch):\n",
    "    print('Epoch: %d' % (epoch+1))\n",
    "    net.train()\n",
    "    train_loss = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    learning_rate = 0.1\n",
    "    for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
    "        inputs, targets = inputs.to(device), targets.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss += loss.item() * targets.size(0)\n",
    "        _, predicted = outputs.max(1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets).sum().item()\n",
    "    print(\"Training Loss: \", train_loss/total)\n",
    "    print(\"Training error:\", 1-correct/total)\n",
    "    return 1 - correct/total, train_loss/total\n",
    "\n",
    "def test(net):\n",
    "    net.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (inputs, targets) in enumerate(testloader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = net(inputs)\n",
    "            loss = criterion(outputs, targets)\n",
    "            test_loss += loss.item()\n",
    "            _, predicted = outputs.max(1)\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "    print(\"Test error:\", 1-correct/total)\n",
    "    return 1 - correct/total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def define_models(model):\n",
    "    if model == 'resnet18':\n",
    "        return ResNet18()\n",
    "    elif model == 'vgg16':\n",
    "        return VGG('VGG16')\n",
    "    else:\n",
    "        raise('Unspecified model.')\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimizer_peformance(model, learning_rate, batch_size, weight_decay, epochs, N, mode):\n",
    "    net = define_models(model)\n",
    "    net = net.to(device)\n",
    "    if device == 'cuda':\n",
    "        net = torch.nn.DataParallel(net)\n",
    "        cudnn.benchmark = True\n",
    "        \n",
    "    train_err = []\n",
    "    train_loss = []\n",
    "    test_err = []\n",
    "    \n",
    "    opti_name = mode\n",
    "    optimizer = optimizers(net, opti_name, learning_rate, weight_decay)\n",
    "    \n",
    "    lambda_lr = lambda epoch: 0.1 ** (epoch // 80) \n",
    "    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_lr)\n",
    "    \n",
    "    start_time = time.time()\n",
    "    print(\"-\"*30+\" Mode \"+ mode+\" starts\" )\n",
    "    for epoch in range(epochs):\n",
    "        train_err_i, train_loss_i = train(net, optimizer, epoch)\n",
    "        train_err.append(train_err_i)\n",
    "        train_loss.append(train_loss_i)\n",
    "        test_err.append(test(net))\n",
    "        scheduler.step()\n",
    "        print (\"--- %s seconds ---\" % (time.time() - start_time))\n",
    "        \n",
    "    save_err({mode:train_loss}, {mode:train_err}, {mode:test_err}, model+'_'+mode, learning_rate, batch_size, weight_decay, epochs, N)\n",
    "    return train_loss, train_err, test_err"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimizer_performance_comparison(model, batch_size, weight_decay, epochs, N):\n",
    "    train_loss, train_err, test_err = {}, {}, {}\n",
    "    mode_list = ['SGDS', 'SGD', 'Adam', 'AdamW', 'AdamS'] \n",
    "    lr_dict = [0.1, 0.1, 1e-3, 1e-3, 1e-3]\n",
    "    for i,mode in enumerate(mode_list):\n",
    "        train_loss[mode], train_err[mode], test_err[mode] = optimizer_peformance(model, lr_dict[i], batch_size, weight_decay, epochs, N, mode.split(None, 1)[0])\n",
    "    return train_loss, train_err, test_err \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_figure(model, train_loss, train_err, test_err, batch_size, weight_decay, epochs, N): \n",
    "    figure_name = model + '_B'+str(batch_size) + '_N'+ str(N) + '_E' + str(epochs)\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots\n",
    "    plt.rcParams['image.interpolation'] = 'nearest'\n",
    "    plt.rcParams['image.cmap'] = 'gray'\n",
    "\n",
    "    fig = plt.figure()\n",
    "\n",
    "    axes = plt.gca()\n",
    "    for key in test_err:\n",
    "        break\n",
    "    axes.set_ylim([0., 0.2])\n",
    "    axes.set_xlim([0,epochs])\n",
    "    mode_list = ['SGD', 'SGDS', 'Adam', 'AdamW', 'AdamS'] \n",
    "    colors = ['red','blue','green','orange','pink','cyan','brown','yellow','black']\n",
    "    for idx,mode in enumerate(mode_list):\n",
    "        plt.plot(np.arange(1,epochs+1), test_err[mode], label=mode, ls='solid', linewidth=2, color=colors[idx])\n",
    "        \n",
    "    plt.ylabel('Test Error')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "    \n",
    "    fig.savefig('Test_errors_'+figure_name + '.png')\n",
    "    fig.savefig('Test_errors_'+figure_name+'.pdf', format='pdf', bbox_inches = 'tight')\n",
    "    \n",
    "    fig = plt.figure()\n",
    "    axes = plt.gca()\n",
    "    axes.set_yscale('log')\n",
    "    axes.set_ylim([1e-4, 1.])\n",
    "    axes.set_xlim([0,epochs])\n",
    "    for idx,mode in enumerate(mode_list):\n",
    "        plt.plot(np.arange(1,epochs+1), train_loss[mode], label=mode, ls='solid', linewidth=2, color=colors[idx])\n",
    "\n",
    "    plt.ylabel('Training Loss')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "    \n",
    "    fig.savefig('Training_loss_'+figure_name + '.png')\n",
    "    fig.savefig('Training_loss_'+figure_name+'.pdf', format='pdf', bbox_inches = 'tight')\n",
    "    \n",
    "    \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_err(train_loss, train_err, test_err, model, learning_rate, batch_size, weight_decay, epochs, N):\n",
    "    csvname = model + '_LR'+str(learning_rate) + '_B'+str(batch_size) + '_N'+ str(N) + '_E' + str(epochs)\n",
    "    csvname = 'Curves_' + csvname\n",
    "\n",
    "    current_name = csvname +'.csv'\n",
    "    files_present = glob.glob(current_name)\n",
    "    if files_present:\n",
    "        print('WARNING: This file already exists!')\n",
    "    data_dict = {}\n",
    "    for mode in test_err:\n",
    "        data_dict[mode+'_test_err'] = test_err[mode]\n",
    "        data_dict[mode+'_training_err'] = train_err[mode]\n",
    "        data_dict[mode+'_training_loss'] = train_loss[mode]\n",
    "    df = pd.DataFrame(data=data_dict)\n",
    "    if not files_present:\n",
    "        df.to_csv(current_name, sep=',', header=True, index=False)\n",
    "    else:\n",
    "        print('WARNING: This file already exists!')\n",
    "        for i in range(1,30):\n",
    "            files_present = glob.glob(csvname+'_'+str(i)+'.csv')\n",
    "            if not files_present:\n",
    "                df.to_csv(csvname+'_'+str(i)+'.csv', sep=',', header=True, index=False)\n",
    "                return None\n",
    "    return None\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "weight_decay = 5e-4\n",
    "epochs = 200\n",
    "\n",
    "#Training data size\n",
    "N = 50000 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "model = 'resnet18'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loss, train_err, test_err = optimizer_performance_comparison(model, batch_size, weight_decay, epochs, N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_figure(model, train_loss, train_err, test_err, batch_size, weight_decay, epochs, N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'vgg16'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loss, train_err, test_err = optimizer_performance_comparison(model, batch_size, weight_decay, epochs, N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_figure(model, train_loss, train_err, test_err, batch_size, weight_decay, epochs, N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
