{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "KJ_Ijt20j05r",
    "outputId": "13657666-7b0c-4584-b9e9-2039c5571eb6"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "\n",
    "torch.manual_seed(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "SXVFwPrcj-dE",
    "outputId": "0c571c39-d0a9-4954-fa68-c226ac71bff1"
   },
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0MWbCIsbj-jw"
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "import torch.optim as optim\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BtW8tupHj-l3"
   },
   "outputs": [],
   "source": [
    "from torchvision import datasets, transforms\n",
    "\n",
    "kwargs = {}\n",
    "trainloader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "batch_size=64, shuffle=True, **kwargs)\n",
    "testloader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "batch_size=64, shuffle=True, **kwargs)\n",
    "\n",
    "classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "M_-4edhGkLMj",
    "outputId": "5cf74f43-f1fa-473f-9005-0d005ceb1434"
   },
   "outputs": [],
   "source": [
    "#Training data size: \n",
    "print(len(trainloader.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "rr-adbc-j-y5",
    "outputId": "d155ef7e-e81b-462e-8fde-0b51b368dd72"
   },
   "outputs": [],
   "source": [
    "#Test data size:\n",
    "print(len(testloader.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b_a-THe0j-5Y"
   },
   "outputs": [],
   "source": [
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        self.dropout1 = nn.Dropout2d(0.25)\n",
    "        self.dropout2 = nn.Dropout2d(0.5)\n",
    "        self.fc1 = nn.Linear(9216, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BY51oDJokl_B"
   },
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SuzL6V0MkmCF"
   },
   "outputs": [],
   "source": [
    "def trainCustomSPPA(net1, net2, device, optimizer, seed, epoch_outer, epoch_inner, acc_array, loss_inner, loss_outer, Name ,many,mode):\n",
    "    #net1 is the network for outer optimization\n",
    "    #net2 is the network for the inner optimization\n",
    "    #seed-2\n",
    "    #epoch_outer-5\n",
    "    #epoch_inner-1\n",
    "    #many-500\n",
    "    #mode ---1 lr 1/sqrt(t)\n",
    "    #mode ---2 lr 1/t\n",
    "    #mode ---other lr=other i.e. constant\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    loss_inner=[]\n",
    "    loss_outer=[]\n",
    "    acc_array=[]\n",
    "    co=0\n",
    "    for epoch in range(epoch_outer):  \n",
    "    # loop over the dataset multiple times\n",
    "        running_loss = 0.0\n",
    "        for i, data in enumerate(trainloader, 0):\n",
    "            # get the inputs; data is a list of [inputs, labels]\n",
    "\n",
    "            if i % many == 0:\n",
    "                correct = 0\n",
    "                total = 0\n",
    "                with torch.no_grad():\n",
    "                    for data_test in testloader:\n",
    "                        images, labels = data_test\n",
    "                        images, labels = images.cuda(), labels.cuda()\n",
    "                        outputs = net1(images)\n",
    "                        _, predicted = torch.max(outputs.data, 1)\n",
    "                        total += labels.size(0)\n",
    "                        correct += (predicted == labels).sum().item()\n",
    "\n",
    "                print('Accuracy of the network on the 10000 test images: %d %%' % (\n",
    "                    100 * correct / total))\n",
    "                acc=(100 * correct / total)\n",
    "                acc_array.append(acc)\n",
    "            inputs, labels = data\n",
    "            inputs, labels = inputs.cuda(), labels.cuda()\n",
    "            outputs = net1(inputs)\n",
    "            loss1 = criterion(outputs, labels)\n",
    "\n",
    "            co=co+1\n",
    "            if (co==1):\n",
    "                print(\"Loss:\",float(loss1))\n",
    "                loss_outer.append(float(loss1))\n",
    "                print(\"Loss :\",loss1)\n",
    "            loss1.backward()\n",
    "\n",
    "            if i% many == 0:\n",
    "                for epoch2 in range(epoch_inner):  # loop over the dataset multiple times\n",
    "                    running_loss2 = 0.0\n",
    "                    t=0\n",
    "                    for i2, data2 in enumerate(trainloader, 0):\n",
    "                        # get the inputs; data is a list of [inputs, labels]\n",
    "                        t=t+1\n",
    "                        inputs2, labels2 = data2\n",
    "                        inputs2, labels2 = inputs2.cuda(), labels2.cuda()\n",
    "                        optimizer.zero_grad()\n",
    "                        outputs2 = net2(inputs2)\n",
    "                        loss2 = criterion(outputs2, labels2)\n",
    "                        l2=0\n",
    "                        if mode==1:\n",
    "                            lr=1.0/math.sqrt(t)\n",
    "                        elif mode==2:\n",
    "                            lr=1.0/t\n",
    "                        else:\n",
    "                            lr=mode\n",
    "                        for p2 in net2.parameters():\n",
    "                            l2 = l2 + p2.norm()\n",
    "                        loss2 = loss2 + (0.5*lr )*l2\n",
    "                        loss2.backward()\n",
    "                        def closure():\n",
    "                            # zero the parameter gradients\n",
    "                            optimizer.zero_grad()\n",
    "                            outputs2 = net2(inputs2)\n",
    "                            loss2 = criterion(outputs2, labels2)\n",
    "                            l2=0\n",
    "                            if mode==1:\n",
    "                                lr = 1.0/math.sqrt(t)\n",
    "                            elif mode==2:\n",
    "                                lr = 1.0/t\n",
    "                            else:\n",
    "                                lr = mode\n",
    "                            c1=0\n",
    "                            c2=0\n",
    "                            for n2, p2 in net2.named_parameters():\n",
    "                                c2=c2+1\n",
    "                                for n1, p1 in net1.named_parameters():\n",
    "                                    c1=c1+1\n",
    "                                    if (c1==c2):\n",
    "                                        l2=l2+(p2.data-p1.data).norm()\n",
    "                            #print(\"Loss before update \", loss2)\n",
    "                            loss2=loss2+ ( 0.5*lr ) * l2\n",
    "                            #print(\"Loss after update \",loss2)\n",
    "                            loss2.backward()\n",
    "                            return loss2\n",
    "                        optimizer.step(closure)\n",
    "                        running_loss2 += loss2.item()\n",
    "                        if i2 % many == (many-1):    # print every 1000 mini-batches\n",
    "                            loss_inner.append(running_loss2/many)\n",
    "                            if (mode==1):\n",
    "                                print('\\t '+str(Name)+' inner 1/sqrt(t) --- [%d, %5d] loss: %.3f' %\n",
    "                                      (epoch2 + 1, i2 + 1, running_loss2 / many))\n",
    "                            elif (mode==2):\n",
    "                                print('\\t '+str(Name)+' inner 1/t --- [%d, %5d] loss: %.3f' %\n",
    "                                      (epoch2 + 1, i2 + 1, running_loss2 / many))\n",
    "                            else :\n",
    "                                print('\\t '+str(Name)+' inner lr = '+str(lr)+' --- [%d, %5d] loss: %.3f' %\n",
    "                                      (epoch2 + 1, i2 + 1, running_loss2 / many))\n",
    "                            running_loss2 = 0.0\n",
    "                print(\"*************************\")\n",
    "            net1.load_state_dict(net2.state_dict())\n",
    "            \n",
    "            running_loss += loss1.item()\n",
    "            if i % many == (many-1):    # print every 2000 mini-batches\n",
    "                print('SPPA --- [%d, %5d] loss: %.3f' %\n",
    "                      (epoch + 1, i + 1, running_loss / many))\n",
    "                loss_outer.append(running_loss/many)\n",
    "                running_loss = 0.0\n",
    "\n",
    "\n",
    "    print('Finished Training')\n",
    "    return loss_inner, loss_outer, acc_array\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aUOqV3b7ibup"
   },
   "outputs": [],
   "source": [
    "def traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many):\n",
    "    #net1 is the network for optimization\n",
    "    #seed-2\n",
    "    #epoch_no-5\n",
    "    #many-500\n",
    "    torch.manual_seed(seed)\n",
    "    loss_array=[]\n",
    "    acc_array=[]\n",
    "    co=0\n",
    "    for epoch in range(epoch_no):  \n",
    "    # loop over the dataset multiple times\n",
    "        running_loss = 0.0\n",
    "        for i, data in enumerate(trainloader, 0):\n",
    "            # get the inputs; data is a list of [inputs, labels]\n",
    "\n",
    "                    \n",
    "            if i % many == 0:\n",
    "                correct = 0\n",
    "                total = 0\n",
    "                with torch.no_grad():\n",
    "                    for data_test in testloader:\n",
    "                        images, labels = data_test\n",
    "                        images, labels = images.cuda(), labels.cuda()\n",
    "                        outputs = net1(images)\n",
    "                        _, predicted = torch.max(outputs.data, 1)\n",
    "                        total += labels.size(0)\n",
    "                        correct += (predicted == labels).sum().item()\n",
    "\n",
    "                print('Accuracy of the network on the 10000 test images: %d %%' % (\n",
    "                    100 * correct / total))\n",
    "                acc=(100 * correct / total)\n",
    "                acc_array.append(acc)\n",
    "\n",
    "\n",
    "            inputs, labels = data\n",
    "            inputs, labels = inputs.cuda(), labels.cuda()\n",
    "            def closure():\n",
    "              # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "                outputs = net1(inputs)\n",
    "                loss1 = criterion(outputs, labels)\n",
    "                loss1.backward()\n",
    "                return loss1\n",
    "            optimizer.zero_grad()\n",
    "            outputs = net1(inputs)\n",
    "            loss1 = criterion(outputs, labels)\n",
    "\n",
    "            co=co+1\n",
    "            if (co==1):\n",
    "                print(\"Loss:\",float(loss1))\n",
    "                loss_array.append(float(loss1))\n",
    "                print(\"Loss :\",loss1)\n",
    "            loss1.backward()\n",
    "\n",
    "            optimizer.step(closure)\n",
    "            \n",
    "            running_loss += loss1.item()\n",
    "            if i % many == (many-1):    # print every 2000 mini-batches\n",
    "                print(Name + '--- [%d, %5d] loss: %.3f' %\n",
    "                      (epoch + 1, i + 1, running_loss / many))\n",
    "                loss_array.append(running_loss/many)\n",
    "                running_loss = 0.0\n",
    "\n",
    "\n",
    "    print('Finished Training')\n",
    "    return loss_array, acc_array\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "oM0LKyx-kmLZ",
    "outputId": "c4b75eb3-e7e8-4b32-8556-6a7e04eb67ee"
   },
   "outputs": [],
   "source": [
    "seed=2\n",
    "epoch_outer=1\n",
    "epoch_inner=1\n",
    "many=25\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adam1=Net()\n",
    "net_adam2=Net()\n",
    "\n",
    "net_adam1.to(device)\n",
    "net_adam2.to(device)\n",
    "\n",
    "optimizer_adam = optim.Adam(net_adam2.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)\n",
    "loss_inner_adam=[]\n",
    "loss_outer_adam=[]\n",
    "acc_array_adam=[]\n",
    "\n",
    "loss_inner_adam,loss_outer_adam,acc_array_adam = trainCustomSPPA(net_adam1, net_adam2, device, optimizer_adam, seed, epoch_outer, epoch_inner, acc_array_adam, loss_inner_adam, loss_outer_adam, 'ADAM' ,many,mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 69
    },
    "id": "R7s3XyRPkmPG",
    "outputId": "cff73995-3e34-4e87-845f-00036260fce0"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=1\n",
    "epoch_inner=1\n",
    "many=25\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adadelta1=Net()\n",
    "net_adadelta2=Net()\n",
    "\n",
    "net_adadelta1.to(device)\n",
    "net_adadelta2.to(device)\n",
    "\n",
    "optimizer_adadelta = optim.Adadelta(net_adadelta2.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)\n",
    "loss_inner_adadelta=[]\n",
    "loss_outer_adadelta=[]\n",
    "acc_array_adadelta=[]\n",
    "\n",
    "loss_inner_adadelta,loss_outer_adadelta,acc_array_adadelta = trainCustomSPPA(net_adadelta1, net_adadelta2, device, optimizer_adadelta, seed, epoch_outer, epoch_inner, acc_array_adadelta, loss_inner_adadelta, loss_outer_adadelta, 'Adadelta' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 69
    },
    "id": "FFsMaW7-kmSw",
    "outputId": "271f1aba-9957-464d-d350-82ab712a869d"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=1\n",
    "epoch_inner=1\n",
    "many=25\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adagrad1=Net()\n",
    "net_adagrad2=Net()\n",
    "\n",
    "net_adagrad1.to(device)\n",
    "net_adagrad2.to(device)\n",
    "\n",
    "optimizer_adagrad = optim.Adagrad(net_adagrad2.parameters(), lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10)\n",
    "loss_inner_adagrad=[]\n",
    "loss_outer_adagrad=[]\n",
    "acc_array_adagrad=[]\n",
    "\n",
    "loss_inner_adagrad,loss_outer_adagrad,acc_array_adagrad = trainCustomSPPA(net_adagrad1, net_adagrad2, device, optimizer_adagrad, seed, epoch_outer, epoch_inner, acc_array_adagrad, loss_inner_adagrad, loss_outer_adagrad, 'Adagrad' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "jj590XYtkmXF",
    "outputId": "fea39c9d-fe7e-46ec-9e20-225e4162e9cf"
   },
   "outputs": [],
   "source": [
    "seed=2\n",
    "epoch_outer=1\n",
    "epoch_inner=1\n",
    "many=25\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adamW1=Net()\n",
    "net_adamW2=Net()\n",
    "\n",
    "net_adamW1.to(device)\n",
    "net_adamW2.to(device)\n",
    "\n",
    "optimizer_adamW = optim.AdamW(net_adamW2.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)\n",
    "loss_inner_adamW=[]\n",
    "loss_outer_adamW=[]\n",
    "acc_array_adamW=[]\n",
    "\n",
    "loss_inner_adamW,loss_outer_adamW,acc_array_adamW = trainCustomSPPA(net_adamW1, net_adamW2, device, optimizer_adamW, seed, epoch_outer, epoch_inner, acc_array_adamW, loss_inner_adamW, loss_outer_adamW, 'AdamW' ,many,mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "NUucyhJYFU-m",
    "outputId": "113843c5-5512-487c-889a-1259efa32532"
   },
   "outputs": [],
   "source": [
    "seed=2\n",
    "epoch_outer=1\n",
    "epoch_inner=1\n",
    "many=25\n",
    " \n",
    "mode=2\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adamW1_2=Net()\n",
    "net_adamW2_2=Net()\n",
    "\n",
    "net_adamW1_2.to(device)\n",
    "net_adamW2_2.to(device)\n",
    "\n",
    "optimizer_adamW_2 = optim.AdamW(net_adamW2_2.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)\n",
    "loss_inner_adamW_2=[]\n",
    "loss_outer_adamW_2=[]\n",
    "acc_array_adamW_2=[]\n",
    "\n",
    "loss_inner_adamW_2,loss_outer_adamW_2,acc_array_adamW_2 = trainCustomSPPA(net_adamW1_2, net_adamW2_2, device, optimizer_adamW_2, seed, epoch_outer, epoch_inner, acc_array_adamW_2, loss_inner_adamW_2, loss_outer_adamW_2, 'AdamW2' ,many,mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 69
    },
    "id": "kSUE7zSLwSwg",
    "outputId": "ba7beef0-8b66-47fe-dd63-94837aa143d8"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=1\n",
    "epoch_inner=1\n",
    "many=25\n",
    "\n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adamax1=Net()\n",
    "net_adamax2=Net()\n",
    "\n",
    "net_adamax1.to(device)\n",
    "net_adamax2.to(device)\n",
    "\n",
    "optimizer_adamax = optim.Adamax(net_adamax2.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)\n",
    "loss_inner_adamax=[]\n",
    "loss_outer_adamax=[]\n",
    "acc_array_adamax=[]\n",
    "\n",
    "loss_inner_adamax,loss_outer_adamax,acc_array_adamax = trainCustomSPPA(net_adamax1, net_adamax2, device, optimizer_adamax, seed, epoch_outer, epoch_inner, acc_array_adamax, loss_inner_adamax, loss_outer_adamax, 'Adamax' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "id": "LlnIgu4_wS62",
    "outputId": "0eeb1f1e-1a6c-44b6-c8c5-9094fe73f76d"
   },
   "outputs": [],
   "source": [
    "'''\n",
    "seed=2\n",
    "epoch_outer=1\n",
    "epoch_inner=1\n",
    "many=25\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_ASGD1=Net()\n",
    "net_ASGD2=Net()\n",
    "\n",
    "net_ASGD1.to(device)\n",
    "net_ASGD2.to(device)\n",
    "\n",
    "optimizer_ASGD = optim.ASGD(net_ASGD2.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)\n",
    "loss_inner_ASGD=[]\n",
    "loss_outer_ASGD=[]\n",
    "acc_array_ASGD=[]\n",
    "\n",
    "loss_inner_ASGD,loss_outer_ASGD,acc_array_ASGD = trainCustomSPPA(net_ASGD1, net_ASGD2, device, optimizer_ASGD, seed, epoch_outer, epoch_inner, acc_array_ASGD, loss_inner_ASGD, loss_outer_ASGD, 'ASGD' ,many,mode)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "s6pKgawY7KAV",
    "outputId": "442c5521-dda2-46a8-ac5e-29fd6cba71e0"
   },
   "outputs": [],
   "source": [
    "seed=2\n",
    "epoch_outer=1\n",
    "epoch_inner=1\n",
    "many=25\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_lbfgs1=Net()\n",
    "net_lbfgs2=Net()\n",
    "\n",
    "net_lbfgs1.to(device)\n",
    "net_lbfgs2.to(device)\n",
    "\n",
    "optimizer_lbfgs = optim.LBFGS(net_lbfgs2.parameters(),max_iter=2, history_size=2)\n",
    "loss_inner_lbfgs=[]\n",
    "loss_outer_lbfgs=[]\n",
    "acc_array_lbfgs=[]\n",
    "\n",
    "loss_inner_lbfgs,loss_outer_lbfgs,acc_array_lbfgs = trainCustomSPPA(net_lbfgs1, net_lbfgs2, device, optimizer_lbfgs, seed, epoch_outer, epoch_inner, acc_array_lbfgs, loss_inner_lbfgs, loss_outer_lbfgs, 'LBFGS' ,many,mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "Y1OU-aloXMD_",
    "outputId": "df39e4f3-ac31-47e6-c15b-fe4a18a8edba"
   },
   "outputs": [],
   "source": [
    "seed=2\n",
    "epoch_outer=1\n",
    "epoch_inner=1\n",
    "many=25\n",
    " \n",
    "mode=1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_lbfgs1_2=Net()\n",
    "net_lbfgs2_2=Net()\n",
    "\n",
    "net_lbfgs1_2.to(device)\n",
    "net_lbfgs2_2.to(device)\n",
    "\n",
    "optimizer_lbfgs_2 = optim.LBFGS(net_lbfgs2_2.parameters(),max_iter=3 , history_size=3)\n",
    "loss_inner_lbfgs_2=[]\n",
    "loss_outer_lbfgs_2=[]\n",
    "acc_array_lbfgs_2=[]\n",
    "\n",
    "loss_inner_lbfgs_2,loss_outer_lbfgs_2,acc_array_lbfgs_2 = trainCustomSPPA(net_lbfgs1_2, net_lbfgs2_2, device, optimizer_lbfgs_2, seed, epoch_outer, epoch_inner, acc_array_lbfgs_2, loss_inner_lbfgs_2, loss_outer_lbfgs_2, 'LBFGS' ,many,mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "oszUXekz7Mhs",
    "outputId": "3181de43-4002-40c4-8dce-d94dbba34f9a"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "\n",
    "seed=2\n",
    "epoch_no=1\n",
    "\n",
    "many=25\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_sgd_general2=Net()\n",
    "\n",
    "net_sgd_general2.to(device)\n",
    "\n",
    "optimizer_sgd_general2 =optim.SGD(net_sgd_general2.parameters(), lr=0.001)\n",
    "loss_array_sgd_general2=[]\n",
    "acc_array_sgd_general2=[]\n",
    "loss_array_sgd_general2, acc_array_sgd_general2 = traingeneral(net_sgd_general2, device, optimizer_sgd_general2, seed, epoch_no, acc_array_sgd_general2, loss_array_sgd_general2, 'SGD' ,many)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "TnTIc4k-7Mo4",
    "outputId": "81ebdffe-6599-435b-a021-9ec5e53bae97"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "\n",
    "seed=2\n",
    "epoch_no=1\n",
    "\n",
    "many=25\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_sgd_general=Net()\n",
    "\n",
    "net_sgd_general.to(device)\n",
    "\n",
    "optimizer_sgd_general =optim.SGD(net_sgd_general.parameters(), lr=0.001, momentum=0.9)\n",
    "loss_array_sgd_general=[]\n",
    "acc_array_sgd_general=[]\n",
    "loss_array_sgd_general, acc_array_sgd_general = traingeneral(net_sgd_general, device, optimizer_sgd_general, seed, epoch_no, acc_array_sgd_general, loss_array_sgd_general, 'SGD Momentum' ,many)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "GV4GR61n7Msy",
    "outputId": "555772eb-54c0-451c-cc12-a8f988b9ed0b"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "\n",
    "seed=2\n",
    "epoch_no=1\n",
    "\n",
    "many=25\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_ASGD_general=Net()\n",
    "\n",
    "net_ASGD_general.to(device)\n",
    "\n",
    "optimizer_ASGD_general = optim.ASGD(net_ASGD_general.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)\n",
    "loss_array_ASGD_general=[]\n",
    "acc_array_ASGD_general=[]\n",
    "loss_array_ASGD_general, acc_array_ASGD_general = traingeneral(net_ASGD_general, device, optimizer_ASGD_general, seed, epoch_no, acc_array_ASGD_general, loss_array_ASGD_general, 'ASGD' ,many)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "qTY5LLQw7M5k",
    "outputId": "24c2a498-8302-4316-a054-63f1f363b4f9"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "\n",
    "\n",
    "seed=2\n",
    "epoch_no=1\n",
    "\n",
    "many=25\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adagrad_general=Net()\n",
    "\n",
    "net_adagrad_general.to(device)\n",
    "\n",
    "optimizer_adagrad_general = optim.Adagrad(net_adagrad_general.parameters(), lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10)\n",
    "loss_array_adagrad_general=[]\n",
    "acc_array_adagrad_general=[]\n",
    "loss_array_adagrad_general, acc_array_adagrad_general = traingeneral(net_adagrad_general, device, optimizer_adagrad_general, seed, epoch_no, acc_array_adagrad_general, loss_array_adagrad_general, 'Adagrad' ,many)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "3vaF_7N07sJp",
    "outputId": "d4de8954-c7c0-4e30-9391-f753a6fe127e"
   },
   "outputs": [],
   "source": [
    "#traingeneral(net1, device, optimizer, seed, epoch_no, acc_array, loss_array, Name ,many)\n",
    "\n",
    "\n",
    "seed=2\n",
    "epoch_no=1\n",
    "\n",
    "many=25\n",
    " \n",
    "torch.manual_seed(seed)\n",
    "\n",
    "net_adam_general=Net()\n",
    "\n",
    "net_adam_general.to(device)\n",
    "\n",
    "optimizer_adam_general = optim.Adam(net_adam_general.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)\n",
    "loss_array_adam_general=[]\n",
    "acc_array_adam_general=[]\n",
    "loss_array_adam_general, acc_array_adam_general = traingeneral(net_adam_general, device, optimizer_adam_general, seed, epoch_no, acc_array_adam_general, loss_array_adam_general, 'ADAM' ,many)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 657
    },
    "id": "X5LQ0eaodaoX",
    "outputId": "6180d05f-81d9-48b9-b523-1eca9240d5ce"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from google.colab import files\n",
    "X=[]\n",
    "for i in range(len(loss_outer_adam)):\n",
    "    X.append(i)\n",
    "\n",
    "plt.figure(figsize=((12,9)))\n",
    "\n",
    "#plt.plot(X, loss_outer_adadelta, label=\"SPPA-Adadelta (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, loss_outer_adagrad, label=\"SPPA-Adagrad (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, loss_outer_adamax, label=\"SPPA-Adamax (1/sqrt(t))\",linewidth=2.5)\n",
    "plt.plot(X, loss_outer_adam, label=\"SPPA-AGD (1/sqrt(t))\",linewidth=2.5,color='darkorchid')\n",
    "#plt.plot(X, loss_outer_adamW, label=\"SPPA-AdamW (1/sqrt(t))\",linewidth=2.5)\n",
    "\n",
    "plt.plot(X, loss_outer_adamW_2, label=\"SPPA-AGD (1/t)\",linewidth=2.5,color='red')\n",
    "#plt.plot(X, loss_outer_ASGD, label=\"SPPA-Asgd (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, loss_outer_RMSprop, label=\"SPPA-RMSProp (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, loss_outer_lbfgs, label=\"SPPA-LBFGS (1/sqrt(t))\",linewidth=2.5)\n",
    "\n",
    "plt.plot(X, loss_outer_lbfgs_2, label=\"SPPA-LBFGS (1/sqrt(t))\",linewidth=2.5, color='green')\n",
    "plt.plot(X, loss_array_adam_general, label=\"Adam \",linewidth=2.5,color='darkorange')\n",
    "plt.plot(X, loss_array_adagrad_general, label=\"Adagrad \",linewidth=2.5)\n",
    "\n",
    "plt.plot(X, loss_array_sgd_general, label=\"SGD with Momentum \",linewidth=2.5,color='blue')\n",
    "\n",
    "plt.plot(X, loss_array_sgd_general2, label=\"SGD\",linewidth=2.5,color='black')\n",
    "\n",
    "\n",
    "plt.xticks(fontsize=26)\n",
    "plt.yticks(fontsize=26)\n",
    "\n",
    "plt.title(\"MNIST Dataset\\n Convolutional Neural Network\",fontsize=26)\n",
    "plt.xlabel(\"25 Stochastic Update (Batch Size 64) \", fontsize=26)\n",
    "plt.ylabel(\"Cross Entropy Loss\", fontsize=26)\n",
    "plt.grid(linestyle=\"--\")\n",
    "plt.legend(fontsize=16)\n",
    "plt.tight_layout()\n",
    "#plt.savefig(\"Adam-SPPAwLBFGSinner_decreasing.png\")\n",
    "plt.savefig(\"mnist-all-algo-mod1.png\")\n",
    "files.download(\"mnist-all-algo-mod1.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 657
    },
    "id": "rdDlKampwx_W",
    "outputId": "46e698e9-10a7-4245-a579-3e17741129b3"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from google.colab import files\n",
    "\n",
    "X=[]\n",
    "for i in range(len(acc_array_adam)):\n",
    "    X.append(i)\n",
    "\n",
    "plt.figure(figsize=((12,9)))\n",
    "\n",
    "\n",
    "#plt.plot(X, acc_array_adadelta, label=\"SPPA-Adadelta (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, acc_array_adagrad, label=\"SPPA-Adagrad (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, acc_array_adamax, label=\"SPPA-Adamax (1/sqrt(t))\",linewidth=2.5)\n",
    "plt.plot(X, acc_array_adam, label=\"SPPA-AGD (1/sqrt(t))\",linewidth=2.5,color='darkorchid')\n",
    "plt.plot(X, acc_array_adamW_2, label=\"SPPA-AGD (1/t)\",linewidth=2.5, color='red')\n",
    "#plt.plot(X, acc_array_ASGD, label=\"SPPA-Asgd (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, acc_array_RMSprop, label=\"SPPA-RMSProp (1/sqrt(t))\",linewidth=2.5)\n",
    "#plt.plot(X, acc_array_lbfgs, label=\"SPPA-LBFGS (1/sqrt(t))\",linewidth=2.5, color='green')\n",
    "plt.plot(X, acc_array_lbfgs_2, label=\"SPPA-LBFGS (1/sqrt(t))\",linewidth=2.5, color='green')\n",
    "plt.plot(X, acc_array_adam_general, label=\"Adam \",linewidth=2.5,color='darkorange')\n",
    "#plt.plot(X, acc_array_adagrad_general, label=\"Adagrad \",linewidth=2.5)\n",
    "\n",
    "plt.plot(X, acc_array_sgd_general, label=\"SGD with Momentum \",linewidth=2.5,color='blue')\n",
    "\n",
    "plt.plot(X, acc_array_sgd_general2, label=\"SGD\",linewidth=2.5,color='black')\n",
    "\n",
    "\n",
    "plt.xticks(fontsize=26)\n",
    "plt.yticks(fontsize=26)\n",
    "\n",
    "plt.title(\"MNIST Dataset\\n Convolutional Neural Network\",fontsize=26)\n",
    "plt.xlabel(\"25 Stochastic Update (Batch Size 64) \", fontsize=26)\n",
    "plt.ylabel(\"Accuracy \", fontsize=26)\n",
    "plt.grid(linestyle=\"--\")\n",
    "plt.legend(fontsize=16)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"mnist-allaccuracy-algo-mod1.png\")\n",
    "\n",
    "files.download(\"mnist-allaccuracy-algo-mod1.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gzeFHANbUQHG"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Supplementary ICLR2021 - MNIST",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
