import numpy as np
import tensorflow as tf
import datetime

from ICE_model import ICE
from tensorflow.python.platform import flags
from data_cifar import Dataloader_CF
from data_tiered import Dataloader


FLAGS = flags.FLAGS

flags.DEFINE_integer('train_iterations', 70000, 'number of training iterations.')

flags.DEFINE_integer('batch_size', 64, 'number of examples sampled per dataset')
flags.DEFINE_float('lr', 0.001, 'the learning rate of ICE')
flags.DEFINE_float('dropout_rate', 0, 'dropout_rate')
flags.DEFINE_string('network', 'ResNet12', 'network name')
flags.DEFINE_integer('base_num_filters', 32, 'number of filters for conv nets')
flags.DEFINE_bool('data_aug', True, 'if false, do not use data augmentation.')

## Logging, saving, and testing options
flags.DEFINE_string('logdir', 'logs/', 'directory for checkpoints.')
flags.DEFINE_bool('train', True, 'True to train, False to test.')

flags.DEFINE_float('train_attack_step_size', 11.7, 'the base learning rate of the generator')
flags.DEFINE_integer('num_outputs', 1000, 'the output classes of the surrogate model')
flags.DEFINE_integer('attack_decay_iter', 3000, 'number of iteration that the train_attack_step_size should decay')

flags.DEFINE_string('target_data', 'cifar100', 'network name')
flags.DEFINE_list('train_data_names',['cifar10', 'cifar100', 'tiered1', 'tiered2'], '')
FLAGS.train_data_names.remove(FLAGS.target_data)
print(FLAGS.train_data_names)



def train(model, saver, sess, exp_string, data_loaders):
    PRINT_INTERVAL = 50
    TEST_PRINT_INTERVAL = PRINT_INTERVAL*10

    print(exp_string)
    print('Done initializing, starting training.')

    target_losses1, target_accs1, target_losses2, target_accs2 = {}, {}, {}, {}
    l2_distortions_mean, l2_distortions_all, l1_distortions_all, l_inf_distortions_all = [], [], [], []


    for itr in range(0, FLAGS.train_iterations):
        lr = FLAGS.lr
        feed_dict = {model.lr: lr}

        feed_dict_data = {}

        for key, data_loader in data_loaders.items():
            if 'cifar' in key:
                batch_images, batch_labels, _ = data_loader.get_batch_data(FLAGS.batch_size, train=True)
                feed_dict[model.input_images[key]] = batch_images
                feed_dict[model.input_labels[key]] = batch_labels
                
            else:
                batch_files, batch_labels, _ = data_loader.get_batch_data(FLAGS.batch_size, train=True)
                feed_dict_data[data_loader.image_lists] = batch_files
                sess.run(data_loader.iterator, feed_dict=feed_dict_data)
                batch_images = sess.run(data_loader.out_images)
                feed_dict[model.input_images[key]] = batch_images
                feed_dict[model.input_labels[key]] = batch_labels


        if FLAGS.attack_decay_iter > 0:
            attack_step_size = FLAGS.train_attack_step_size * 0.9 ** int(itr / FLAGS.attack_decay_iter)
            if int(itr % FLAGS.attack_decay_iter) < 2:
                print('change the attack step size to:' + str(attack_step_size) + ', ----------------------------')
        else:
            attack_step_size = FLAGS.train_attack_step_size

        feed_dict[model.attack_step_size] = attack_step_size

        input_tensors = [model.metatrain_op]
        input_tensors.extend([model.train_target_loss1, model.train_target_accuracy1])
        input_tensors.extend([model.train_target_loss2, model.train_target_accuracy2])
        input_tensors.extend([model.training_l2_distortion_mean, model.training_l2_distortion_all,
                              model.training_l1_distortion_all, model.training_l_inf_distortion_all])

        result = sess.run(input_tensors, feed_dict)
        
        for data_id in FLAGS.train_data_names:
            if data_id not in target_losses1.keys():
                target_losses1[data_id] = [result[1][data_id]]
            else:
                target_losses1[data_id].append(result[1][data_id])
            
            if data_id not in target_accs1.keys():
                target_accs1[data_id] = [result[2][data_id]]
            else:
                target_accs1[data_id].append(result[2][data_id])
            
            if data_id not in target_losses2.keys():
                target_losses2[data_id] = [result[3][data_id]]
            else:
                target_losses2[data_id].append(result[3][data_id])
            
            if data_id not in target_accs2.keys():
                target_accs2[data_id] = [result[4][data_id]]
            else:
                target_accs2[data_id].append(result[4][data_id])
                

        # grads = result[7]
        l2_distortions_mean.append(result[5])
        l2_distortions_all.append(result[6])
        l1_distortions_all.append(result[7])
        l_inf_distortions_all.append(result[8])

        if (itr!=0) and itr % PRINT_INTERVAL == 0:
            print_str = 'Iteration ' + str(itr)
            
            for data_id in FLAGS.train_data_names:
                target_losses1[data_id] = np.mean(target_losses1[data_id])
                target_accs1[data_id] = np.mean(target_accs1[data_id])
                print_str += ':::' + str('%.4f' %np.mean(target_losses1[data_id])) + ', ' + \
                             str('%.4f' %np.mean(target_accs1[data_id]))

            for data_id in FLAGS.train_data_names:
                target_losses2[data_id] = np.mean(target_losses2[data_id])
                target_accs2[data_id] = np.mean(target_accs2[data_id])
                print_str += ', ' + str('%.4f' % np.mean(target_losses2[data_id])) + ', ' + str(
                    '%.4f' % np.mean(target_accs2[data_id]))

            print_str += ', ' + str('%.4f' % np.mean(l2_distortions_mean)) + ', ' + str(
                '%.4f' % np.mean(l_inf_distortions_all))


            print(str(datetime.datetime.now())[:-7], print_str)
            target_losses1, target_accs1, target_losses2, target_accs2 = {}, {}, {}, {}
            l2_distortions_mean, l2_distortions_all, l1_distortions_all, l_inf_distortions_all = [], [], [], []


        if (itr!=0) and itr % (2*TEST_PRINT_INTERVAL) == 0:
            model_name = FLAGS.logdir + '/' + exp_string + '/model' + str(itr)
            saver.save(sess, model_name)
        
    saver.save(sess, FLAGS.logdir + '/' + exp_string +  '/model' + str(itr))
    
    
    
def main():
    FLAGS.logdir = 'logs/'

    data_loaders = {}

    for data_id in FLAGS.train_data_names:
        if 'cifar' in data_id:
            data_loaders[data_id] = Dataloader_CF(data=data_id)
        else:
            data_loaders[data_id] = Dataloader(data=data_id)
            
    model = ICE()

    model.construct_training_graph()
    model.construct_optimizing_graph()

    substitute_vars = tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES, scope='Substitute')
    for var in substitute_vars:
        if 'Adam' not in var.name:
            print(var)
    saver = loader = tf.train.Saver(substitute_vars, max_to_keep=0)

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)

    exp_string = str(FLAGS.network) + '_' + str(FLAGS.target_data)
    exp_string += '.lr_' + str(FLAGS.lr)
    exp_string += '.drop_' + str(FLAGS.dropout_rate) + '.nfs_' + str(FLAGS.base_num_filters)
    exp_string += '.DA_' + str(FLAGS.data_aug)[0]
    exp_string += '.alr_' + str(FLAGS.train_attack_step_size)
    exp_string += '.no_' + str(FLAGS.num_outputs)
    exp_string += '.adecay_' + str(FLAGS.attack_decay_iter)
 
    tf.global_variables_initializer().run()
    print('loading target models')
    model.load_target_models(sess=sess)
    train(model, saver, sess, exp_string, data_loaders)
    

if __name__ == "__main__":
    main()





