from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import flags
from networks import ResNet, SeResNet, DenseNet, MobileNetV1, MobileNetV3, ResNet12, VggNet


FLAGS = flags.FLAGS


def input_diversity(input_tensor):
    image_width = 32
    image_resize = 36
    rnd = tf.random_uniform((), image_width, image_resize, dtype=tf.int32)
    rescaled = tf.image.resize_images(input_tensor, [rnd, rnd], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    h_rem = image_resize - rnd
    w_rem = image_resize - rnd
    pad_top = tf.random_uniform((), 0, h_rem, dtype=tf.int32)
    pad_bottom = h_rem - pad_top
    pad_left = tf.random_uniform((), 0, w_rem, dtype=tf.int32)
    pad_right = w_rem - pad_left
    padded = tf.pad(rescaled, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], constant_values=0.)
    padded.set_shape((input_tensor.shape[0], image_resize, image_resize, 3))
    padded = tf.image.resize_images(padded, [image_width, image_width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    out = tf.cond(tf.random_uniform(shape=[1])[0] < FLAGS.DI_prob, lambda: padded, lambda: input_tensor)
    return out




def gkern(kernlen=21, nsig=3):
    """Returns a 2D Gaussian kernel array."""
    import scipy.stats as st
    
    x = np.linspace(-nsig, nsig, kernlen)
    kern1d = st.norm.pdf(x)
    kernel_raw = np.outer(kern1d, kern1d)
    kernel = kernel_raw / kernel_raw.sum()
    return kernel

kernel = gkern(15, 3).astype(np.float32)
stack_kernel = np.stack([kernel, kernel, kernel]).swapaxes(2, 0)
stack_kernel = np.expand_dims(stack_kernel, 3)



class ICE:
    def __init__(self, dim_input=1, dim_output=1):
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.attack_step_size = tf.placeholder_with_default(FLAGS.test_attack_step_size, ())
        self.eps = tf.placeholder_with_default(FLAGS.test_eps, ())

        self.substitute_net = ResNet12(0)

        print('The substitute model is',self.substitute_net.__class__.__name__)

        self.dim_mapping = {'cifar10': 10,
                            'cifar100': 100,
                            'tiered1': 351,
                            'tiered2': 257}
        
        self.test_target_nets = [MobileNetV3(0, dim_output=self.dim_mapping[FLAGS.target_data]),
                                 VggNet('VGG16', dim_output=self.dim_mapping[FLAGS.target_data]),
                                 ResNet(18, dim_output=self.dim_mapping[FLAGS.target_data]),
                                 ResNet(34, dim_output=self.dim_mapping[FLAGS.target_data]),
                                 SeResNet(26, dim_output=self.dim_mapping[FLAGS.target_data]),
                                 DenseNet(26, dim_output=self.dim_mapping[FLAGS.target_data])
                                 ]
        

        self.loss_func = tf.nn.softmax_cross_entropy_with_logits_v2

        if FLAGS.target_data == 'cifar10':
            image_shape = (FLAGS.batch_size, 32, 32, 3)
            shape = (FLAGS.batch_size, 10)

        elif FLAGS.target_data == 'cifar100':
            image_shape = [FLAGS.batch_size, 32, 32, 3]
            shape = [FLAGS.batch_size, 100]

        elif FLAGS.target_data == 'tiered1':
            image_shape = [FLAGS.batch_size, 84, 84, 3]
            shape = [FLAGS.batch_size, 351]

        elif FLAGS.target_data == 'tiered2':
            image_shape = [FLAGS.batch_size, 56, 56, 3]
            shape = [FLAGS.batch_size, 257]
        else:
            raise ValueError('Unrecognized target data')

        self.image = tf.placeholder(tf.float32, shape=image_shape, name='image')
        self.label = tf.placeholder(tf.float32, shape=shape, name='label')
        
        self.target_net_vars = {}
        self.target_vars = []
        self.target_loaders = {}
        self.test_clean_outputs = {}
        self.test_attack_outputs = {}

        self.scope_mapping = {'cifar10': 'Target',
                              'cifar100': 'Target_100',
                              'tiered1': 'Tiered_Target1',
                              'tiered2': 'Tiered_Target2'}
        self.data_attacks = {}
        self.data_attacks_steps = []
        
        
    def construct_testing_graph(self, update_steps):
        attack = self.image
        self.train_accs = []
        self.train_losses = []
        
        print('Using substitute model')
        for step in range(update_steps + 1):
            with tf.variable_scope('Substitute', reuse=tf.AUTO_REUSE) as training_scope:
                logit = self.substitute_net.forward(attack, True)

                prob = tf.nn.softmax(logit) + 1e-12
                log_prob = tf.log(prob)
                entropy = -prob * log_prob
                train_loss = tf.reduce_mean(entropy)
        
                if step == 0:
                    self.substitute_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Substitute')

            grad = tf.gradients(train_loss, attack)[0]

            abs_grad = tf.abs(grad)
            norm_grad = grad / tf.reduce_sum(abs_grad, axis=[1, 2, 3], keep_dims=True)
            sign_grad = tf.sign(norm_grad)

            attack = attack + self.attack_step_size * sign_grad
            attack = tf.clip_by_value(attack, 0.0, 1.0)
            self.data_attacks_steps.append(attack)
    
        self.train_accs = tf.zeros(dtype=tf.float32, shape=[1,])

        attack_last = self.data_attacks_steps[-2]


        diff = attack_last - self.image
        diff_sign = tf.sign(diff)
        attack_last = self.image + FLAGS.test_eps * diff_sign
        attack_last = tf.clip_by_value(attack_last, 0.0, 1.0)
        self.testing_attack_images_last = attack_last

        distortion = self.testing_attack_images_last - self.image
        distortion *= 255
        
        l2_distortion = tf.square(distortion)
        self.testing_l2_distortion_mean = tf.sqrt(tf.reduce_mean(l2_distortion))
        l2_distortion_all = tf.sqrt(tf.reduce_sum(l2_distortion, axis=[1,2,3]))
        self.testing_l2_distortion_all = tf.reduce_mean(l2_distortion_all)

        l1_distortion = tf.abs(distortion)
        l1_distortion_all = tf.reduce_sum(l1_distortion, axis=[1, 2, 3])
        self.testing_l1_distortion_all = tf.reduce_mean(l1_distortion_all)

        l_inf_distortion = tf.reduce_max(l1_distortion, axis=[1, 2, 3])
        self.testing_l_inf_distortion_all = tf.reduce_mean(l_inf_distortion)

        self.test_target_loss1 = 0
        self.test_target_loss2 = 0
        self.test_target_accuracy1 = 0
        self.test_target_accuracy2 = 0
        
        
        scope = self.scope_mapping[FLAGS.target_data]
        for i, net in enumerate(self.test_target_nets):

            net_name = net.__class__.__name__
            net_size = net.size
            scope_name = scope + '/' + net_name + str(net_size)
            with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE):
                output1 = net.forward(self.image, False)
            self.test_clean_outputs[scope_name] = tf.nn.softmax(output1)
            accuracy1 = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(output1), 1),
                                                    tf.argmax(self.label, 1))
            self.test_target_accuracy1 += accuracy1
            loss1 = self.loss_func(logits=output1, labels=self.label)
            self.test_target_loss1 += tf.reduce_mean(loss1)
    
            vars = tf.get_collection(key=tf.GraphKeys.VARIABLES, scope=scope_name)
            self.target_net_vars[scope_name] = vars
            self.target_vars.extend(vars)
            self.target_loaders[scope_name] = tf.train.Saver(vars, max_to_keep=0)

            begin_vars = tf.global_variables()
            if 'SqueezeNet' in net_name:
                net.conv_num = 0 
            with tf.variable_scope(scope_name, reuse=True):
                output2 = net.forward(self.testing_attack_images_last, False)
                self.test_attack_outputs[scope_name] = tf.nn.softmax(output2)
                accuracy2 = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(output2), 1),
                                                        tf.argmax(self.label, 1))
                self.test_target_accuracy2 += accuracy2
                loss2 = self.loss_func(logits=output2, labels=self.label)
                self.test_target_loss2 += tf.reduce_mean(loss2)

            end_vars = tf.global_variables()
            for var in end_vars:
                if var not in begin_vars:
                    print('added var', var)

        self.test_target_loss1 = self.test_target_loss1 / (len(self.test_target_nets) )
        self.test_target_loss2 = self.test_target_loss2 / (len(self.test_target_nets) )
        self.test_target_accuracy1 = self.test_target_accuracy1 / (len(self.test_target_nets) )
        self.test_target_accuracy2 = self.test_target_accuracy2 / (len(self.test_target_nets) )

        
        
    def load_target_models(self, sess):
        scopes = {'cifar10': 'target_models_cifar10',
                  'cifar100': 'target_models_cifar100',
                  'tiered1': 'target_models_Tiered1',
                  'tiered2': 'target_models_Tiered2'}
    
        scope = scopes[FLAGS.target_data]
        for model in self.test_target_nets:
            model_name = model.__class__.__name__ + str(model.size)
            if model.__class__.__name__ == 'ResNet':
                model_path = '../../' + scope + '/ResNet' + str(model.size)
                model_name = 'ResNet' + str(model.size)
    
            elif model.__class__.__name__ == 'SeResNet':
                model_path = '../../' + scope + '/SeResNet' + str(model.size)
                model_name = 'SeResNet' + str(model.size)
    
            elif model.__class__.__name__ == 'DenseNet':
                if not model.BC:
                    model_path = '../../' + scope + '/DenseNet' + str(model.size)
                    model_name = 'DenseNet' + str(model.size)
                else:
                    model_path = '../../' + scope + '/DenseNet' + str(model.size) + 'BC'
                    model_name = 'DenseNet' + str(model.size)
    
    
            elif model.__class__.__name__ == 'SqueezeNet':
                model_path = '../../' + scope + '/SqueezeNet' + str(model.size)
                model_name = 'SqueezeNet' + str(model.size)
    
            elif model.__class__.__name__ == 'MobileNetV1':
                model_path = '../../' + scope + '/Mobile1'
            elif model.__class__.__name__ == 'MobileNetV2':
                model_path = '../../' + scope + '/Mobile2'
            elif model.__class__.__name__ == 'MobileNetV3':
                model_path = '../../' + scope + '/Mobile3'
            elif model.__class__.__name__ == 'VggNet':
                model_path = '../../' + scope + '/VggNet16'
            elif model.__class__.__name__ == 'ShuffleNetV1':
                model_path = '../../' + scope + '/Shuffle1'
            elif model.__class__.__name__ == 'ShuffleNetV2':
                model_path = '../../' + scope + '/Shuffle2'
            else:
                pass
    
            if scope == 'target_models_cifar10':
                self.target_loaders['Target/' + model_name].restore(sess, model_path)
                print(model_path)
            elif scope == 'target_models_cifar100':
                self.target_loaders['Target_100/' + model_name].restore(sess, model_path)
                print(model_path)
            elif scope == 'target_models_Tiered1':
                self.target_loaders['Tiered_Target1/' + model_name].restore(sess, model_path)
                print(model_path)
            elif scope == 'target_models_Tiered2':
                self.target_loaders['Tiered_Target2/' + model_name].restore(sess, model_path)
                print(model_path)
            else:
                raise ValueError("Unknown model path and scope")

