'''Basic adv_test example'''
import argparse
from framework.adv import Engine
from framework.config import get_arch

if __name__ == '__main__':

  parser = argparse.ArgumentParser(description='Robust test')
  parser.add_argument('--dataset', default='cifar10', type=str)
  parser.add_argument('--resume_ckpt', default='checkpoint.pth', type=str,
                      help='Checkpoint path')
  parser.add_argument('--eps', default=8.0, type=float, help='Epsilon of maximum attack')
  parser.add_argument('--step_size', default=2.0, type=float, help='PGD step size in adversarial training')
  parser.add_argument('--step_num', default=7, type=int, help='Number of PGD steps in adversarial training')
  parser.add_argument('--attack', default='pgd', type=str)
  parser.add_argument('--arch', default='10xresnet32', type=str)
  parser.add_argument('--test_trainset', action='store_true')
  args = parser.parse_args()


  resume_checkpoint = args.resume_ckpt
  engine = Engine(resume_checkpoint=resume_checkpoint,
                  dataset=args.dataset,
                  basic_net=lambda: get_arch(args.arch),
                  eps=args.eps,
                  step_size=args.step_size,
                  step_num=args.step_num,
                  attack=args.attack,
                  test_trainset=args.test_trainset)

  engine.test(engine.start_epoch)