import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"]='2'
import numpy as np
from setup_mnist import MNIST

epsilon = 0.01
lambda_ = 1e-2
mu=0.01

data = MNIST()
X_train = data.train_data
y_train = data.train_labels
X_test = data.test_data
y_test = data.test_labels
p

W0 = tf.Variable(tf.truncated_normal([784,128], stddev=0.1))
W1 = tf.Variable(tf.truncated_normal([128,64], stddev=0.1))
W2 = tf.Variable(tf.truncated_normal([64,32], stddev=0.1))
W3 = tf.Variable(tf.truncated_normal([32,10], stddev=0.1))


x = tf.placeholder(tf.float32, [None,28,28,1])
y = tf.placeholder(tf.float32, [None,10])


xImage = tf.reshape(x,[-1, 784])
layer1 = tf.nn.relu(tf.matmul(xImage, W0))
layer2 = tf.nn.relu(tf.matmul(layer1, W1))
layer3 = tf.nn.relu(tf.matmul(layer2, W2))
o = tf.matmul(layer3, W3)
layer1_ = tf.nn.relu(tf.matmul(xImage, W0+epsilon))
layer2_ = tf.nn.relu(tf.matmul(layer1_, W1+epsilon))
	
'''
correct_logits = tf.reduce_sum(y * tf.nn.softmax(o))
wrong_logits =  tf.reduce_max((1-y)*tf.nn.softmax(o))
correct_margin =  correct_logits - wrong_logits 
'''

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=o)
y_true=tf.argmax(y,axis=1)[0]
max_yy=tf.reduce_max(tf.norm(W3[:,y_true:y_true+1]-W3, ord=1, axis=0))

regularizer = max_yy*epsilon*tf.norm(x,ord=1)*(tf.reduce_max(tf.norm(W2,ord=1,axis=1))*tf.reduce_max(tf.norm(W1,ord=1,axis=1)))+\
				tf.reduce_max(tf.norm(W2,ord=1,axis=1))*epsilon*tf.reduce_max(tf.norm(layer1_,ord=1,axis=0)) +\
				epsilon*tf.reduce_max(tf.norm(layer2_,ord=1,axis=0))

#regularizer_one = max_yy*epsilon*tf.reduce_max(tf.norm(layer2,ord=1,axis=0))
#perturb all layer
W_inf=tf.reduce_sum(tf.reduce_max(tf.norm(W1,ord=1,axis=1))+tf.reduce_max(tf.norm(W2,ord=1,axis=1))+tf.reduce_max(tf.norm(W3,ord=1,axis=1))+tf.reduce_max(tf.norm(W0,ord=1,axis=1)))
W_1 = tf.reduce_sum(tf.reduce_max(tf.norm(W0,ord=1,axis=0))+tf.reduce_max(tf.norm(W1,ord=1,axis=0))+tf.reduce_max(tf.norm(W2,ord=1,axis=0))+tf.reduce_max(tf.norm(W3,ord=1,axis=0)))
#perturb all layers
loss = tf.reduce_sum(cross_entropy)+lambda_*regularizer+mu*(W_inf+W_1)

#perturb one layer
#loss = tf.reduce_sum(cross_entropy)+lambda_*regularizer_one+mu*(W_inf+W_1)

#loss = correct_margin - lambda_*regularizer
#ramp_fn = tf.case({tf.less_equal(loss, 0): lambda:1., tf.greater_equal(loss, gamma): lambda:0.},default=lambda: 1-loss/gamma, exclusive=True)

train = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.1).minimize(loss)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,1),tf.argmax(o,1)),tf.float32))

sess = tf.Session()
saver = tf.train.Saver()
sess.run( tf.global_variables_initializer() )

epochs=20
batchSize=32
#np.random.seed(8)
#idx = np.random.randint(0,55000,size=1000)
#X_train = X_train[idx]
#y_train = y_train[idx]

for epoch in range(epochs):
	for steps in range(X_train.shape[0]//batchSize):
		#minibatch sgd 
		idx = np.random.randint(0,batchSize,size=1)
		x_batch=X_train[steps*batchSize:(steps+1)*batchSize][idx]
		y_batch=y_train[steps*batchSize:(steps+1)*batchSize][idx].reshape(-1,20)
		sess.run(train, {x:x_batch, y:y_batch})

	ls,train_acc = sess.run([loss,accuracy], {x:X_train, y:y_train})
	test_acc = sess.run(accuracy,{x:X_test, y:y_test})
	print("Epoch: ",epoch, 'acc:', train_acc)

saver.save(sess,'models/sls_mnist')
train_acc = sess.run(accuracy, {x:X_train, y:y_train})
tacc = sess.run(accuracy, {x:X_test, y:y_test})
print("Test Acc:", tacc, "train Acc:", train_acc)
print("Test - Train: ", tacc-train_acc, "epsilon:",epsilon, 'mu:',mu)
