import argparse
from IPython import embed
import numpy as np


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--test_batch_size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--samples_per_epoch", type=int, default=10000)
    parser.add_argument("--optimizer", type=str, default='Adam')
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate of models")
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--log_interval", type=int, default=10)
    parser.add_argument("-n", "--num_clients", type=int, default=10)
    parser.add_argument("--dataset", type=str, default="imdb")
    parser.add_argument("--defense", type=str, default="none")
    parser.add_argument("--model", type=str, default="CNN")
    parser.add_argument("--loader_type", type=str, choices=["iid", "byLabel", "dirichlet-0.9"], default="iid")
    parser.add_argument("--loader_path", type=str, default="", help="where to save the data partitions")
    parser.add_argument("--AR", type=str, )
    parser.add_argument("--attacker_list", type=str, default="clean")
    parser.add_argument("--attacks", type=str, help="if contains \"backdoor\", activate the corresponding tests")
    parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default='cuda')
    parser.add_argument("--inner_epochs", type=int, default=1)

    args = parser.parse_args()
    
    ###args.save_model_weights = True
    args.save_model_weights = False
    args.loader_path = f'.data/{args.dataset}_{args.loader_type}.pt'
    if args.attacker_list == 'clean':
        args.attacker_list = []
    else:
        args.attacker_list = [int(s) for s in args.attacker_list.split('-')]
    return args


if __name__ == "__main__":

    import _main

    args = parse_args()
    print("#" * 64)
    for i in vars(args):
        print(f"#{i:>40}: {str(getattr(args, i)):<20}#")
    print("#" * 64)
    _main.main(args)
