import argparse
import numpy as np
import pandas as pd
import os
from framework.adv import Engine
from framework.config import get_arch

if __name__ == '__main__':

  parser = argparse.ArgumentParser(description='Clean Train')
  parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
  parser.add_argument('--resume_ckpt', default='none', type=str,
                      help='Checkpoint path')
  parser.add_argument('--pathdir', default='none', type=str)
  parser.add_argument('--dataset', default='cifar10', type=str)
  parser.add_argument('--arch', default='resnet32', type=str)
  parser.add_argument('--train_attack', default='none', type=str)
  parser.add_argument('--batch_size', default=128, type=int)
  parser.add_argument('--train_task', default='sclass', type=str)
  parser.add_argument('--test_class', default='class', type=str)
  parser.add_argument('--train_label', default=5, type=int)
  parser.add_argument('--test_label', default=10, type=int)
  parser.add_argument('--repre_root', default='...', type=str)
  parser.add_argument('--repre_folder', default='...', type=str)
  args = parser.parse_args()

  pathdir = None if args.pathdir == 'none' else args.pathdir
  resume_checkpoint = None if args.resume_ckpt == 'none' else args.resume_ckpt
  
  # random seed
  seed = 0
  
  engine = Engine(dataset=args.dataset,
                    basic_net=lambda: get_arch(args.arch),
                    lr=args.lr,
                    resume_checkpoint=resume_checkpoint,
                    train_attack=args.train_attack,
                    test_freq=10,
                    epochs=200,
                    batch_size=args.batch_size,
                    test_first=0 if resume_checkpoint else 1,
                    pytorch_seed=seed, 
                    numpy_seed=seed,
    #                pathdir=tmp_path,
                    train_task_level=args.train_task,
                    test_task_level=args.train_task,
                    repre_save_path=args.dataset+'representation'+str(seed)+'.csv',
                    train_label=args.train_label,
                    test_label=args.train_label,
                    repre_root=args.repre_root,
                    repre_save_folder=args.repre_folder,
                    )
  engine.start()
  engine.featureout('./')
  print('===== change label task to cifar 10 ======')
  engine.test_task_level=args.test_task
  engine.test_label=args.test_label
  engine.finetunestart()
  engine.output_collect('./')