'''Basic adv_train example'''
import argparse
from framework.adv import Engine
from framework.config import get_arch

if __name__ == '__main__':

  parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
  parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
  parser.add_argument('--resume_ckpt', default='none', type=str,
                      help='Checkpoint path')
  parser.add_argument('--eps', default=8.0, type=float,
                      help='Epsilon of maximum attack')
  parser.add_argument('--step_size', default=2.0, type=float,
                      help='PGD step size in adversarial training')
  parser.add_argument('--step_num', default=7, type=int,
                      help='Number of PGD steps in adversarial training')
  parser.add_argument('--pathdir', default='none', type=str)
  parser.add_argument('--attack', default='pgd', type=str)

  parser.add_argument('--epochs', default=200, type=int)
  parser.add_argument('--dataset', default='cifar10', type=str)
  parser.add_argument('--test_freq', default=5, type=int)
  parser.add_argument('--test_trainset', action='store_true')
  parser.add_argument('--arch', default='10xresnet32', type=str)
  args = parser.parse_args()


  pathdir = None if args.pathdir == 'none' else args.pathdir
  resume_checkpoint = None if args.resume_ckpt == 'none' else args.resume_ckpt
  engine = Engine(lr=args.lr,
                  dataset=args.dataset,
                  basic_net=lambda: get_arch(args.arch),
                  resume_checkpoint=resume_checkpoint,
                  eps=args.eps,
                  step_size=args.step_size,
                  step_num=args.step_num,
                  attack=args.attack,
                  pathdir=pathdir,
                  epochs=args.epochs,
                  test_first=0 if resume_checkpoint else 1,
                  test_trainset=args.test_trainset,
                  test_freq=args.test_freq)
  engine.start()
