from __future__ import print_function
import numpy as np
import sys
import tensorflow as tf

from tensorflow.python.platform import flags
# from utils import mse, xent, conv_block, normalize
from networks import ResNet, SeResNet, ResNext, DenseNet, MobileNetV1, MobileNetV2, MobileNetV3, \
    ShuffleNetV1, ShuffleNetV2, SqueezeNet, ResNet12

import math
import pwd
import os


FLAGS = flags.FLAGS


class ICE:
    def __init__(self):
        self.lr = tf.placeholder_with_default(FLAGS.lr, ())
        self.attack_step_size = tf.placeholder_with_default(FLAGS.train_attack_step_size, ())

        self.substitute_net = ResNet12(0)

        print('The substitute model is',self.substitute_net.__class__.__name__)

        self.train_target_nets = [ResNet(18)]
        self.test_target_nets = [MobileNetV3(0)]

        self.loss_func = tf.nn.softmax_cross_entropy_with_logits_v2

        shape = [FLAGS.batch_size, 32, 32, 3]
        self.image_cifar10 = tf.placeholder(tf.float32, shape=shape)
        shape = [FLAGS.batch_size, 10]
        self.label_cifar10 = tf.placeholder(tf.float32, shape=shape)

        shape = [FLAGS.batch_size, 32, 32, 3]
        self.image_cifar100 = tf.placeholder(tf.float32, shape=shape)
        shape = [FLAGS.batch_size, 100]
        self.label_cifar100 = tf.placeholder(tf.float32, shape=shape)

        shape = [FLAGS.batch_size, 84, 84, 3]
        self.image_tiered1 = tf.placeholder(tf.float32, shape=shape)
        shape = [FLAGS.batch_size, 351]
        self.label_tiered1 = tf.placeholder(tf.float32, shape=shape)

        shape = [FLAGS.batch_size, 56, 56, 3]
        self.image_tiered2 = tf.placeholder(tf.float32, shape=shape)
        shape = [FLAGS.batch_size, 257]
        self.label_tiered2 = tf.placeholder(tf.float32, shape=shape)

        shape = [FLAGS.batch_size, 84, 84, 3]
        self.image = tf.placeholder(tf.float32, shape=shape)
        shape = [FLAGS.batch_size, 10]
        self.label = tf.placeholder(tf.float32, shape=shape)
        
        nets_cifar10  = [ResNet(18, dim_output=10),  ]
        nets_cifar100 = [ResNet(18, dim_output=100), ]
        nets_tiered1  = [ResNet(18, dim_output=351), ]
        nets_tiered2  = [ResNet(18, dim_output=257), ]

        if FLAGS.target_data == 'cifar10':
            self.input_images = {'cifar100':self.image_cifar100,
                                 'tiered1':self.image_tiered1,
                                 'tiered2':self.image_tiered2}
            
            self.input_labels = {'cifar100':self.label_cifar100,
                                 'tiered1':self.label_tiered1,
                                 'tiered2':self.label_tiered2}
            
            self.train_target_nets = {'cifar100':nets_cifar100,
                                      'tiered1':nets_tiered1,
                                      'tiered2':nets_tiered2}
            
        elif FLAGS.target_data == 'cifar100':
            self.input_images = {'cifar10': self.image_cifar10,
                                 'tiered1': self.image_tiered1,
                                 'tiered2': self.image_tiered2}
            
            self.input_labels = {'cifar10': self.label_cifar10,
                                 'tiered1': self.label_tiered1,
                                 'tiered2': self.label_tiered2}
            
            self.train_target_nets = {'cifar10': nets_cifar10,
                                      'tiered1': nets_tiered1,
                                      'tiered2': nets_tiered2}

        elif FLAGS.target_data == 'tiered1':
            self.input_images = {'cifar10': self.image_cifar10,
                                 'cifar100': self.image_cifar100,
                                 'tiered2': self.image_tiered2}
    
            self.input_labels = {'cifar10': self.label_cifar10,
                                 'cifar100': self.label_cifar100,
                                 'tiered2': self.label_tiered2}
    
            self.train_target_nets = {'cifar10': nets_cifar10,
                                      'cifar100': nets_cifar100,
                                      'tiered2': nets_tiered2}

        elif FLAGS.target_data == 'tiered2':
            self.input_images = {'cifar10': self.image_cifar10,
                                 'cifar100': self.image_cifar100,
                                 'tiered1': self.image_tiered1}
    
            self.input_labels = {'cifar10': self.label_cifar10,
                                 'cifar100': self.label_cifar100,
                                 'tiered1': self.label_tiered1}
    
            self.train_target_nets = {'cifar10': nets_cifar10,
                                      'cifar100': nets_cifar100,
                                      'tiered1': nets_tiered1}

        else:
            pass
        
        self.target_net_vars = {}
        self.target_vars = []
        self.target_loaders = {}
        self.test_clean_outputs = {}
        self.test_attack_outputs = {}

        self.scope_mapping = {'cifar10': 'Target',
                              'cifar100': 'Target_100',
                              'tiered1': 'Tiered_Target1',
                              'tiered2': 'Tiered_Target2'}

        self.train_data_names = FLAGS.train_data_names
        
        
    def construct_training_graph(self):
        self.training_substitute_acc = tf.zeros(shape=[1, ])
        self.training_attack_images = {}
        
        self.training_l2_distortion_mean = []
        self.training_l2_distortion_all = []
        self.training_l1_distortion_all = []
        self.training_l_inf_distortion_all = []
        
        
        for data_id in self.train_data_names:
            image = self.input_images[data_id]
            
            with tf.variable_scope('Substitute', reuse=tf.AUTO_REUSE):
                substitute_output = self.substitute_net.forward(image, True)
                self.substitute_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Substitute')

            prob = tf.nn.softmax(substitute_output)+1e-12
            log_prob = tf.log(prob)
            entropy = -prob * log_prob
            substitute_loss = entropy

            grad = tf.gradients(substitute_loss, image)[0]
           
            abs_grad = tf.abs(grad)
            l1_norm = tf.reduce_sum(abs_grad, axis=[1, 2, 3], keep_dims=True)
            norm_grad = grad / l1_norm

            mean_abs_grad = tf.reduce_mean(tf.abs(grad), axis=[1, 2, 3], keep_dims=True)
            norm_one_grad = grad / mean_abs_grad

            sign_grad = tf.sign(grad)
            atan_grad = tf.atan(norm_one_grad) * (2 / 3.1415926)
            norm_grad = norm_grad + 0.01*sign_grad + 0.01 * atan_grad

            attack = image + self.attack_step_size * norm_grad
            attack = tf.clip_by_value(attack, 0.0, 1.0)
            
            self.training_attack_images[data_id] = attack
            
            distortion = attack - image
            l2_distortion = tf.square(distortion)
            training_l2_distortion_mean = tf.sqrt(tf.reduce_mean(l2_distortion))
            l2_distortion_all = tf.sqrt(tf.reduce_sum(l2_distortion, axis=[1,2,3]))
            training_l2_distortion_all = tf.reduce_mean(l2_distortion_all)

            l1_distortion = tf.abs(distortion)
            l1_distortion_all = tf.reduce_sum(l1_distortion, axis=[1, 2, 3])
            training_l1_distortion_all = tf.reduce_mean(l1_distortion_all)

            l_inf_distortion = tf.reduce_max(l1_distortion, axis=[1, 2, 3])
            training_l_inf_distortion_all = tf.reduce_mean(l_inf_distortion)

            self.training_l2_distortion_mean.append(training_l2_distortion_mean)
            self.training_l2_distortion_all.append(training_l2_distortion_all)
            self.training_l1_distortion_all.append(training_l1_distortion_all)
            self.training_l_inf_distortion_all.append(training_l_inf_distortion_all)

        self.training_l2_distortion_mean = tf.reduce_mean(self.training_l2_distortion_mean)
        self.training_l2_distortion_all = tf.reduce_mean(self.training_l2_distortion_all)
        self.training_l1_distortion_all = tf.reduce_mean(self.training_l1_distortion_all)
        self.training_l_inf_distortion_all = tf.reduce_mean(self.training_l_inf_distortion_all)
        
            
        self.train_target_loss1 = {}                   # loss on the clean images
        self.train_target_accuracy1 = {}

        self.train_target_loss2 = {}
        self.train_target_accuracy2 = {}

        for data_id in self.train_data_names:
            scope_name = self.scope_mapping[data_id]
            with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE):
                train_target_accuracy1 = 0
                train_target_loss1 = 0
                train_target_accuracy2 = 0
                train_target_loss2 = 0
                
                train_target_nets = self.train_target_nets[data_id]
                label = self.input_labels[data_id]
                for net in train_target_nets:
                    net_name = net.__class__.__name__
                    net_size = net.size
                    with tf.variable_scope(net_name + str(net_size), reuse=tf.AUTO_REUSE):
                        output1 = net.forward(self.input_images[data_id], False)
                        accuracy1 = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(output1), 1),
                                                                tf.argmax(label, 1))
                        train_target_accuracy1 += accuracy1
                        loss1 = self.loss_func(logits=output1, labels=label)
                        train_target_loss1 += tf.reduce_mean(loss1)
            
                    scope = scope_name + '/' + net_name + str(net_size)
                    vars = tf.get_collection(key=tf.GraphKeys.VARIABLES, scope=scope)
                    self.target_net_vars[scope] = vars
                    self.target_vars.extend(vars)
                    self.target_loaders[scope] = tf.train.Saver(vars, max_to_keep=0)
            
                    with tf.variable_scope(net_name + str(net_size), reuse=tf.AUTO_REUSE):
                        output2 = net.forward(self.training_attack_images[data_id], False)
                        accuracy2 = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(output2), 1),
                                                                tf.argmax(self.input_labels[data_id], 1))
                        train_target_accuracy2 += accuracy2
                
                        loss2 = self.loss_func(logits=output2, labels=self.input_labels[data_id])
                        train_target_loss2 += tf.reduce_mean(loss2)
     
                train_target_loss2 = train_target_loss2 / len(train_target_nets)
                
            self.train_target_loss1[data_id] = train_target_loss1/len(train_target_nets)
            self.train_target_accuracy1[data_id] = train_target_accuracy1/len(train_target_nets)
            self.train_target_loss2[data_id] = train_target_loss2
            self.train_target_accuracy2[data_id] = train_target_accuracy2/len(train_target_nets)
            
        

    def construct_optimizing_graph(self):
        optimizer = tf.train.AdamOptimizer(self.lr)

        train_losses = 0
        for key, value in self.train_target_loss2.items():
            train_losses += value
        
        total_loss = -train_losses

        gvs = optimizer.compute_gradients(total_loss, self.substitute_vars)
        self.gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs]
        self.metatrain_op = optimizer.apply_gradients(self.gvs)


    def load_target_models(self, sess):
        
        scopes = {'cifar10': 'target_models_cifar10',
         'cifar100': 'target_models_cifar100',
         'tiered1': 'target_models_Tiered1',
         'tiered2': 'target_models_Tiered2'}
 
        for data_id in self.train_data_names:
            scope = scopes[data_id]
            target_models = self.train_target_nets[data_id]
            for model in target_models:
                model_name = model.__class__.__name__ + str(model.size)
                if model.__class__.__name__ == 'ResNet':
                    model_path = '../../'+scope+'/ResNet' + str(model.size)
                    model_name = 'ResNet' + str(model.size)
    
                elif model.__class__.__name__ == 'SeResNet':
                    model_path = '../../'+scope+'/SeResNet' + str(model.size)
                    model_name = 'SeResNet' + str(model.size)
    
                elif model.__class__.__name__ == 'DenseNet':
                    if not model.BC:
                        model_path = '../../'+scope+'/DenseNet' + str(model.size)
                        model_name = 'DenseNet' + str(model.size)
                    else:
                        model_path = '../../'+scope+'/DenseNet' + str(model.size) + 'BC'
                        model_name = 'DenseNet' + str(model.size)
    
    
                elif model.__class__.__name__ == 'SqueezeNet':
                    model_path = '../../'+scope+'/SqueezeNet' + str(model.size)
                    model_name = 'SqueezeNet' + str(model.size)
    
                elif model.__class__.__name__ == 'MobileNetV1':
                    model_path = '../../'+scope+'/Mobile1'
                elif model.__class__.__name__ == 'MobileNetV2':
                    model_path = '../../'+scope+'/Mobile2'
                elif model.__class__.__name__ == 'MobileNetV3':
                    model_path = '../../'+scope+'/Mobile3'
                elif model.__class__.__name__ == 'ShuffleNetV1':
                    model_path = '../../'+scope+'/Shuffle1'
                elif model.__class__.__name__ == 'ShuffleNetV2':
                    model_path = '../../'+scope+'/Shuffle2'
                else:
                    pass
                
                
                if scope == 'target_models_cifar10':
                    self.target_loaders['Target/'+model_name].restore(sess, model_path)
                    print(model_path)
                elif scope == 'target_models_cifar100':
                    self.target_loaders['Target_100/'+model_name].restore(sess, model_path)
                    print(model_path)
                elif scope == 'target_models_Tiered1':
                    self.target_loaders['Tiered_Target1/'+model_name].restore(sess, model_path)
                    print(model_path)
                elif scope == 'target_models_Tiered2':
                    self.target_loaders['Tiered_Target2/'+model_name].restore(sess, model_path)
                    print(model_path)
                else:
                    pass




