import train
import numpy as np
import models
import surrogate
import argparse


parser = argparse.ArgumentParser(description='SNN_WEIGHT_INIT')
parser.add_argument('--k', default=0.8, type=float,help='set k value')
parser.add_argument('--lam', default=0.2, type=float, help='set lam value')
parser.add_argument('--lr', default=0.001, type=float, help='set learning rate')
parser.add_argument('--epoch', default=150, type=int, help='set timesteps value')
parser.add_argument('--T', default=20, type=int, help='set timesteps value')
parser.add_argument('--model', default='Cifar10NetVgg9', type=str, help='set model to be trained')
parser.add_argument('--runsfolder', default='runs0918-CIFAR10-Cifar10NetVgg9', type=str, help='set folder to save events')
parser.add_argument('--init_mode', default='asymptote_normal', type=str)


args = parser.parse_args()
net = args.model
train.init_mode = args.init_mode
if train.init_mode == 'asymptote_normal':
    train.init_param = "{\"bias_correction\":true}"
else:
    train.init_param = "{\"bias_correction\":false}"

train.k = args.k
train.lam = args.lam
train.T = args.T
train.train_epoch = args.epoch
train.learning_rate = args.lr
train.opt = "adam"
train.opt_param = "{\"beta1\":0.9,\"beta2\":0.99,\"weight_decay\": 0}"
train.scheduler = 'MultiStepLR'
train.scheduler_param = "{\"milestones\": [50, 90, 130]}"
train.runs = args.runsfolder
train.suffix = 'CIFAR10-' + str(train.learning_rate) + '-' + net
train.dataset = 'cifar10'
train.batch_size = 32
train.net = getattr(models, net)(k=train.k,
                                 lam=train.lam,
                                 T=train.T,
                                 grad=surrogate.ATan,
                                 batchnorm=False)
train.inspector_save_full = False
train.inspector = False
train.if_save_model = False
train.main()
train.reset_training()
