import argparse

from sde.legacy.learn_synthetic_data import train

if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--num-classes', default=30, type=int)
    arg_parser.add_argument('--lr', default=1e-4, type=float)
    arg_parser.add_argument('--eta-min', default=1e-6, type=float)
    arg_parser.add_argument('--num-epochs', default=30, type=int)
    arg_parser.add_argument('--train-split', default=0.8, type=float)
    arg_parser.add_argument('--devices', nargs='+', default=[0, 1], type=int)
    arg_parser.add_argument('--data-root', default="synthetic_dataset/")
    arg_parser.add_argument('--validation-in-train', action='store_true')
    args = arg_parser.parse_args()
    print("Run with arguments")
    for name, val in vars(args).items():
        print(f"{name}: {val}")
    train(
        num_classes=args.num_classes,
        num_epochs=args.num_epochs,
        devices=args.devices,
        lr=args.lr,
        eta_min=args.eta_min,
        data_root=args.data_root,
        train_split=args.train_split,
        validation_in_train=args.validation_in_train,
    )
