import os

gpu = [0,2,3]

gpuid = str(gpu[0])
for i in range(1, len(gpu)):
	gpuid += "," + str(gpu[i])

datasets = "mnist" # "cifar10", "imagenet"
model_type = None

if (datasets == "cifar10"): model_type = "resnet110"
elif (datasets == "mnist"): model_type = "lenet"
elif (datasets == "imagenet"): model_type = "resnet50"

adv_train = False

noise = 0.5
num_noise = 2

eps = 255
steps = 10
warm = 10
attack_type = "PGD" # "DDN"

lhs_weights = 5.0
rhs_weights = 10.0
num_models = 3

init_lr = 0.01
resume = True

lbd = 2.0
margin = 8
beta = 16
#method = "salman"
#if (adv_train == False): method = "cohen"

method = "cohen" # "cohen", "macer", "salman", "stab", "drt"

if not os.path.exists("train_scripts"):
	os.makedirs("train_scripts")

if not os.path.exists("train_scripts/Certified"):
	os.makedirs("train_scripts/Certified")

tmp = "train_scripts/Certified/" + datasets + "_run_" + method


if (adv_train == True or method == "salman"):
	tmp += "_%d_%d_%d" % (eps, steps, warm)

if (method == "salman"):
	tmp += "_%s" % (attack_type)
elif (method == "drt"):
	tmp += "_%.1f_%.1f" % (lhs_weights, rhs_weights)
elif (method == "stab"):
	tmp += "_%.2f" % (lbd)
elif (method == "macer"):
	tmp += "_%.2f_%.2f_%.2f" % (lbd, margin, beta)

tmp += "_%.2f" % noise

if (resume == True):
	tmp += "_resume"
else:
	tmp += "_scratch"

tmp += ".sh"
print(tmp)

import os

commd = "CUDA_VISIBLE_DEVICES=" + gpuid + " python train/Certified/train_%s.py " % (method) + datasets + " " + \
		model_type + " --noise %.2f --lr %.6f " % (noise, init_lr)


if (method != "cohen" and method != "stab"):
	commd += "--num-noise-vec %d " % (num_noise)

if (resume == True): commd += "--resume "


if (adv_train == True):
	commd += "--adv-training --epsilon %d --num-steps %d --warmup %d " % (eps, steps, warm)

if (method == "salman"):
	commd += "--epsilon %d --num-steps %d --warmup %d --attack %s " % (eps, steps, warm, attack_type)
elif (method == "drt"):
	commd += "--lhs-weights %.1f --rhs-weights %.1f " % (lhs_weights, rhs_weights)
elif (method == "stab"):
	commd += "--lbd %.2f " % (lbd)
elif (method == "macer"):
	commd += "--lbd %.2f --margin %.2f --beta %.2f " % (lbd, margin, beta)

commd += "--num-models %d" % (num_models)


print(commd)
os.system("echo \"" + commd + "\" > " + tmp)
