{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dfd7fa2e",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "# Loss-Controlled Asymmetric Momentum(LCAM)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ea08a5c",
   "metadata": {},
   "source": [
    "# Cifar10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b1abe94",
   "metadata": {},
   "source": [
    "# Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f1c4b69",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from torchvision import models\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "\n",
    "#Data Preprocessing#\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "])\n",
    "\n",
    "#Load Dataset#\n",
    "train_set = torchvision.datasets.CIFAR10(root='./', train=True, download = True, transform=transform_train)\n",
    "test_set = torchvision.datasets.CIFAR10(root='./', train=False, download = True, transform=transform_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "325b7ca9",
   "metadata": {},
   "source": [
    "# Wide Residual Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39b7ad9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BasicBlock(nn.Module):\n",
    "    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.bn1 = nn.BatchNorm2d(in_planes)\n",
    "        self.relu1 = nn.ReLU(inplace=True)\n",
    "        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
    "                               padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(out_planes)\n",
    "        self.relu2 = nn.ReLU(inplace=True)\n",
    "        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,\n",
    "                               padding=1, bias=False)\n",
    "        self.droprate = dropRate\n",
    "        self.equalInOut = (in_planes == out_planes)\n",
    "        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,\n",
    "                               padding=0, bias=False) or None\n",
    "    def forward(self, x):\n",
    "        if not self.equalInOut:\n",
    "            x = self.relu1(self.bn1(x))\n",
    "        else:\n",
    "            out = self.relu1(self.bn1(x))\n",
    "        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))\n",
    "        if self.droprate > 0:\n",
    "            out = F.dropout(out, p=self.droprate, training=self.training)\n",
    "        out = self.conv2(out)\n",
    "        return torch.add(x if self.equalInOut else self.convShortcut(x), out)\n",
    "\n",
    "class NetworkBlock(nn.Module):\n",
    "    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):\n",
    "        super(NetworkBlock, self).__init__()\n",
    "        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)\n",
    "    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):\n",
    "        layers = []\n",
    "        for i in range(int(nb_layers)):\n",
    "            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))\n",
    "        return nn.Sequential(*layers)\n",
    "    def forward(self, x):\n",
    "        return self.layer(x)\n",
    "\n",
    "class WideResNet(nn.Module):\n",
    "    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):\n",
    "        super(WideResNet, self).__init__()\n",
    "        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]\n",
    "        assert((depth - 4) % 6 == 0)\n",
    "        n = (depth - 4) / 6\n",
    "        block = BasicBlock\n",
    "        # 1st conv before any network block\n",
    "        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,\n",
    "                               padding=1, bias=False)\n",
    "        # 1st block\n",
    "        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)\n",
    "        # 2nd block\n",
    "        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)\n",
    "        # 3rd block\n",
    "        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)\n",
    "        # global average pooling and classifier\n",
    "        self.bn1 = nn.BatchNorm2d(nChannels[3])\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.fc = nn.Linear(nChannels[3], num_classes)\n",
    "        self.nChannels = nChannels[3]\n",
    "\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                m.weight.data.fill_(1)\n",
    "                m.bias.data.zero_()\n",
    "            elif isinstance(m, nn.Linear):\n",
    "                m.bias.data.zero_()\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = self.block1(out)\n",
    "        out = self.block2(out)\n",
    "        out = self.block3(out)\n",
    "        out = self.relu(self.bn1(out))\n",
    "        out = F.avg_pool2d(out, 8)\n",
    "        out = out.view(-1, self.nChannels)\n",
    "        return self.fc(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15609ebf",
   "metadata": {},
   "source": [
    "# 01 Figure 3 - Group 1 - Fixed0.9 - Black Line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6aafc067",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "learningratevalue = 0.1\n",
    "weightdecayvalue = 5e-4\n",
    "momentumvalue = 0.9\n",
    "\n",
    "num_epochs = 150\n",
    "iteration = 0\n",
    "writer = SummaryWriter(\"./runs/Cifar10-0.1-m0.9-*0.2_30_60_90\")\n",
    "\n",
    "model = WideResNet(28, 10, 10, dropRate=0.0)\n",
    "model = model.to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr = learningratevalue, momentum = momentumvalue, weight_decay = weightdecayvalue)#, nesterov = True)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    running_loss = 0.0\n",
    "    lossaverage = 0\n",
    "    model.train()\n",
    "    for i, data in enumerate(train_loader, 0):\n",
    "        print(\"\\r                                                                                             \",end=\"\")\n",
    "        print(\"\\rTraining Epoch {} ... {:.2f}% ... Current Learning Rate: {}\".format(epoch + 1, 100*i/len(train_loader), learningratevalue),end=\"\")\n",
    "        inputs, labels = data\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "        lossaverage = running_loss / (i + 1)\n",
    "        iteration = iteration + 1\n",
    "        \n",
    "        ###LCAM###\n",
    "        #if loss.item() < lossaverage:\n",
    "        #    optimizer.param_groups[0]['momentum'] = 0.95 #Right Sparse\n",
    "        #else:\n",
    "        #    optimizer.param_groups[0]['momentum'] = 0.9 #Left Non-Sparse\n",
    "        ###LCAM###\n",
    "        \n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        model.eval()\n",
    "        i = 0\n",
    "        for data, target in test_loader:\n",
    "            print(\"\\r                                                                                             \",end=\"\")\n",
    "            print(\"\\rTesting Epoch {} .... {:.2f}%\".format(epoch + 1, 100*i/len(test_loader)),end=\"\")\n",
    "            i = i + 1\n",
    "            images = data.to(device)\n",
    "            labels = target.to(device)\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "        test_acc = 100 * correct / total\n",
    "        \n",
    "    writer.add_scalar('Learning Rate / Epoch', learningratevalue, epoch + 1)\n",
    "    writer.add_scalar('Weight Decay / Epoch', weightdecayvalue, epoch + 1)\n",
    "    writer.add_scalar('TestError / Epoch', 100 - test_acc, epoch + 1)\n",
    "    writer.add_scalar('TrainLoss / Epoch', lossaverage, epoch + 1)\n",
    "    \n",
    "    print('\\rEpoch [{}/{}] Iteration [{}] Test Error: {:.2f}% Loss: {:.6f} Learning Rate: {}'.format(epoch + 1, num_epochs, iteration, 100 - test_acc, lossaverage, learningratevalue))    \n",
    "    if epoch == 30:\n",
    "        learningratevalue = 0.02\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue\n",
    "    if epoch == 60:\n",
    "        learningratevalue = 0.004\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue\n",
    "    if epoch == 90:\n",
    "        learningratevalue = 0.0008\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41a308d7",
   "metadata": {},
   "source": [
    "# 02 Figure 3 - Group 2 - Fixed0.95 - Blue Line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f1840ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "learningratevalue = 0.1\n",
    "weightdecayvalue = 5e-4\n",
    "momentumvalue = 0.95\n",
    "\n",
    "num_epochs = 150\n",
    "iteration = 0\n",
    "writer = SummaryWriter(\"./runs/Cifar10-0.1-m0.95-*0.2_30_60_90\")\n",
    "\n",
    "model = WideResNet(28, 10, 10, dropRate=0.0)\n",
    "model = model.to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr = learningratevalue, momentum = momentumvalue, weight_decay = weightdecayvalue)#, nesterov = True)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    running_loss = 0.0\n",
    "    lossaverage = 0\n",
    "    model.train()\n",
    "    for i, data in enumerate(train_loader, 0):\n",
    "        print(\"\\r                                                                                             \",end=\"\")\n",
    "        print(\"\\rTraining Epoch {} ... {:.2f}% ... Current Learning Rate: {}\".format(epoch + 1, 100*i/len(train_loader), learningratevalue),end=\"\")\n",
    "        inputs, labels = data\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "        lossaverage = running_loss / (i + 1)\n",
    "        iteration = iteration + 1\n",
    "        \n",
    "        ###LCAM###\n",
    "        #if loss.item() < lossaverage:\n",
    "        #    optimizer.param_groups[0]['momentum'] = 0.95 #Right Sparse\n",
    "        #else:\n",
    "        #    optimizer.param_groups[0]['momentum'] = 0.9 #Left Non-Sparse\n",
    "        ###LCAM###\n",
    "        \n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        model.eval()\n",
    "        i = 0\n",
    "        for data, target in test_loader:\n",
    "            print(\"\\r                                                                                             \",end=\"\")\n",
    "            print(\"\\rTesting Epoch {} .... {:.2f}%\".format(epoch + 1, 100*i/len(test_loader)),end=\"\")\n",
    "            i = i + 1\n",
    "            images = data.to(device)\n",
    "            labels = target.to(device)\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "        test_acc = 100 * correct / total\n",
    "        \n",
    "    writer.add_scalar('Learning Rate / Epoch', learningratevalue, epoch + 1)\n",
    "    writer.add_scalar('Weight Decay / Epoch', weightdecayvalue, epoch + 1)\n",
    "    writer.add_scalar('TestError / Epoch', 100 - test_acc, epoch + 1)\n",
    "    writer.add_scalar('TrainLoss / Epoch', lossaverage, epoch + 1)\n",
    "    \n",
    "    print('\\rEpoch [{}/{}] Iteration [{}] Test Error: {:.2f}% Loss: {:.6f} Learning Rate: {}'.format(epoch + 1, num_epochs, iteration, 100 - test_acc, lossaverage, learningratevalue))    \n",
    "    if epoch == 30:\n",
    "        learningratevalue = 0.02\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue\n",
    "    if epoch == 60:\n",
    "        learningratevalue = 0.004\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue\n",
    "    if epoch == 90:\n",
    "        learningratevalue = 0.0008\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26023b7f",
   "metadata": {},
   "source": [
    "# 03 Figure 3 - Group 3 - 0.95_0.9 - Green Line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f79ab6a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "learningratevalue = 0.1\n",
    "weightdecayvalue = 5e-4\n",
    "momentumvalue = 0.9\n",
    "\n",
    "num_epochs = 150\n",
    "iteration = 0\n",
    "writer = SummaryWriter(\"./runs/Cifar10-0.1-m0.95_0.9-*0.2_30_60_90\")\n",
    "\n",
    "model = WideResNet(28, 10, 10, dropRate=0.0)\n",
    "model = model.to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr = learningratevalue, momentum = momentumvalue, weight_decay = weightdecayvalue)#, nesterov = True)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    running_loss = 0.0\n",
    "    lossaverage = 0\n",
    "    model.train()\n",
    "    for i, data in enumerate(train_loader, 0):\n",
    "        print(\"\\r                                                                                             \",end=\"\")\n",
    "        print(\"\\rTraining Epoch {} ... {:.2f}% ... Current Learning Rate: {}\".format(epoch + 1, 100*i/len(train_loader), learningratevalue),end=\"\")\n",
    "        inputs, labels = data\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "        lossaverage = running_loss / (i + 1)\n",
    "        iteration = iteration + 1\n",
    "        \n",
    "        ###LCAM###\n",
    "        if loss.item() < lossaverage:\n",
    "            optimizer.param_groups[0]['momentum'] = 0.95 #Right Sparse\n",
    "        else:\n",
    "            optimizer.param_groups[0]['momentum'] = 0.9 #Left Non-Sparse\n",
    "        ###LCAM###\n",
    "        \n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        model.eval()\n",
    "        i = 0\n",
    "        for data, target in test_loader:\n",
    "            print(\"\\r                                                                                             \",end=\"\")\n",
    "            print(\"\\rTesting Epoch {} .... {:.2f}%\".format(epoch + 1, 100*i/len(test_loader)),end=\"\")\n",
    "            i = i + 1\n",
    "            images = data.to(device)\n",
    "            labels = target.to(device)\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "        test_acc = 100 * correct / total\n",
    "        \n",
    "    writer.add_scalar('Learning Rate / Epoch', learningratevalue, epoch + 1)\n",
    "    writer.add_scalar('Weight Decay / Epoch', weightdecayvalue, epoch + 1)\n",
    "    writer.add_scalar('TestError / Epoch', 100 - test_acc, epoch + 1)\n",
    "    writer.add_scalar('TrainLoss / Epoch', lossaverage, epoch + 1)\n",
    "    \n",
    "    print('\\rEpoch [{}/{}] Iteration [{}] Test Error: {:.2f}% Loss: {:.6f} Learning Rate: {}'.format(epoch + 1, num_epochs, iteration, 100 - test_acc, lossaverage, learningratevalue))    \n",
    "    if epoch == 30:\n",
    "        learningratevalue = 0.02\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue\n",
    "    if epoch == 60:\n",
    "        learningratevalue = 0.004\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue\n",
    "    if epoch == 90:\n",
    "        learningratevalue = 0.0008\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6ef45b6",
   "metadata": {},
   "source": [
    "# 04 Figure 3 - Group 4 - 0.9_0.95 - Red Line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cf2537c",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "learningratevalue = 0.1\n",
    "weightdecayvalue = 5e-4\n",
    "momentumvalue = 0.9\n",
    "\n",
    "num_epochs = 150\n",
    "iteration = 0\n",
    "writer = SummaryWriter(\"./runs/Cifar10-0.1-m0.9_0.95-*0.2_30_60_90\")\n",
    "\n",
    "model = WideResNet(28, 10, 10, dropRate=0.0)\n",
    "model = model.to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr = learningratevalue, momentum = momentumvalue, weight_decay = weightdecayvalue)#, nesterov = True)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    running_loss = 0.0\n",
    "    lossaverage = 0\n",
    "    model.train()\n",
    "    for i, data in enumerate(train_loader, 0):\n",
    "        print(\"\\r                                                                                             \",end=\"\")\n",
    "        print(\"\\rTraining Epoch {} ... {:.2f}% ... Current Learning Rate: {}\".format(epoch + 1, 100*i/len(train_loader), learningratevalue),end=\"\")\n",
    "        inputs, labels = data\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "        lossaverage = running_loss / (i + 1)\n",
    "        iteration = iteration + 1\n",
    "        \n",
    "        ###LCAM###\n",
    "        if loss.item() < lossaverage:\n",
    "            optimizer.param_groups[0]['momentum'] = 0.9 #Right Sparse\n",
    "        else:\n",
    "            optimizer.param_groups[0]['momentum'] = 0.95 #Left Non-Sparse\n",
    "        ###LCAM###\n",
    "        \n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        model.eval()\n",
    "        i = 0\n",
    "        for data, target in test_loader:\n",
    "            print(\"\\r                                                                                             \",end=\"\")\n",
    "            print(\"\\rTesting Epoch {} .... {:.2f}%\".format(epoch + 1, 100*i/len(test_loader)),end=\"\")\n",
    "            i = i + 1\n",
    "            images = data.to(device)\n",
    "            labels = target.to(device)\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "        test_acc = 100 * correct / total\n",
    "        \n",
    "    writer.add_scalar('Learning Rate / Epoch', learningratevalue, epoch + 1)\n",
    "    writer.add_scalar('Weight Decay / Epoch', weightdecayvalue, epoch + 1)\n",
    "    writer.add_scalar('TestError / Epoch', 100 - test_acc, epoch + 1)\n",
    "    writer.add_scalar('TrainLoss / Epoch', lossaverage, epoch + 1)\n",
    "    \n",
    "    print('\\rEpoch [{}/{}] Iteration [{}] Test Error: {:.2f}% Loss: {:.6f} Learning Rate: {}'.format(epoch + 1, num_epochs, iteration, 100 - test_acc, lossaverage, learningratevalue))    \n",
    "    if epoch == 30:\n",
    "        learningratevalue = 0.02\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue\n",
    "    if epoch == 60:\n",
    "        learningratevalue = 0.004\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue\n",
    "    if epoch == 90:\n",
    "        learningratevalue = 0.0008\n",
    "        optimizer.param_groups[0]['lr'] = learningratevalue"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cadd9a2c",
   "metadata": {},
   "source": [
    "# 05 Figure 5 - Group 1 - 0.9 - Black Line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a5c9db",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "learningratevalue = 0.1\n",
    "weightdecayvalue = 5e-4\n",
    "momentumvalue = 0.9\n",
    "\n",
    "num_epochs = 150\n",
    "iteration = 0\n",
    "writer = SummaryWriter(\"./runs/Cifar10-0.1-m0.9-30*0.99985\")\n",
    "\n",
    "model = WideResNet(28, 10, 10, dropRate=0.0)\n",
    "model = model.to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr = learningratevalue, momentum = momentumvalue, weight_decay = weightdecayvalue)#, nesterov = True)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    running_loss = 0.0\n",
    "    lossaverage = 0\n",
    "    model.train()\n",
    "    for i, data in enumerate(train_loader, 0):\n",
    "        print(\"\\r                                                                                             \",end=\"\")\n",
    "        if epoch > 30:\n",
    "            learningratevalue = 0.99985 * learningratevalue #0.99993441\n",
    "            optimizer.param_groups[0]['lr'] = learningratevalue        \n",
    "        print(\"\\rTraining Epoch {} ... {:.2f}% ... Current Learning Rate: {}\".format(epoch + 1, 100*i/len(train_loader), learningratevalue),end=\"\")\n",
    "        inputs, labels = data\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "        lossaverage = running_loss / (i + 1)\n",
    "        iteration = iteration + 1\n",
    "        \n",
    "        ###LCAM###\n",
    "        #if loss.item() < lossaverage:\n",
    "            #optimizer.param_groups[0]['momentum'] = 0.9 #Right Sparse\n",
    "        #else:\n",
    "            #optimizer.param_groups[0]['momentum'] = 0.95 #Left Non-Sparse\n",
    "        ###LCAM###\n",
    "        \n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        model.eval()\n",
    "        i = 0\n",
    "        for data, target in test_loader:\n",
    "            print(\"\\r                                                                                             \",end=\"\")\n",
    "            print(\"\\rTesting Epoch {} .... {:.2f}%\".format(epoch + 1, 100*i/len(test_loader)),end=\"\")\n",
    "            i = i + 1\n",
    "            images = data.to(device)\n",
    "            labels = target.to(device)\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "        test_acc = 100 * correct / total\n",
    "        \n",
    "    writer.add_scalar('Learning Rate / Epoch', learningratevalue, epoch + 1)\n",
    "    writer.add_scalar('Weight Decay / Epoch', weightdecayvalue, epoch + 1)\n",
    "    writer.add_scalar('TestError / Epoch', 100 - test_acc, epoch + 1)\n",
    "    writer.add_scalar('TrainLoss / Epoch', lossaverage, epoch + 1)\n",
    "    \n",
    "    print('\\rEpoch [{}/{}] Iteration [{}] Test Error: {:.2f}% Loss: {:.6f} Learning Rate: {}'.format(epoch + 1, num_epochs, iteration, 100 - test_acc, lossaverage, learningratevalue))    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0006446",
   "metadata": {},
   "source": [
    "# 06 Figure 5 - Group 2 - 0.9_0.95 - Green Line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cc3181a",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "learningratevalue = 0.1\n",
    "weightdecayvalue = 5e-4\n",
    "momentumvalue = 0.9\n",
    "\n",
    "num_epochs = 150\n",
    "iteration = 0\n",
    "writer = SummaryWriter(\"./runs/Cifar10-0.1-m0.9_0.95-30*0.99985\")\n",
    "\n",
    "model = WideResNet(28, 10, 10, dropRate=0.0)\n",
    "model = model.to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr = learningratevalue, momentum = momentumvalue, weight_decay = weightdecayvalue)#, nesterov = True)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    running_loss = 0.0\n",
    "    lossaverage = 0\n",
    "    model.train()\n",
    "    for i, data in enumerate(train_loader, 0):\n",
    "        print(\"\\r                                                                                             \",end=\"\")\n",
    "        if epoch > 30:\n",
    "            learningratevalue = 0.99985 * learningratevalue #0.99993441\n",
    "            optimizer.param_groups[0]['lr'] = learningratevalue        \n",
    "        print(\"\\rTraining Epoch {} ... {:.2f}% ... Current Learning Rate: {}\".format(epoch + 1, 100*i/len(train_loader), learningratevalue),end=\"\")\n",
    "        inputs, labels = data\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "        lossaverage = running_loss / (i + 1)\n",
    "        iteration = iteration + 1\n",
    "        \n",
    "        ###LCAM###\n",
    "        if loss.item() < lossaverage:\n",
    "            optimizer.param_groups[0]['momentum'] = 0.9 #Right Sparse\n",
    "        else:\n",
    "            optimizer.param_groups[0]['momentum'] = 0.95 #Left Non-Sparse\n",
    "        ###LCAM###\n",
    "        \n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        model.eval()\n",
    "        i = 0\n",
    "        for data, target in test_loader:\n",
    "            print(\"\\r                                                                                             \",end=\"\")\n",
    "            print(\"\\rTesting Epoch {} .... {:.2f}%\".format(epoch + 1, 100*i/len(test_loader)),end=\"\")\n",
    "            i = i + 1\n",
    "            images = data.to(device)\n",
    "            labels = target.to(device)\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "        test_acc = 100 * correct / total\n",
    "        \n",
    "    writer.add_scalar('Learning Rate / Epoch', learningratevalue, epoch + 1)\n",
    "    writer.add_scalar('Weight Decay / Epoch', weightdecayvalue, epoch + 1)\n",
    "    writer.add_scalar('TestError / Epoch', 100 - test_acc, epoch + 1)\n",
    "    writer.add_scalar('TrainLoss / Epoch', lossaverage, epoch + 1)\n",
    "    \n",
    "    print('\\rEpoch [{}/{}] Iteration [{}] Test Error: {:.2f}% Loss: {:.6f} Learning Rate: {}'.format(epoch + 1, num_epochs, iteration, 100 - test_acc, lossaverage, learningratevalue))    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e57c62ef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
