import numpy as np
import tensorflow as tf
import datetime

from ICE_model import ICE
from tensorflow.python.platform import flags
import os
from data_cifar import Dataloader_CF
from data_tiered import Dataloader


FLAGS = flags.FLAGS


# Training options
flags.DEFINE_integer('batch_size', 100, 'number of images per test iteration')
flags.DEFINE_string('network', 'resnet12', 'network name')
flags.DEFINE_integer('base_num_filters', 32, 'number of filters for conv nets.')

flags.DEFINE_bool('data_aug', True, '')
flags.DEFINE_float('dropout_rate', 0.0, '')

flags.DEFINE_float('test_eps', 0.05882359, 'the base learning rate of the generator')
flags.DEFINE_float('test_attack_step_size', 0.00588239, 'the base learning rate of the generator')
flags.DEFINE_integer('num_test_steps', 10, 'number of iteration that the meta lr should decay')

flags.DEFINE_string('substitute_file', '../', 'the directory for the trained checkpoints')
flags.DEFINE_integer('num_outputs', 1000, 'the output classes of the surrogate model')

flags.DEFINE_list('train_data_names',['cifar10', 'cifar100', 'tiered1', 'tiered2'], '')
flags.DEFINE_string('target_data', 'cifar100', 'network name')
FLAGS.train_data_names.remove(FLAGS.target_data)

FLAGS.test_attack_step_size = FLAGS.test_eps/FLAGS.num_test_steps
LEN_MODELS = 100


def test(model, sess, data_loador, itr):
    iter = 0
    val_target_losses1, val_target_accs1, val_target_losses2, val_target_accs2 = [], [], [], []
    val_l2_distortions_mean, val_l2_distortions_all, val_l1_distortions_all, val_l_inf_distortions_all = [], [], [], []
    
    val_clean_outputs = {}
    val_attack_outputs = {}
    val_labels = []
    val_end = False
    while not val_end:
        iter += 1
        
        val_feed_dict = {}
        feed_dict_data = {}
        
        if 'cifar' in FLAGS.target_data:
            val_batch_images, val_batch_labels, val_end = data_loador.get_batch_data(FLAGS.batch_size, train=False)
            val_feed_dict[model.image] = val_batch_images
            val_feed_dict[model.label] = val_batch_labels
            #print(val_batch_images.shape, val_batch_labels.shape)
        else:
            val_batch_files, val_batch_labels, val_end = data_loador.get_batch_data(FLAGS.batch_size, train=False)
            feed_dict_data[data_loador.image_lists_val] = val_batch_files
            sess.run(data_loador.iterator_val, feed_dict=feed_dict_data)
            val_batch_images = sess.run(data_loador.out_images_val)
            val_feed_dict[model.image] = val_batch_images
            val_feed_dict[model.label] = val_batch_labels
            
        val_labels.append(val_batch_labels)
        
        val_feed_dict[model.eps] = FLAGS.test_eps
        val_feed_dict[model.attack_step_size] = FLAGS.test_attack_step_size

        input_tensors = [model.test_target_loss1, model.test_target_loss2]
        input_tensors.extend([model.testing_l2_distortion_mean, model.testing_l2_distortion_all,
                              model.testing_l1_distortion_all, model.testing_l_inf_distortion_all])
        input_tensors.extend([model.test_clean_outputs, model.test_attack_outputs])
        
        result = sess.run(input_tensors, val_feed_dict)
        val_target_losses1.append(result[0])
        val_target_losses2.append(result[1])
        val_l2_distortions_mean.append(result[2])
        val_l2_distortions_all.append(result[3])
        val_l1_distortions_all.append(result[4])
        val_l_inf_distortions_all.append(result[5])
        
        clean_output = result[6]
        for key, value in clean_output.items():
            if key not in val_clean_outputs.keys():
                val_clean_outputs[key] = [value]
            else:
                val_clean_outputs[key].append(value)
        
        attack_output = result[7]
        for key, value in attack_output.items():
            if key not in val_attack_outputs.keys():
                val_attack_outputs[key] = [value]
            else:
                val_attack_outputs[key].append(value)
    
    
    val_labels = np.concatenate(val_labels, axis=0)
    val_labels = np.argmax(val_labels, axis=1)
    print(val_labels.shape)
    val_clean_accuracy = {}
    val_attack_accuracy = {}
    attack_accuracies = []
    clean_accuracies = []
    print_clean_str = 'Validation  clean ' + str(itr)
    print_attack_str = 'Validation attack ' + str(itr)
    print_attack_success_rate = 'Validation_attack ' + str(itr)
    
    for key in val_clean_outputs.keys():
        clean_value = np.concatenate(val_clean_outputs[key], axis=0)
        clean_predict = np.argmax(clean_value, axis=1)
        clean_success = val_labels == clean_predict
        clean_accuracy = np.mean(clean_success)
        
        val_clean_accuracy[key] = clean_accuracy
        clean_accuracies.append(clean_accuracy)
        print_clean_str += ': ' + str('%.5f' % clean_accuracy)
        
        attack_value = np.concatenate(val_attack_outputs[key], axis=0)
        attack_predict = np.argmax(attack_value, axis=1)
        attack_success = val_labels == attack_predict
        attack_success_on_clean_success = attack_success * clean_success
        attack_accuracy_on_clean_success = np.sum(attack_success_on_clean_success) / np.sum(clean_success)
        
        val_attack_accuracy[key] = attack_accuracy_on_clean_success
        attack_accuracies.append(attack_accuracy_on_clean_success)
        print_attack_str += ': ' + str('%.5f' % attack_accuracy_on_clean_success)
        attack_success_on_clean_success = (1 - attack_accuracy_on_clean_success)
        print_attack_success_rate += ': ' + str('%.5f' % attack_success_on_clean_success)
    
    mean_acc_clean = np.mean(clean_accuracies)
    std_acc_clean = np.std(clean_accuracies)
    ci95_clean = 1.96 * std_acc_clean / np.sqrt(len(clean_accuracies))
    mean_acc_attack = np.mean(attack_accuracies)
    std_acc_attack = np.std(attack_accuracies)
    ci95_attack = 1.96 * std_acc_attack / np.sqrt(len(attack_accuracies))
    
    print_clean_str += ': ' + str('%.5f' % mean_acc_clean) + ', ' + str('%.5f' % ci95_clean)
    print_attack_str += ': ' + str('%.5f' % mean_acc_attack) + ', ' + str('%.5f' % ci95_attack)
    
    print_str = 'Validation ' + str(itr)
    print_str += ':' + str('%.5f' % np.mean(val_target_losses1)) + ', ' + str('%.5f' % np.mean(val_target_losses2)) \
                 + ':::' + str('%.5f' % np.mean(val_l2_distortions_mean)) + ', ' + str(
        '%.5f' % np.mean(val_l2_distortions_all)) \
                 + ', ' + str('%.5f' % np.mean(val_l1_distortions_all)) + ', ' + str(
        '%.5f' % np.mean(val_l_inf_distortions_all))
    
    print('------------------------------------------', itr)
    print(print_str)
    print(print_clean_str)
    print(print_attack_str)
    print(print_attack_success_rate)
    print('------------------------------------------')
    
    return mean_acc_attack




def main():
    if 'cifar' in FLAGS.target_data:
        data_loador = Dataloader_CF(data=FLAGS.target_data)
    else:
        data_loador = Dataloader(data=FLAGS.target_data)

    model = ICE()
    
    print('create testing graph')
    model.construct_testing_graph(FLAGS.num_test_steps)

    substitute_vars = tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES, scope='Substitute')

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)

    print('----------------------------------')
    for net in model.test_target_nets:
        net_name = net.__class__.__name__
        net_size = net.size
        print(FLAGS.target_data + '/'+net_name+str(net_size))
    print('----------------------------------')
 
    test_setting = 'step_' + str(FLAGS.num_test_steps) + '.size_' + str(FLAGS.test_attack_step_size)
    test_setting += '.eps_' + str(FLAGS.test_eps) + '_' + str(FLAGS.target_data) \
                      + '.out' + str(FLAGS.num_outputs)
    
    print('Testing setting:', test_setting)

    tf.global_variables_initializer().run()
    print('loading target models')
    model.load_target_models(sess=sess)


    models = os.listdir(FLAGS.substitute_file)
    model_epochs = []
    for model_file in models:
        if 'model' in model_file and 'index' in model_file:
            i = model_file.find('del')
            j = model_file.find('.')
            model_epoch = model_file[i + 3:j]
            model_epochs.append(int(model_epoch))

    model_epochs.sort()
    #print(model_epochs)
    saver = loader = tf.train.Saver(substitute_vars, max_to_keep=0)
    Min_acc = 1.0
    max_epoch = 0
    for epoch in model_epochs:
        model_file = FLAGS.substitute_file + '/model' + str(epoch)
        saver.restore(sess, model_file)
        print(str(datetime.datetime.now())[:-7], "testing model: " + model_file)
        mean_acc = test(model, sess, data_loador, epoch)
        
        if mean_acc < Min_acc:
            Min_acc = mean_acc
            max_epoch = epoch
        print('Val min attack acc:', Min_acc, 'Epoch:', max_epoch)




if __name__ == "__main__":
    main()





