import tensorflow as tf
import numpy as np
import keras.backend as K



class LinfPGDAttack:
  def __init__(self, model, epsilon, num_steps, step_size, random_start, loss_func):
    """Attack parameter initialization. The attack performs k steps of
       size a, while always staying within epsilon from the initial
       point."""
    self.model = model
    self.epsilon = epsilon
    self.num_steps = num_steps
    self.step_size = step_size
    self.rand = random_start
    

    if loss_func == 'xent':

      self.inputs = tf.placeholder(tf.float32,[None,784])
      self.labels = tf.placeholder(tf.float32,[None,10])
      self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.labels,
                                                       logits=self.model.model(self.inputs))
    elif loss_func == 'cw':
      label_mask = tf.one_hot(model.y_input,
                              10,
                              on_value=1.0,
                              off_value=0.0,
                              dtype=tf.float32)
      correct_logit = tf.reduce_sum(label_mask * model.pre_softmax, axis=1)
      wrong_logit = tf.reduce_max((1-label_mask) * model.pre_softmax - 1e4*label_mask, axis=1)
      loss = -tf.nn.relu(correct_logit - wrong_logit + 50)
    else:
      print('Unknown loss function. Defaulting to cross-entropy')
      loss = model.xent
    self.grad = tf.gradients(self.loss, self.model.model.trainable_weights)

  def perturb(self, x_nat, y, weight_nat, sess):
    """Given a set of examples (x_nat, y), returns a set of adversarial
       examples within epsilon of x_nat in l_infinity norm."""
    
    if self.rand:
    	weight_n=[]
    	for i in range(len(weight_nat)):
      		weight_n.append(weight_nat[i] + np.random.uniform(-self.epsilon, self.epsilon, weight_nat[i].shape).astype('float32'))
      
    else:
      weight_n = weight_nat
    
    for i in range(self.num_steps):

      g = sess.run(self.grad,{self.inputs:x_nat, self.labels:y})
      for j in range(len(weight_nat)):
        weight_n[j] = np.add(weight_n[j], self.step_size * np.sign(g[j]), out=weight_n[j], casting='unsafe')
        weight_n[j] = np.clip(weight_n[j], weight_nat[j] - self.epsilon, weight_nat[j] + self.epsilon)
        self.model.model.layers[j].set_weights([weight_n[j]])
      

    return self.model.model