{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "import copy\n",
    "import logging\n",
    "import os\n",
    "import time\n",
    "\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from apex import amp\n",
    "\n",
    "from preact_resnet import PreActResNet18\n",
    "from utils import (upper_limit, lower_limit, std, clamp, get_loaders,\n",
    "    attack_pgd, evaluate_pgd, evaluate_standard)\n",
    "from extragradient import *\n",
    "logger = logging.getLogger(__name__)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PPM Trajectory Via SGD "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_args():\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument('--batch-size', default=128, type=int)\n",
    "    parser.add_argument('--data-dir', default='../../cifar-data', type=str)\n",
    "    parser.add_argument('--epochs', default=100, type=int)\n",
    "    parser.add_argument('--lr-schedule', default='cyclic', choices=['cyclic', 'multistep'])\n",
    "    parser.add_argument('--lr-min', default=0.0005, type=float) # 0.0\n",
    "    parser.add_argument('--lr-max', default=0.0005, type=float) # 0.01\n",
    "    parser.add_argument('--weight-decay', default=1e-4, type=float)\n",
    "    parser.add_argument('--momentum', default=0.0, type=float)\n",
    "    #parser.add_argument('--epsilon', default=8, type=int)\n",
    "    #parser.add_argument('--alpha', default=10, type=float, help='Step size')\n",
    "    parser.add_argument('--delta-init', default='random', choices=['zero', 'random', 'previous'],\n",
    "        help='Perturbation initialization method')\n",
    "    parser.add_argument('--out-dir', default='train_fgsm_output', type=str, help='Output directory')\n",
    "    parser.add_argument('--seed', default=0, type=int, help='Random seed')\n",
    "    parser.add_argument('--early-stop', action='store_true', help='Early stop if overfitting occurs')\n",
    "    parser.add_argument('--opt-level', default='O2', type=str, choices=['O0', 'O1', 'O2'],\n",
    "        help='O0 is FP32 training, O1 is Mixed Precision, and O2 is \"Almost FP16\" Mixed Precision')\n",
    "    parser.add_argument('--loss-scale', default='1.0', type=str, choices=['1.0', 'dynamic'],\n",
    "        help='If loss_scale is \"dynamic\", adaptively adjust the loss scale over time')\n",
    "    parser.add_argument('--master-weights', action='store_true',\n",
    "        help='Maintain FP32 master weights to accompany any FP16 model weights, not applicable for O1 opt level')\n",
    "    args, unknown = parser.parse_known_args()\n",
    "\n",
    "    return args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_eval = True\n",
    "\n",
    "args = get_args()\n",
    "np.random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)\n",
    "torch.cuda.manual_seed(args.seed)\n",
    "\n",
    "train_loader, test_loader = get_loaders(args.data_dir, args.batch_size)\n",
    "model = PreActResNet18().cuda()\n",
    "state_dict = torch.load(os.path.join(args.out_dir, '9_25_robust_model_epsilon_8.pth'))\n",
    "model.load_state_dict(state_dict)\n",
    "opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)\n",
    "amp_args = dict(opt_level=args.opt_level, loss_scale=args.loss_scale, verbosity=False)\n",
    "if args.opt_level == 'O2':\n",
    "    amp_args['master_weights'] = args.master_weights\n",
    "model, opt = amp.initialize(model, opt, **amp_args)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "lr_steps = args.epochs * len(train_loader)\n",
    "scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=args.lr_min, max_lr=args.lr_max,\n",
    "                                              step_size_up=lr_steps / 2, step_size_down=lr_steps / 2)\n",
    "\n",
    "if run_eval:\n",
    "    model.eval()\n",
    "    start_eval_time = time.time()\n",
    "    test_loss, test_acc = evaluate_standard(test_loader, model)\n",
    "    pgd_loss, pgd_acc = evaluate_pgd(test_loader, model, 50, 10)\n",
    "    end_eval_time = time.time()\n",
    "    print('Test Loss \\t Test Acc \\t PGD Loss \\t PGD Acc')\n",
    "    print('%.4f \\t %.4f \\t %.4f \\t %.4f' % (test_loss, test_acc, pgd_loss, pgd_acc))\n",
    "    print('Total evaluation time: %.4f minutes' % ((end_eval_time - start_eval_time)/60))\n",
    "\n",
    "\n",
    "# Training\n",
    "prev_robust_acc = 0.\n",
    "start_train_time = time.time()\n",
    "logger.info('Epoch \\t Seconds \\t LR \\t \\t Train Loss \\t Train Acc')\n",
    "print('Epoch \\t Seconds \\t LR \\t \\t Train Loss \\t Train Acc')\n",
    "for epoch in range(args.epochs):\n",
    "    start_epoch_time = time.time()\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    train_acc = 0\n",
    "    train_n = 0\n",
    "    for i, (X, y) in enumerate(train_loader):\n",
    "        X, y = X.cuda(), y.cuda()\n",
    "        output = model(X)\n",
    "        loss = criterion(output, y)\n",
    "        opt.zero_grad()\n",
    "        with amp.scale_loss(loss, opt) as scaled_loss:\n",
    "            scaled_loss.backward()\n",
    "        opt.step()\n",
    "        train_loss += loss.item() * y.size(0)\n",
    "        train_acc += (output.max(1)[1] == y).sum().item()\n",
    "        train_n += y.size(0)\n",
    "        scheduler.step()\n",
    "    epoch_time = time.time()\n",
    "    lr = scheduler.get_lr()[0]\n",
    "    print('%d \\t %.1f \\t \\t %.4f \\t %.4f \\t %.4f' % (epoch, epoch_time - start_epoch_time, lr, train_loss/train_n, train_acc/train_n))\n",
    "  \n",
    "    if run_eval:\n",
    "        model.eval()\n",
    "        start_eval_time = time.time()\n",
    "        test_loss, test_acc = evaluate_standard(test_loader, model)\n",
    "        pgd_loss, pgd_acc = evaluate_pgd(test_loader, model, 50, 10)\n",
    "        end_eval_time = time.time()\n",
    "        print('Test Loss \\t Test Acc \\t PGD Loss \\t PGD Acc')\n",
    "        print('%.4f \\t %.4f \\t %.4f \\t %.4f' % (test_loss, test_acc, pgd_loss, pgd_acc))\n",
    "        print('Total evaluation time: %.4f minutes' % ((end_eval_time - start_eval_time)/60))\n",
    "\n",
    "train_time = time.time()    \n",
    "print('Total train time: %.4f minutes' % ((train_time - start_train_time)/60))\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## PPM Trajectory Via ExtraSGD "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_args():\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument('--batch-size', default=128, type=int)\n",
    "    parser.add_argument('--data-dir', default='../../cifar-data', type=str)\n",
    "    parser.add_argument('--epochs', default=200, type=int)\n",
    "    parser.add_argument('--lr-schedule', default='cyclic', choices=['cyclic', 'multistep'])\n",
    "    parser.add_argument('--lr-min', default=0.0005, type=float) # 0.0\n",
    "    parser.add_argument('--lr-max', default=0.0005, type=float) # 0.01\n",
    "    parser.add_argument('--weight-decay', default=1e-4, type=float)\n",
    "    parser.add_argument('--momentum', default=0.0, type=float)\n",
    "    #parser.add_argument('--epsilon', default=8, type=int)\n",
    "    #parser.add_argument('--alpha', default=10, type=float, help='Step size')\n",
    "    parser.add_argument('--delta-init', default='random', choices=['zero', 'random', 'previous'],\n",
    "        help='Perturbation initialization method')\n",
    "    parser.add_argument('--out-dir', default='train_fgsm_output', type=str, help='Output directory')\n",
    "    parser.add_argument('--seed', default=0, type=int, help='Random seed')\n",
    "    parser.add_argument('--early-stop', action='store_true', help='Early stop if overfitting occurs')\n",
    "    parser.add_argument('--opt-level', default='O2', type=str, choices=['O0', 'O1', 'O2'],\n",
    "        help='O0 is FP32 training, O1 is Mixed Precision, and O2 is \"Almost FP16\" Mixed Precision')\n",
    "    parser.add_argument('--loss-scale', default='1.0', type=str, choices=['1.0', 'dynamic'],\n",
    "        help='If loss_scale is \"dynamic\", adaptively adjust the loss scale over time')\n",
    "    parser.add_argument('--master-weights', action='store_true',\n",
    "        help='Maintain FP32 master weights to accompany any FP16 model weights, not applicable for O1 opt level')\n",
    "    args, unknown = parser.parse_known_args()\n",
    "\n",
    "    return args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "args = get_args()\n",
    "\n",
    "train_loader, test_loader = get_loaders(args.data_dir, args.batch_size)\n",
    "model = PreActResNet18().cuda()\n",
    "state_dict = torch.load(os.path.join(args.out_dir, '9_25_robust_model_epsilon_8.pth'))\n",
    "model.load_state_dict(state_dict)\n",
    "\n",
    "opt = ExtraSGD(model.parameters(), lr=args.lr_max, momentum=args.momentum, weight_decay=args.weight_decay)\n",
    "amp_args = dict(opt_level=args.opt_level, loss_scale=args.loss_scale, verbosity=False)\n",
    "if args.opt_level == 'O2':\n",
    "    amp_args['master_weights'] = args.master_weights\n",
    "model, opt = amp.initialize(model, opt, **amp_args)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "lr_steps = args.epochs * len(train_loader)\n",
    "scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=args.lr_min, max_lr=args.lr_max,\n",
    "                                              step_size_up=lr_steps / 2, step_size_down=lr_steps / 2)\n",
    "\n",
    "\n",
    "model.eval()\n",
    "start_eval_time = time.time()\n",
    "test_loss, test_acc = evaluate_standard(test_loader, model)\n",
    "pgd_loss, pgd_acc = evaluate_pgd(test_loader, model, 50, 10)\n",
    "end_eval_time = time.time()\n",
    "print('Test Loss \\t Test Acc \\t PGD Loss \\t PGD Acc')\n",
    "print('%.4f \\t %.4f \\t %.4f \\t %.4f' % (test_loss, test_acc, pgd_loss, pgd_acc))\n",
    "print('Total evaluation time: %.4f minutes' % ((end_eval_time - start_eval_time)/60))\n",
    "\n",
    "\n",
    "# Training\n",
    "prev_robust_acc = 0.\n",
    "start_train_time = time.time()\n",
    "logger.info('Epoch \\t Seconds \\t LR \\t \\t Train Loss \\t Train Acc')\n",
    "print('Epoch \\t Seconds \\t LR \\t \\t Train Loss \\t Train Acc')\n",
    "for epoch in range(args.epochs):\n",
    "    start_epoch_time = time.time()\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    train_acc = 0\n",
    "    train_n = 0\n",
    "    for i, (X, y) in enumerate(train_loader):\n",
    "        X, y = X.cuda(), y.cuda()\n",
    "        output = model(X)\n",
    "        loss = criterion(output, y)\n",
    "        opt.zero_grad()\n",
    "        with amp.scale_loss(loss, opt) as scaled_loss:\n",
    "            scaled_loss.backward()\n",
    "        if (i+1)%2:\n",
    "            opt.extrapolation()\n",
    "        else:\n",
    "            opt.step()\n",
    "        train_loss += loss.item() * y.size(0)\n",
    "        train_acc += (output.max(1)[1] == y).sum().item()\n",
    "        train_n += y.size(0)\n",
    "        scheduler.step()\n",
    "    epoch_time = time.time()\n",
    "    lr = scheduler.get_lr()[0]\n",
    "    print('%d \\t %.1f \\t \\t %.4f \\t %.4f \\t %.4f' % (epoch, epoch_time - start_epoch_time, lr, train_loss/train_n, train_acc/train_n))\n",
    "\n",
    "    model.eval()\n",
    "    start_eval_time = time.time()\n",
    "    test_loss, test_acc = evaluate_standard(test_loader, model)\n",
    "    pgd_loss, pgd_acc = evaluate_pgd(test_loader, model, 50, 10)\n",
    "    end_eval_time = time.time()\n",
    "    print('Test Loss \\t Test Acc \\t PGD Loss \\t PGD Acc')\n",
    "    print('%.4f \\t %.4f \\t %.4f \\t %.4f' % (test_loss, test_acc, pgd_loss, pgd_acc))\n",
    "    print('Total evaluation time: %.4f minutes' % ((end_eval_time - start_eval_time)/60))\n",
    "\n",
    "train_time = time.time()    \n",
    "print('Total train time: %.4f minutes' % ((train_time - start_train_time)/60))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
