import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--train", action="store_true", help="Set to initiate training")
parser.add_argument("--test", action="store_true", help="Set to initiate testing")
parser.add_argument("--n", type=int, default=0, help="Value of n for training")
parser.add_argument("--reflection", action="store_true", help="whether to use reflection padding")
parser.add_argument("--load", action="store_true", help="Set to enable embedding")
parser.add_argument("--mnist", action="store_true", help="Using MNIST dataset")
parser.add_argument("--cifar10", action="store_true", help="Using CIFAR10 dataset")
parser.add_argument("--cifar100", action="store_true", help="Using CIFAR100 dataset")
parser.add_argument('--stl10', action="store_true", help="Use STL10 dataset")
parser.add_argument("--setup", type=int, default=0, help="Quickly choosing setups for training in cifar.")
parser.add_argument("--lr", type=float, default=0.0005, help="Learning rate")
parser.add_argument("--epoch", type=int, default=50, help="Number of epochs to train")
parser.add_argument("--batch", type=int, default=128, help="Batch size")
parser.add_argument("--century", type=int, default=1, help="Century of training, default to 1.")
parser.add_argument("--lora", action="store_true", help="Set to enable LoRa")
parser.add_argument("--inter", type=int, default=1, help = "Specify when using low rank adaptation.")
parser.add_argument("--matrix", type=int, default=0, help = "Specify the dimension of the low rank matrix.")
parser.add_argument("--decay", type=float, default=0, help = "Weight decay for training.")
parser.add_argument("--head", action="store_true", help="Set to enable multi-head")
parser.add_argument("--share", type=int, default=3, help="# of shared layers for multi-head. Default = 3")
parser.add_argument("--shared_choice", type=int, default=0, help="# of shared layers for multi-head. Default = 3")
parser.add_argument("--visual", action="store_true", help="Set to enable visualization")
parser.add_argument('--sim', action="store_true", help="Set to enable similarity")
parser.add_argument('--test_aug', action="store_true")
args = parser.parse_args()

if args.cifar100:
    data_choice = True
else:
    data_choice = False
century = args.century
inter = args.inter
# n stands for group size = 4* 2^n. Default set to 0.
n = args.n
lr = args.lr
batch = args.batch
reflection = args.reflection
lora = args.lora
head = args.head
dim = args.matrix

if args.visual:
    load = True
else:
    load = args.load
print("Please follow the readme")
if args.train:
    if args.mnist:
        from mnist.train_mnist import mnist_train
        while True:
            print("This is the {}-th run of training on MNIST.".format(century))
            mnist_train(lora, Load_Previous=load, n=n, l_r=lr, matrix_dim=dim,
                        inter_dim=inter,  head=head, epoch=args.epoch,
                        shared_choice=args.shared_choice, visualize_filters=args.visual,
                         test_aug=args.test_aug, reflection=reflection)
            century += 1
            load = True
    elif args.cifar10 or args.cifar100:
        from cifar.train_cifar import cifar_train
        print("This is the {}-th run of training on CIFAR10.".format(century))
        new_learning = cifar_train(lora, load_previous=load, n=n, set_up=args.setup, l_r=lr,
                                   b_s=batch, inter_dim=inter, weight_decay=args.decay, head=head,
                                   shared_layer=args.share, shared_choice=args.shared_choice,
                                   reflection=reflection, data_choice = data_choice, visualize_filters=args.visual)
        while True:
            century += 1
            print("This is the {}-th run of training on CIFAR10.".format(century))
            new_learning = cifar_train(lora, load_previous=True, n=n, set_up=args.setup, l_r=new_learning,
                                       b_s=batch, inter_dim=inter, weight_decay=args.decay, head=head,
                                       shared_layer=args.share, shared_choice=args.shared_choice,
                                       reflection=reflection, data_choice=data_choice)

    elif args.stl10:
        from stl10.train_stl10 import stl10_train

        print("This is the {}-th run of training on STL10.".format(century))
        new_learning = stl10_train(lora, Load_Previous=load, n=n,
                                   b_s=batch, inter_dim=inter, head=head,
                                   shared_choice=args.shared_choice,
                                   reflection=reflection)