{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "KJ_Ijt20j05r",
    "outputId": "6bc01c97-2bf9-4094-a328-5d553716bf39"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "\n",
    "torch.manual_seed(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "SXVFwPrcj-dE",
    "outputId": "94e5f668-1539-4003-f0ec-6b240f4f1f95"
   },
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0MWbCIsbj-jw"
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "import torch.optim as optim\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 101,
     "referenced_widgets": [
      "3bd056c952374f3eb698a5cfb93c0052",
      "542a74be255e452a90c02672ee2e9a8a",
      "29d82448c2f44ec4b8517037c7426373",
      "499db43433dd44a186bc12a58e56fafe",
      "4bfbcfa4bd0d497c9471196fc67cf145",
      "038cd24e5cc34840a127a05d330cc710",
      "810e20dcd6334c1094853fbbcf017bd7",
      "953d696e8fb34b1c820e333c4c76d38a"
     ]
    },
    "id": "BtW8tupHj-l3",
    "outputId": "8047c605-1487-4de5-b90e-227f41e5ba19"
   },
   "outputs": [],
   "source": [
    "from torchvision import datasets, transforms\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n",
    "                                          shuffle=True, num_workers=2)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
    "                                       download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n",
    "                                         shuffle=False, num_workers=2)\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat',\n",
    "           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "M_-4edhGkLMj",
    "outputId": "582839dd-527f-429a-9f61-e6d45d8cfca2"
   },
   "outputs": [],
   "source": [
    "#Training data size: \n",
    "print(len(trainloader.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "rr-adbc-j-y5",
    "outputId": "5a45fdfd-e7a3-4525-e4c8-fddac44c3254"
   },
   "outputs": [],
   "source": [
    "#Test data size:\n",
    "print(len(testloader.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 155
    },
    "id": "duf3_iGBj-2m",
    "outputId": "9859585d-0e04-44af-f722-e47189db70ce"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# functions to show an image\n",
    "\n",
    "def imshow(img):\n",
    "    img = img / 2 + 0.5     # unnormalize\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "    plt.show()\n",
    "\n",
    "# get some random training images\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "# show images\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "# print labels\n",
    "print(' '.join('%5s' % classes[labels[j]] for j in range(4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b_a-THe0j-5Y"
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BY51oDJokl_B"
   },
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SuzL6V0MkmCF"
   },
   "outputs": [],
   "source": [
    "def trainCustomSPPA(net1, net2, device, optimizer, seed, epoch_outer, epoch_inner, acc_array, loss_inner, loss_outer, Name ,many,mode):\n",
    "    #net1 is the network for outer optimization\n",
    "    #net2 is the network for the inner optimization\n",
    "    #seed-2\n",
    "    #epoch_outer-5\n",
    "    #epoch_inner-1\n",
    "    #many-500\n",
    "    #mode ---1 lr 1/sqrt(t)\n",
    "    #mode ---2 lr 1/t\n",
    "    #mode ---other lr=other i.e. constant\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    loss_inner=[]\n",
    "    loss_outer=[]\n",
    "    acc_array=[]\n",
    "    co=0\n",
    "    for epoch in range(epoch_outer):  \n",
    "    # loop over the dataset multiple times\n",
    "        running_loss = 0.0\n",
    "        for i, data in enumerate(trainloader, 0):\n",
    "            # get the inputs; data is a list of [inputs, labels]\n",
    "            inputs, labels = data\n",
    "            inputs, labels = inputs.cuda(), labels.cuda()\n",
    "            if i % many == 0:\n",
    "                correct = 0\n",
    "                total = 0\n",
    "                with torch.no_grad():\n",
    "                    for data in testloader:\n",
    "                        images, labels = data\n",
    "                        images, labels = images.cuda(), labels.cuda()\n",
    "                        outputs = net1(images)\n",
    "                        _, predicted = torch.max(outputs.data, 1)\n",
    "                        total += labels.size(0)\n",
    "                        correct += (predicted == labels).sum().item()\n",
    "\n",
    "                print('Accuracy of the network on the 10000 test images: %d %%' % (\n",
    "                    100 * correct / total))\n",
    "                acc=(100 * correct / total)\n",
    "                acc_array.append(acc)\n",
    "            outputs = net1(inputs)\n",
    "            loss1 = criterion(outputs, labels)\n",
    "            co=co+1\n",
    "            if (co==1):\n",
    "              print(\"Loss:\",float(loss1))\n",
    "              loss_outer.append(float(loss1))\n",
    "              print(\"Loss :\",loss1)\n",
    "            loss1.backward()\n",
    "\n",
    "            if i% many == 0:\n",
    "                for epoch2 in range(epoch_inner):  # loop over the dataset multiple times\n",
    "                    running_loss2 = 0.0\n",
    "                    t=0\n",
    "                    for i2, data2 in enumerate(trainloader, 0):\n",
    "                        # get the inputs; data is a list of [inputs, labels]\n",
    "                        t=t+1\n",
    "                        inputs2, labels2 = data2\n",
    "                        inputs2, labels2 = inputs2.cuda(), labels2.cuda()\n",
    "                        optimizer.zero_grad()\n",
    "                        outputs2 = net2(inputs2)\n",
    "                        loss2 = criterion(outputs2, labels2)\n",
    "                        l2=0\n",
    "                        if mode==1:\n",
    "                            lr=1.0/math.sqrt(t)\n",
    "                        elif mode==2:\n",
    "                            lr=1.0/t\n",
    "                        else:\n",
    "                            lr=mode\n",
    "                        for p2 in net2.parameters():\n",
    "                            l2 = l2 + p2.norm()\n",
    "                        loss2 = loss2 + (lr )*l2\n",
    "                        loss2.backward()\n",
    "                        def closure():\n",
    "                            # zero the parameter gradients\n",
    "                            optimizer.zero_grad()\n",
    "                            outputs2 = net2(inputs2)\n",
    "                            loss2 = criterion(outputs2, labels2)\n",
    "                            l2=0\n",
    "                            if mode==1:\n",
    "                                lr = 1.0/math.sqrt(t)\n",
    "                            elif mode==2:\n",
    "                                lr = 1.0/t\n",
    "                            else:\n",
    "                                lr = mode\n",
    "                            c1=0\n",
    "                            c2=0\n",
    "                            for n2, p2 in net2.named_parameters():\n",
    "                                c2=c2+1\n",
    "                                for n1, p1 in net1.named_parameters():\n",
    "                                    c1=c1+1\n",
    "                                    if (c1==c2):\n",
    "                                        l2=l2+(p2.data-p1.data).norm()\n",
    "                            #print(\"Loss before update \", loss2)\n",
    "                            loss2=loss2+ ( 0.5*lr ) * l2\n",
    "                            #print(\"Loss after update \",loss2)\n",
    "                            loss2.backward()\n",
    "                            return loss2\n",
    "                        optimizer.step(closure)\n",
    "                        running_loss2 += loss2.item()\n",
    "                        if i2 % many == (many-1):    # print every 1000 mini-batches\n",
    "                            loss_inner.append(running_loss2/many)\n",
    "                            if (mode==1):\n",
    "                                print('\\t '+str(Name)+' inner 1/sqrt(t) --- [%d, %5d] loss: %.3f' %\n",
    "                                      (epoch2 + 1, i2 + 1, running_loss2 / many))\n",
    "                            elif (mode==2):\n",
    "                                print('\\t '+str(Name)+' inner 1/t --- [%d, %5d] loss: %.3f' %\n",
    "                                      (epoch2 + 1, i2 + 1, running_loss2 / many))\n",
    "                            else :\n",
    "                                print('\\t '+str(Name)+' inner lr = '+str(lr)+' --- [%d, %5d] loss: %.3f' %\n",
    "                                      (epoch2 + 1, i2 + 1, running_loss2 / many))\n",
    "                            running_loss2 = 0.0\n",
    "                print(\"*************************\")\n",
    "            net1.load_state_dict(net2.state_dict())\n",
    "            \n",
    "            running_loss += loss1.item()\n",
    "            if i % many == (many-1):    # print every 2000 mini-batches\n",
    "                print('SPPA --- [%d, %5d] loss: %.3f' %\n",
    "                      (epoch + 1, i + 1, running_loss / many))\n",
    "                loss_outer.append(running_loss/many)\n",
    "                running_loss = 0.0\n",
    "\n",
    "\n",
    "    print('Finished Training')\n",
    "    return loss_inner, loss_outer, acc_array\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eWKc82vp-2QR"
   },
   "outputs": [],
   "source": [
    "def traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many):\n",
    "    #net1 is the network for optimization\n",
    "    #seed-2\n",
    "    #epoch_no-5\n",
    "    #many-500\n",
    "    torch.manual_seed(seed)\n",
    "    loss_array=[]\n",
    "    acc_array=[]\n",
    "    co=0\n",
    "    for epoch in range(epoch_no):  \n",
    "    # loop over the dataset multiple times\n",
    "        running_loss = 0.0\n",
    "        for i, data in enumerate(trainloader, 0):\n",
    "            # get the inputs; data is a list of [inputs, labels]\n",
    "\n",
    "                    \n",
    "            if i % many == 0:\n",
    "                correct = 0\n",
    "                total = 0\n",
    "                with torch.no_grad():\n",
    "                    for data_test in testloader:\n",
    "                        images, labels = data_test\n",
    "                        images, labels = images.cuda(), labels.cuda()\n",
    "                        outputs = net1(images)\n",
    "                        _, predicted = torch.max(outputs.data, 1)\n",
    "                        total += labels.size(0)\n",
    "                        correct += (predicted == labels).sum().item()\n",
    "\n",
    "                print('Accuracy of the network on the 10000 test images: %d %%' % (\n",
    "                    100 * correct / total))\n",
    "                acc=(100 * correct / total)\n",
    "                acc_array.append(acc)\n",
    "\n",
    "\n",
    "            inputs, labels = data\n",
    "            inputs, labels = inputs.cuda(), labels.cuda()\n",
    "            def closure():\n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "                outputs = net1(inputs)\n",
    "                loss1 = criterion(outputs, labels)\n",
    "                loss1.backward()\n",
    "                return loss1\n",
    "            optimizer.zero_grad()\n",
    "            outputs = net1(inputs)\n",
    "            loss1 = criterion(outputs, labels)\n",
    "\n",
    "            co=co+1\n",
    "            if (co==1):\n",
    "                print(\"Loss:\",float(loss1))\n",
    "                loss_array.append(float(loss1))\n",
    "                print(\"Loss :\",loss1)\n",
    "            loss1.backward()\n",
    "\n",
    "            optimizer.step(closure)\n",
    "            \n",
    "            running_loss += loss1.item()\n",
    "            if i % many == (many-1):    # print every 2000 mini-batches\n",
    "                print(Name + '--- [%d, %5d] loss: %.3f' %\n",
    "                      (epoch + 1, i + 1, running_loss / many))\n",
    "                loss_array.append(running_loss/many)\n",
    "                running_loss = 0.0\n",
    "\n",
    "\n",
    "    print('Finished Training')\n",
    "    return loss_array, acc_array\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "oM0LKyx-kmLZ",
    "outputId": "1c88413a-3d0c-4b67-f312-93a159ad4f4c"
   },
   "outputs": [],
   "source": [
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adam1=Net()\n",
    "net_adam2=Net()\n",
    "\n",
    "net_adam1.to(device)\n",
    "net_adam2.to(device)\n",
    "\n",
    "optimizer_adam = optim.Adam(net_adam2.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)\n",
    "loss_inner_adam=[]\n",
    "loss_outer_adam=[]\n",
    "acc_array_adam=[]\n",
    "\n",
    "loss_inner_adam,loss_outer_adam,acc_array_adam = trainCustomSPPA(net_adam1, net_adam2, device, optimizer_adam, seed, epoch_outer, epoch_inner, acc_array_adam, loss_inner_adam, loss_outer_adam, 'ADAM' ,many,mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UYgkFc6KgiTv"
   },
   "outputs": [],
   "source": [
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=2\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adam1_=Net()\n",
    "net_adam2_=Net()\n",
    "\n",
    "net_adam1_.to(device)\n",
    "net_adam2_.to(device)\n",
    "\n",
    "optimizer_adam_ = optim.Adam(net_adam2_.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)\n",
    "loss_inner_adam_=[]\n",
    "loss_outer_adam_=[]\n",
    "acc_array_adam_=[]\n",
    "\n",
    "loss_inner_adam_,loss_outer_adam_,acc_array_adam_ = trainCustomSPPA(net_adam1_, net_adam2_, device, optimizer_adam_, seed, epoch_outer, epoch_inner, acc_array_adam_, loss_inner_adam_, loss_outer_adam_, 'ADAM' ,many,mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IALoEfXKjONj"
   },
   "outputs": [],
   "source": [
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=2\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adamW1_2=Net()\n",
    "net_adamW2_2=Net()\n",
    "\n",
    "net_adamW1_2.to(device)\n",
    "net_adamW2_2.to(device)\n",
    "\n",
    "optimizer_adamW_2 = optim.AdamW(net_adamW2_2.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)\n",
    "loss_inner_adamW_2=[]\n",
    "loss_outer_adamW_2=[]\n",
    "acc_array_adamW_2=[]\n",
    "\n",
    "loss_inner_adamW_2,loss_outer_adamW_2,acc_array_adamW_2 = trainCustomSPPA(net_adamW1_2, net_adamW2_2, device, optimizer_adamW_2, seed, epoch_outer, epoch_inner, acc_array_adamW_2, loss_inner_adamW_2, loss_outer_adamW_2, 'AdamW' ,many,mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R7s3XyRPkmPG"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adadelta1=Net()\n",
    "net_adadelta2=Net()\n",
    "\n",
    "net_adadelta1.to(device)\n",
    "net_adadelta2.to(device)\n",
    "\n",
    "optimizer_adadelta = optim.Adadelta(net_adadelta2.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)\n",
    "loss_inner_adadelta=[]\n",
    "loss_outer_adadelta=[]\n",
    "acc_array_adadelta=[]\n",
    "\n",
    "loss_inner_adadelta,loss_outer_adadelta,acc_array_adadelta = trainCustomSPPA(net_adadelta1, net_adadelta2, device, optimizer_adadelta, seed, epoch_outer, epoch_inner, acc_array_adadelta, loss_inner_adadelta, loss_outer_adadelta, 'Adadelta' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FFsMaW7-kmSw"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adagrad1=Net()\n",
    "net_adagrad2=Net()\n",
    "\n",
    "net_adagrad1.to(device)\n",
    "net_adagrad2.to(device)\n",
    "\n",
    "optimizer_adagrad = optim.Adagrad(net_adagrad2.parameters(), lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10)\n",
    "loss_inner_adagrad=[]\n",
    "loss_outer_adagrad=[]\n",
    "acc_array_adagrad=[]\n",
    "\n",
    "loss_inner_adagrad,loss_outer_adagrad,acc_array_adagrad = trainCustomSPPA(net_adagrad1, net_adagrad2, device, optimizer_adagrad, seed, epoch_outer, epoch_inner, acc_array_adagrad, loss_inner_adagrad, loss_outer_adagrad, 'Adagrad' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jj590XYtkmXF"
   },
   "outputs": [],
   "source": [
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adamW1=Net()\n",
    "net_adamW2=Net()\n",
    "\n",
    "net_adamW1.to(device)\n",
    "net_adamW2.to(device)\n",
    "\n",
    "optimizer_adamW = optim.AdamW(net_adamW2.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)\n",
    "loss_inner_adamW=[]\n",
    "loss_outer_adamW=[]\n",
    "acc_array_adamW=[]\n",
    "\n",
    "loss_inner_adamW,loss_outer_adamW,acc_array_adamW = trainCustomSPPA(net_adamW1, net_adamW2, device, optimizer_adamW, seed, epoch_outer, epoch_inner, acc_array_adamW, loss_inner_adamW, loss_outer_adamW, 'AdamW' ,many,mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kSUE7zSLwSwg"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adamax1=Net()\n",
    "net_adamax2=Net()\n",
    "\n",
    "net_adamax1.to(device)\n",
    "net_adamax2.to(device)\n",
    "\n",
    "optimizer_adamax = optim.Adamax(net_adamax2.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)\n",
    "loss_inner_adamax=[]\n",
    "loss_outer_adamax=[]\n",
    "acc_array_adamax=[]\n",
    "\n",
    "loss_inner_adamax,loss_outer_adamax,acc_array_adamax = trainCustomSPPA(net_adamax1, net_adamax2, device, optimizer_adamax, seed, epoch_outer, epoch_inner, acc_array_adamax, loss_inner_adamax, loss_outer_adamax, 'Adamax' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LlnIgu4_wS62"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_ASGD1=Net()\n",
    "net_ASGD2=Net()\n",
    "\n",
    "net_ASGD1.to(device)\n",
    "net_ASGD2.to(device)\n",
    "\n",
    "optimizer_ASGD = optim.ASGD(net_ASGD2.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)\n",
    "loss_inner_ASGD=[]\n",
    "loss_outer_ASGD=[]\n",
    "acc_array_ASGD=[]\n",
    "\n",
    "loss_inner_ASGD,loss_outer_ASGD,acc_array_ASGD = trainCustomSPPA(net_ASGD1, net_ASGD2, device, optimizer_ASGD, seed, epoch_outer, epoch_inner, acc_array_ASGD, loss_inner_ASGD, loss_outer_ASGD, 'ASGD' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OtV0-fIu6iPS"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_RMSprop1=Net()\n",
    "net_RMSprop2=Net()\n",
    "\n",
    "net_RMSprop1.to(device)\n",
    "net_RMSprop2.to(device)\n",
    "\n",
    "optimizer_RMSprop = optim.RMSprop(net_RMSprop2.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)\n",
    "loss_inner_RMSprop=[]\n",
    "loss_outer_RMSprop=[]\n",
    "acc_array_RMSprop=[]\n",
    "\n",
    "loss_inner_RMSprop,loss_outer_RMSprop,acc_array_RMSprop = trainCustomSPPA(net_RMSprop1, net_RMSprop2, device, optimizer_RMSprop, seed, epoch_outer, epoch_inner, acc_array_RMSprop, loss_inner_RMSprop, loss_outer_RMSprop, 'RMSprop' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "s6pKgawY7KAV"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=5\n",
    "epoch_inner=1\n",
    "many=500\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_lbfgs1=Net()\n",
    "net_lbfgs2=Net()\n",
    "\n",
    "net_lbfgs1.to(device)\n",
    "net_lbfgs2.to(device)\n",
    "\n",
    "optimizer_lbfgs = optim.LBFGS(net_lbfgs2.parameters(),max_iter=1,history_size=1)\n",
    "loss_inner_lbfgs=[]\n",
    "loss_outer_lbfgs=[]\n",
    "acc_array_lbfgs=[]\n",
    "\n",
    "loss_inner_lbfgs,loss_outer_lbfgs,acc_array_lbfgs = trainCustomSPPA(net_lbfgs1, net_lbfgs2, device, optimizer_lbfgs, seed, epoch_outer, epoch_inner, acc_array_lbfgs, loss_inner_lbfgs, loss_outer_lbfgs, 'LBFGS' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "e0_Yr9DM-VlJ"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "seed=2\n",
    "epoch_no=5\n",
    "\n",
    "many=500\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adagrad_general=Net()\n",
    "\n",
    "net_adagrad_general.to(device)\n",
    "\n",
    "optimizer_adagrad_general = optim.Adagrad(net_adagrad_general.parameters(), lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10)\n",
    "loss_array_adagrad_general=[]\n",
    "acc_array_adagrad_general=[]\n",
    "loss_array_adagrad_general, acc_array_adagrad_general = traingeneral(net_adagrad_general, device, optimizer_adagrad_general, seed, epoch_no, acc_array_adagrad_general, loss_array_adagrad_general, 'Adagrad' ,many)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5nfCcbOD-iQG"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "\n",
    "seed=2\n",
    "epoch_no=5\n",
    "\n",
    "many=500\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_sgd_general2=Net()\n",
    "\n",
    "net_sgd_general2.to(device)\n",
    "\n",
    "optimizer_sgd_general2 =optim.SGD(net_sgd_general2.parameters(), lr=0.001)\n",
    "loss_array_sgd_general2=[]\n",
    "acc_array_sgd_general2=[]\n",
    "loss_array_sgd_general2, acc_array_sgd_general2 = traingeneral(net_sgd_general2, device, optimizer_sgd_general2, seed, epoch_no, acc_array_sgd_general2, loss_array_sgd_general2, 'SGD' ,many)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tdt_sGoF-iVX"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "\n",
    "seed=2\n",
    "epoch_no=5\n",
    "\n",
    "many=500\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_sgd_general=Net()\n",
    "\n",
    "net_sgd_general.to(device)\n",
    "\n",
    "optimizer_sgd_general =optim.SGD(net_sgd_general.parameters(), lr=0.001, momentum=0.9)\n",
    "loss_array_sgd_general=[]\n",
    "acc_array_sgd_general=[]\n",
    "loss_array_sgd_general, acc_array_sgd_general = traingeneral(net_sgd_general, device, optimizer_sgd_general, seed, epoch_no, acc_array_sgd_general, loss_array_sgd_general, 'SGD Momentum' ,many)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fxfljlUu-qUE"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "'''\n",
    "seed=2\n",
    "epoch_no=5\n",
    "\n",
    "many=500\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_ASGD_general=Net()\n",
    "\n",
    "net_ASGD_general.to(device)\n",
    "\n",
    "optimizer_ASGD_general = optim.ASGD(net_ASGD_general.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)\n",
    "loss_array_ASGD_general=[]\n",
    "acc_array_ASGD_general=[]\n",
    "loss_array_ASGD_general, acc_array_ASGD_general = traingeneral(net_ASGD_general, device, optimizer_ASGD_general, seed, epoch_no, acc_array_ASGD_general, loss_array_ASGD_general, 'ASGD' ,many)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XX7EqWNl-qeZ"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "\n",
    "\n",
    "seed=2\n",
    "epoch_no=5\n",
    "\n",
    "many=500\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adam_general=Net()\n",
    "\n",
    "net_adam_general.to(device)\n",
    "\n",
    "optimizer_adam_general = optim.Adam(net_adam_general.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)\n",
    "loss_array_adam_general=[]\n",
    "acc_array_adam_general=[]\n",
    "loss_array_adam_general, acc_array_adam_general = traingeneral(net_adam_general, device, optimizer_adam_general, seed, epoch_no, acc_array_adam_general, loss_array_adam_general, 'ADAM' ,many)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y1OU-aloXMD_"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from google.colab import files\n",
    "X=[]\n",
    "for i in range(125):\n",
    "    X.append(i)\n",
    "\n",
    "plt.figure(figsize=((12,9)))\n",
    "\n",
    "#plt.plot(X, loss_outer_adadelta, label=\"SPPA-Adadelta (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, loss_outer_adagrad, label=\"SPPA-Adagrad (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, loss_outer_adamax, label=\"SPPA-Adamax (1/sqrt(t))\",linewidth=2.5)\n",
    "\n",
    "plt.plot(X, loss_outer_adamW[:125], label=\"SPPA (1/sqrt(t))\",linewidth=2.5,color='green')\n",
    "plt.plot(X, loss_outer_adamW_[:125], label=\"SPPA (1/t)\",linewidth=2.5,color='red')\n",
    "#plt.plot(X, loss_outer_ASGD, label=\"SPPA-Asgd (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, loss_outer_RMSprop, label=\"SPPA-RMSProp \",linewidth=2.5)\n",
    "\n",
    "#plt.plot(X, loss_array_adadelta_general, label=\"Adadelta \",linewidth=2.5)\n",
    "plt.plot(X, loss_array_adam_general[:125], label=\"Adam \",linewidth=2.5,color='darkorange')\n",
    "plt.plot(X, loss_array_adagrad_general[:125], label=\"Adagrad \",linewidth=2.5)\n",
    "#plt.plot(X, loss_array_adamax_general, label=\"Adamax \",linewidth=2.5)\n",
    "\n",
    "\n",
    "plt.plot(X, loss_array_sgd_general[:125], label=\"SGD with Momentum\",linewidth=2.5,color='blue')\n",
    "plt.plot(X, loss_array_sgd_general2[:125], label=\"SGD \",linewidth=2.5,color='black')\n",
    "#plt.plot(X, loss_array_adam_general, label=\"AdamW \",linewidth=2.5)\n",
    "#plt.plot(X, loss_array_ASGD_general, label=\"Asgd \",linewidth=2.5)\n",
    "#plt.plot(X, loss_array_RMSprop_general, label=\"RMSProp \",linewidth=2.5)\n",
    "\n",
    "\n",
    "plt.xticks(fontsize=26)\n",
    "plt.yticks(fontsize=26)\n",
    "plt.title(\"CIFAR10 Dataset\\n Convolutional Neural Network\",fontsize=26)\n",
    "plt.xlabel(\"500 Stochastic Update (Batch Size 4) \", fontsize=26)\n",
    "plt.ylabel(\"Cross Entropy Loss\", fontsize=26)\n",
    "plt.grid(linestyle=\"--\")\n",
    "plt.legend(fontsize=18)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"cifar10-all-algo-mod1.png\")\n",
    "files.download(\"cifar10-all-algo-mod1.png\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from google.colab import files\n",
    "X=[]\n",
    "for i in range(125):\n",
    "    X.append(i)\n",
    "\n",
    "plt.figure(figsize=((12,9)))\n",
    "\n",
    "\n",
    "plt.plot(X, acc_array_adamW[:125], label=\"SPPA (1/sqrt(t))\",linewidth=2.5,color='green')\n",
    "plt.plot(X, acc_array_adamW_[:125], label=\"SPPA (1/t)\",linewidth=2.5,color='red')\n",
    "\n",
    "plt.plot(X, acc_array_adam_general[:125], label=\"Adam \",linewidth=2.5,color='darkorange')\n",
    "plt.plot(X, acc_array_adagrad_general[:125], label=\"Adagrad \",linewidth=2.5)\n",
    "\n",
    "\n",
    "plt.plot(X, acc_array_sgd_general[:125], label=\"SGD with Momentum\",linewidth=2.5,color='blue')\n",
    "plt.plot(X, acc_array_sgd_general2[:125], label=\"SGD \",linewidth=2.5,color='black')\n",
    "\n",
    "plt.xticks(fontsize=26)\n",
    "plt.yticks(fontsize=26)\n",
    "plt.title(\"CIFAR10 Dataset\\n Convolutional Neural Network\",fontsize=26)\n",
    "plt.xlabel(\"500 Stochastic Update (Batch Size 4) \", fontsize=26)\n",
    "plt.ylabel(\"Accuracy\", fontsize=26)\n",
    "plt.grid(linestyle=\"--\")\n",
    "plt.legend(fontsize=18)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"cifar10-acc-algo-mod1.png\")\n",
    "files.download(\"cifar10-acc-algo-mod1.png\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Supplementary ICLR2020 - CIFAR10",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "038cd24e5cc34840a127a05d330cc710": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "29d82448c2f44ec4b8517037c7426373": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_038cd24e5cc34840a127a05d330cc710",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_4bfbcfa4bd0d497c9471196fc67cf145",
      "value": 1
     }
    },
    "3bd056c952374f3eb698a5cfb93c0052": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_29d82448c2f44ec4b8517037c7426373",
       "IPY_MODEL_499db43433dd44a186bc12a58e56fafe"
      ],
      "layout": "IPY_MODEL_542a74be255e452a90c02672ee2e9a8a"
     }
    },
    "499db43433dd44a186bc12a58e56fafe": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_953d696e8fb34b1c820e333c4c76d38a",
      "placeholder": "​",
      "style": "IPY_MODEL_810e20dcd6334c1094853fbbcf017bd7",
      "value": " 170500096/? [00:20&lt;00:00, 34241692.52it/s]"
     }
    },
    "4bfbcfa4bd0d497c9471196fc67cf145": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "542a74be255e452a90c02672ee2e9a8a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "810e20dcd6334c1094853fbbcf017bd7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "953d696e8fb34b1c820e333c4c76d38a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
