from pgd_weight import LinfPGDAttack
from setup_mnist import MNIST, MNISTModel
import tensorflow as tf
import keras.backend as K
import numpy as np
import keras

import os
os.environ["CUDA_VISIBLE_DEVICES"]='2'

def attack(args):
	data = MNIST()
	x_test = data.test_data.reshape(-1,784)
	y_test = data.test_labels
	x_train = data.train_data
	y_train = data.train_labels

	batchSize = 1000
	steps = x_test.shape[0]//batchSize
	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True
	sess = K.get_session()
	model = MNISTModel(args["attack model"])
	#model = MNISTModel('models/dense_mnist_narrow')
	weight_nat = [layer.get_weights()[0] for layer in model.model.layers]
	tacc=[]
	#for epsilon in epsilons:
	for i in range(steps):
		attack = LinfPGDAttack(model, epsilon=args['epsilon'], num_steps=args["num of steps"], step_size=args["step size"], random_start=args["random start"], loss_func='xent')
		r = attack.perturb(x_test[i*batchSize:(i+1)*batchSize],y_test[i*batchSize:(i+1)*batchSize],weight_nat,sess)
		
		_, test_acc = r.evaluate(x_test.reshape(-1,784),y_test,verbose=0)
		_,train_acc = r.evaluate(data.train_data.reshape(-1,784), data.train_labels,verbose=0)
	
	r.save(args["save attacked model"])
	print("epsilon:",args['epsilon'])
	print("test_acc:",test_acc)
	print("train_acc:",train_acc)
	'''
	tacc.append(test_acc)
	print(tacc)
	filepath = "results/no_beta/100/e-2.npy"
	if os.path.isfile(filepath):
		a = list(np.load(filepath))
		a.append(test_acc)
		np.save(filepath, a)
	else:
		a=[]
		a.append(test_acc)
		np.save(filepath, a)
	'''
if __name__ == "__main__":
	import argparse
	parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
	parser.add_argument("-ep", "--epsilon",type=float, default=0, help="epsilon to use")
	parser.add_argument("-s", "--save attacked model",type=str, default='models/weight_perturb', help="path of saving model")
	parser.add_argument("-am", "--attack model", type=str, default="models/original", help='path of model attacked')
	parser.add_argument("-step", "--num of steps", type=int, default=100, help="number of attack steps")
	parser.add_argument("-step_size", "--step size", type=float, default=0.0005, help="step batchSize")
	parser.add_argument("-rs", "--random start", type=bool, default=True, help="decide to use random start")
	args = vars(parser.parse_args())
	attack(args)
#np.save('results/beta_1e-2/20/fix/narrow/e-3.npy',tacc)
#np.save('results/normal_200.npy',tacc)
		
