from attacks import *

def get_attack(inputs, labels, model, attack_type, eps=8/255):
	adv_inputs = None
	if attack_type == 'pgd':
		atk = PGD(model, eps=eps, alpha=1 / 255, steps=10, random_start=True)
		adv_inputs = atk(inputs, labels)
	if attack_type == 'bim':
		atk = BIM(model, eps=eps, alpha=2 / 255, steps=10)
		adv_inputs = atk(inputs, labels)
	if attack_type == 'pif':
		atk = PIFGSM(model, num_iter_set=10)
		adv_inputs = atk(inputs, labels)
	if attack_type == 'vmi':  # attacks.VMIFGSM
		atk = VMIFGSM(model, eps=eps, alpha=2/255, steps=10, decay=1.0, N=5, beta=3/2)
		adv_inputs = atk(inputs, labels)
	if attack_type == 'vni':
		atk = VNIFGSM(model, eps=eps, alpha=2 / 255, steps=10, decay=1.0, N=5, beta=3 / 2)
		adv_inputs = atk(inputs, labels)
	if attack_type == 'apgd':
		atk = APGD(model, norm='Linf', eps=eps, steps=10, n_restarts=1, seed=0, loss='ce', eot_iter=1, rho=.75, verbose=False)
		adv_inputs = atk(inputs, labels)
	if attack_type == 'anda':
			adv_inputs = anda_attack.attack(inputs, labels, model)

	return adv_inputs


def skip_anda(i):  # some samples cause memory error with anda
	skip = False
	if 878 < i < 900:
		skip = True
	if 1091 < i < 1110:
		skip = True
	if 6400 < i < 6450:
		skip = True
	if 9180 < i < 9200:
		skip = True
	if (i == 243 or i == 251 or i == 254 or i == 286 or i == 710 or i == 878 or i == 1292 or i == 1357 or i == 3096
			or i == 3406 or i == 4174 or i == 4398 or i == 6400 or i == 6858 or i == 9167 or i == 9180):
		skip = True
	return skip










