{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Default to term_width =  118\n"
     ]
    }
   ],
   "source": [
    "'''Train CIFAR10 with PyTorch.'''\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torch.backends.cudnn as cudnn\n",
    "\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import os\n",
    "import argparse\n",
    "\n",
    "from models import *\n",
    "from utils import progress_bar\n",
    "\n",
    "import pandas as pd \n",
    "from plotnine import * \n",
    "import datetime as dt\n",
    "import pytz\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==> Preparing data..\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "==> Building model..\n",
      "\n",
      "Epoch: 0\n",
      "torch.Size([128, 64, 32, 32])\n",
      "torch.Size([128, 256, 32, 32])\n",
      "torch.Size([128, 64, 32, 32])\n",
      "torch.Size([128, 64, 32, 32])\n",
      "torch.Size([128, 256, 32, 32])\n",
      "torch.Size([128, 256, 32, 32])\n",
      "torch.Size([128, 256, 32, 32])\n",
      "torch.Size([128, 64, 32, 32])\n",
      "torch.Size([128, 64, 32, 32])\n",
      "torch.Size([128, 256, 32, 32])\n",
      "torch.Size([128, 256, 32, 32])\n",
      "torch.Size([128, 256, 32, 32])\n",
      "torch.Size([128, 64, 32, 32])\n",
      "torch.Size([128, 64, 32, 32])\n",
      "torch.Size([128, 256, 32, 32])\n",
      "torch.Size([128, 256, 32, 32])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 128, 32, 32])\n",
      "torch.Size([128, 128, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 128, 16, 16])\n",
      "torch.Size([128, 128, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 128, 16, 16])\n",
      "torch.Size([128, 128, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 128, 16, 16])\n",
      "torch.Size([128, 128, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 512, 16, 16])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 16, 16])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 256, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 1024, 8, 8])\n",
      "torch.Size([128, 2048, 4, 4])\n",
      "torch.Size([128, 512, 8, 8])\n",
      "torch.Size([128, 512, 4, 4])\n",
      "torch.Size([128, 2048, 4, 4])\n",
      "torch.Size([128, 2048, 4, 4])\n",
      "torch.Size([128, 2048, 4, 4])\n",
      "torch.Size([128, 512, 4, 4])\n",
      "torch.Size([128, 512, 4, 4])\n",
      "torch.Size([128, 2048, 4, 4])\n",
      "torch.Size([128, 2048, 4, 4])\n",
      "torch.Size([128, 2048, 4, 4])\n",
      "torch.Size([128, 512, 4, 4])\n",
      "torch.Size([128, 512, 4, 4])\n",
      "torch.Size([128, 2048, 4, 4])\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "# parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')\n",
    "# parser.add_argument('--lr', default=0.1, type=float, help='learning rate')\n",
    "# parser.add_argument('--resume', '-r', action='store_true',\n",
    "#                     help='resume from checkpoint')\n",
    "# args = parser.parse_args()\n",
    "\n",
    "args_resume = False\n",
    "args_lr = 1e-2\n",
    "\n",
    "if not os.path.isdir('csv'):\n",
    "    os.mkdir('csv')\n",
    "\n",
    "utc_now = dt.datetime.now()\n",
    "est_timezone = pytz.timezone('America/New_York')\n",
    "est_now = utc_now.astimezone(est_timezone)\n",
    "dt_str = est_now.strftime('%m%d%Y') + '_' + est_now.strftime('%I%M%p')\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "best_acc = 0  # best test accuracy\n",
    "start_epoch = 0  # start from epoch 0 or last checkpoint epoch\n",
    "\n",
    "# Data\n",
    "print('==> Preparing data..')\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(\n",
    "    root='./data', train=True, download=True, transform=transform_train)\n",
    "trainloader = torch.utils.data.DataLoader(\n",
    "    trainset, batch_size=128, shuffle=True, num_workers=2)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(\n",
    "    root='./data', train=False, download=True, transform=transform_test)\n",
    "testloader = torch.utils.data.DataLoader(\n",
    "    testset, batch_size=100, shuffle=False, num_workers=2)\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat', 'deer',\n",
    "           'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "\n",
    "# Model\n",
    "print('==> Building model..')\n",
    "# net = VGG('VGG19')\n",
    "# net = ResNet18()\n",
    "\n",
    "# net = PreActResNet18()\n",
    "# net = BalancedPreActResNet18()\n",
    "net = PreActResNet101()\n",
    "# net = BalancedPreActResNet101()\n",
    "\n",
    "# net = GoogLeNet()\n",
    "# net = DenseNet121()\n",
    "# net = ResNeXt29_2x64d()\n",
    "# net = MobileNet()\n",
    "# net = MobileNetV2()\n",
    "# net = DPN92()\n",
    "# net = ShuffleNetG2()\n",
    "# net = SENet18()\n",
    "# net = ShuffleNetV2(1)\n",
    "# net = EfficientNetB0()\n",
    "# net = RegNetX_200MF()\n",
    "# net = SimpleDLA()\n",
    "net = net.to(device)\n",
    "if device == 'cuda':\n",
    "    net = torch.nn.DataParallel(net)\n",
    "    cudnn.benchmark = True\n",
    "\n",
    "if args_resume:\n",
    "    # Load checkpoint.\n",
    "    print('==> Resuming from checkpoint..')\n",
    "    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'\n",
    "    checkpoint = torch.load('./checkpoint/ckpt2.pth')\n",
    "    net.load_state_dict(checkpoint['net'])\n",
    "    best_acc = checkpoint['acc']\n",
    "    start_epoch = checkpoint['epoch']\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=args_lr,\n",
    "                      momentum=0.9, weight_decay=5e-4)\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)\n",
    "\n",
    "\n",
    "# Training\n",
    "def train(epoch):\n",
    "    print('\\nEpoch: %d' % epoch)\n",
    "    net.train()\n",
    "    train_loss = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    row_list = []\n",
    "    for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
    "        inputs, targets = inputs.to(device), targets.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(inputs)\n",
    "        return\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss += loss.item()\n",
    "        _, predicted = outputs.max(1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets).sum().item()\n",
    "        \n",
    "        row_list += [{ 'epoch': epoch, 'batch': batch_idx, \n",
    "                     'train loss': train_loss / (batch_idx + 1), \n",
    "                     'train accuracy': 100.*correct/total}]\n",
    "\n",
    "    df_row = pd.DataFrame(row_list)\n",
    "#     df_row.to_csv('./csv/balanced_train_' + dt_str + '.csv', \n",
    "#                   mode='a', header = (epoch == 0), index = False )\n",
    "\n",
    "    progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n",
    "                 % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))\n",
    "\n",
    "\n",
    "def test(epoch):\n",
    "    global best_acc\n",
    "    net.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    row_list = []\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (inputs, targets) in enumerate(testloader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = net(inputs)\n",
    "            loss = criterion(outputs, targets)\n",
    "\n",
    "            test_loss += loss.item()\n",
    "            _, predicted = outputs.max(1)\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "            \n",
    "            row_list += [{ 'epoch': epoch, 'batch': batch_idx, \n",
    "                         'test loss': test_loss / (batch_idx + 1), \n",
    "                         'test accuracy': 100.*correct/total}]\n",
    "\n",
    "        df_row = pd.DataFrame(row_list)\n",
    "#         df_row.to_csv('./csv/balanced_test_' + dt_str + '.csv', \n",
    "#                       mode='a', header = (epoch == 0), index = False)\n",
    "\n",
    "        progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n",
    "                     % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))\n",
    "\n",
    "    # Save checkpoint.\n",
    "    acc = 100.*correct/total\n",
    "    if acc > best_acc:\n",
    "        print('Saving..')\n",
    "        state = {\n",
    "            'net': net.state_dict(),\n",
    "            'acc': acc,\n",
    "            'epoch': epoch,\n",
    "        }\n",
    "        if not os.path.isdir('checkpoint'):\n",
    "            os.mkdir('checkpoint')\n",
    "        torch.save(state, './checkpoint/ckpt2.pth')\n",
    "        best_acc = acc\n",
    "\n",
    "\n",
    "for epoch in range(start_epoch, start_epoch+1):\n",
    "    train(epoch)\n",
    "#     test(epoch)\n",
    "#     scheduler.step()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
