import argparse
import os
from os.path import join
import pprint
import sys
import time
import glob
from utils import ArgNumber, ArgInit, ArgBoolean, set_seed, none_or_str, elapsed_time  #, save_dict, generate_experiment_name, set_seed, elapsed_time, plot

from permuted_base import run

assert __name__ == "__main__", "Invalid usage! Run this script from command line, do not import it!"

arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--lr", help="Learning rate", type=float, default=0.001)
arg_parser.add_argument("--epochs", help="Number of epochs for each task", type=int, default=2)
arg_parser.add_argument("--batch_size", help="Number of examples in each minibatch", type=int, default=100)
arg_parser.add_argument("--log_every", help="How many steps do I want to log?", type=int, default=1000)
arg_parser.add_argument("--net", help="Name of the net I want to use", type=str, default='small')
arg_parser.add_argument("--dataset", help="Name of the dataset I want to use", type=str, default='mnist')
arg_parser.add_argument("--device", help="On which device do I want to do the computations", type=str, default='cpu')
arg_parser.add_argument("--n_tasks", help="Number of tasks", type=int, default=10)
arg_parser.add_argument("--order", help="Order of the approximation", type=int, default=10)
arg_parser.add_argument("--n_perm_pix", help="Number of permuted pixels for each task", type=int, default=100)
arg_parser.add_argument("--wandb_project", help="Use W&B, this is the project name", type=none_or_str, default=None)
arg_parser.add_argument("--wandb_group", help="Group within the W&B project name", type=none_or_str, default=None)
arg_parser.add_argument("--hippo", help="Do I want to use online approx strategy", type=ArgBoolean(), default=False)
arg_parser.add_argument("--replay", help="If I want to use a replay method", type=ArgBoolean(), default=True)
arg_parser.add_argument("--tau", help="Quantization step", type=float, default=0.1)
arg_parser.add_argument("--weight_decay", help="Weight of the L2 normalization ", type=float, default=0)
arg_parser.add_argument("--seed", help="Seed for random numbers (if < 0, it depends on time)", type=int, default=-1)
opts = vars(arg_parser.parse_args())

# setting up seeds for random number generators
set_seed(opts['seed'])

# running the learning algorithm
start_time = time.time()
run(opts)
end_time = time.time()

print("[Elapsed: " + elapsed_time(start_time, end_time) + "]")
