import argparse
import time
from utils import ArgBoolean, set_seed, none_or_str, \
    elapsed_time

from sequential_pretrained import run

assert __name__ == "__main__", "Invalid usage! Run this script from command line, do not import it!"

arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--dataset", help="Dataset (mnist, cifar100, cub200)", type=str, default="cub200")
arg_parser.add_argument("--lr", help="Learning rate", type=float, default=0.001)
arg_parser.add_argument("--epochs", help="Number of epochs for each task", type=int, default=100)
arg_parser.add_argument("--batch_size", help="Number of examples in each minibatch", type=int, default=2 ** 16)
arg_parser.add_argument("--log_every", help="How many steps do I want to log?", type=int, default=10)
arg_parser.add_argument("--update_every", help="Every how many epochs do I update the parameters?", type=int,
                        default=-10)
arg_parser.add_argument("--net", help="Name of the net I want to use (resnet18, resnet50)", type=str,
                        default='resnet18')
arg_parser.add_argument("--device", help="On which device do I want to do the computations", type=str, default='cpu')
arg_parser.add_argument("--n_tasks", help="Number of tasks", type=int, default=20)
arg_parser.add_argument("--order", help="Order of the approximation", type=int, default=10)
arg_parser.add_argument("--wandb_project", help="Use W&B, this is the project name", type=none_or_str,
                        default="wandb_log")
arg_parser.add_argument("--wandb_group", help="Group within the W&B project name", type=none_or_str,
                        default="Sequential Hippo")
arg_parser.add_argument("--hippo", help="Do I want to use hippo-based online approx strategy", type=ArgBoolean(),
                        default=False)
arg_parser.add_argument("--ewc", help="Do I want to use the Elastic Weight Consolidation strategy ", type=ArgBoolean(),
                        default=True)
arg_parser.add_argument("--ewc_lambda", help="How much to weight the EWC loss", type=float, default=10 ** -6)
arg_parser.add_argument("--replay", help="Do I want to use the Replay strategy", type=ArgBoolean(), default=False)
arg_parser.add_argument("--seed", help="Seed for random numbers (if < 0, it depends on time)", type=int, default=314)
opts = vars(arg_parser.parse_args())

# setting up seeds for random number generators
set_seed(opts['seed'])

# running the learning algorithm
start_time = time.time()
run(opts)
end_time = time.time()

print("[Elapsed: " + elapsed_time(start_time, end_time) + "]")