import train
import models
import surrogate
import argparse

parser = argparse.ArgumentParser(description='SNN_WEIGHT_INIT')
parser.add_argument('--k', default=0.8, type=float, help='set k value')
parser.add_argument('--lam', default=0.2, type=float, help='set lam value')
parser.add_argument('--lr', default=0.001, type=float, help='set learning rate')
parser.add_argument('--epoch', default=150, type=int, help='set timesteps value')
parser.add_argument('--T', default=20, type=int, help='set snn timesteps')
parser.add_argument('--model', default='CifarDvsNetVgg9', type=str, help='set model to be trained')
parser.add_argument('--runsfolder', default='runs0922-CIFAR10DVS-L9-ADAM', type=str, help='set folder to save events')
parser.add_argument('--init_mode', default='asymptote_normal', type=str)
parser.add_argument('--randseed', default=False, action="store_true")

args = parser.parse_args()

train.random_seed = args.randseed

net = args.model
train.init_mode = args.init_mode
train.hasAug = args.hasAug
if train.init_mode == 'asymptote_normal':
    train.init_param = "{\"bias_correction\":true}"
else:
    train.init_param = "{\"bias_correction\":false}"

train.k = args.k
train.lam = args.lam
train.T = args.T
train.train_epoch = args.epoch
train.learning_rate = args.lr
train.opt = "adam"
train.opt_param = "{\"beta1\":0.9,\"beta2\":0.99,\"weight_decay\": 0}"
train.scheduler = 'MultiStepLR'
train.scheduler_param = "{\"milestones\": [50, 90, 130],  \"gamma\": 0.1}"
train.runs = args.runsfolder
train.encode_func = 'event'
train.suffix = 'CIFAR10DVS_LR' + str(args.lr)
train.dataset = 'cifar10dvs'
train.batch_size = 16
train.net = getattr(models, net)(k=train.k,
                                 lam=train.lam,
                                 T=train.T,
                                 grad=surrogate.ATan)
train.inspector_save_full = False
train.inspector = False
train.if_save_model = False
train.main()
train.reset_training()
