from __future__ import division
from __future__ import print_function
import time
import tensorflow as tf
from models import IGNNS
from utils import *
# Settings

seed=1234
np.random.seed(seed)
tf.set_random_seed(seed)

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('dataset', 'cora', 'datasets') # cora, citeseer, pubmed

if FLAGS.dataset == 'cora':
    # test_acc=0.863,setup: symmetric,0.005,0.5,0.9,5e-3,5e-4,1000,48,5,False,True;True;False,100
    flags.DEFINE_string('neighborhood_normalization', 'symmetric', 'mean_pooling or symmetric')
    flags.DEFINE_float('learning_rate', 0.005, 'Initial learning rate.')
    flags.DEFINE_float('initial_value_of_p0', 0.5, 'Probability of upper branch of IFS.')
    flags.DEFINE_float('dropout', 0.9, 'Dropout rate (1 - keep probability).')
    flags.DEFINE_float('weight_decay', 5e-3, 'L2 total_loss on embedding matrix.')
    flags.DEFINE_float('bias_decay', 5e-4, 'L2 total_loss on bias.')
    flags.DEFINE_integer('epochs', 1000, 'Number of epochs to train.')
    flags.DEFINE_integer('hidden',48, 'Number of units in hidden layer.')
    flags.DEFINE_integer('ifs_layers_num', 5, 'Number of IFS layers.')
    flags.DEFINE_bool('learnable_p', False, 'Whether or not to train the adjoint probability vector p.')
    flags.DEFINE_bool('learnable_r', True, 'Whether or not to train the repreation layer coefficient r.')
    flags.DEFINE_bool('IFS_layer_bais', True, 'Whether or not to use bias for all iterations in IFS layer.')
    flags.DEFINE_bool('output_bais', False, 'Whether or not to use bias for output layer.')
    flags.DEFINE_integer('patience', 100, 'Early stopping.')

elif FLAGS.dataset == 'citeseer':
    # test_acc=0.751 setup: mean_pooling,0.002,0.5,0.9,5e-3,5e-4,1000,72,4,True,True,False,False,100
    flags.DEFINE_string('neighborhood_normalization', 'mean_pooling', 'mean_pooling or symmetric')
    flags.DEFINE_float('learning_rate', 0.002, 'Initial learning rate.')
    flags.DEFINE_float('initial_value_of_p0', 0.5, 'Probability of upper branch of IFS.')
    flags.DEFINE_float('dropout', 0.9, 'Dropout rate (1 - keep probability).')
    flags.DEFINE_float('weight_decay', 5e-3, 'L2 total_loss on embedding matrix.')
    flags.DEFINE_float('bias_decay', 5e-4, 'L2 total_loss on bias.')
    flags.DEFINE_integer('epochs', 1000, 'Number of epochs to train.')
    flags.DEFINE_integer('hidden', 72, 'Number of units in hidden layer.')
    flags.DEFINE_integer('ifs_layers_num', 4, 'Number of IFS layers')
    flags.DEFINE_bool('learnable_p', True, 'Whether or not to train the adjoint probability vector.')
    flags.DEFINE_bool('learnable_r', True, 'Whether or not to train the repreation layer coefficient r.')
    flags.DEFINE_bool('IFS_layer_bais', False, 'Whether or not to use bias for all iterations in IFS layer.')
    flags.DEFINE_bool('output_bais', False, 'Whether or not to use bias for output layer.')
    flags.DEFINE_integer('patience', 100, 'Early stopping.')
elif FLAGS.dataset == 'pubmed':
    # test_acc=0.805,setup: symmetric,0.01,0.5,0.8,5e-3,5e-3,1000,72,4,False,False,False,100
    flags.DEFINE_string('neighborhood_normalization', 'symmetric', 'mean_pooling or symmetric')
    flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
    flags.DEFINE_float('initial_value_of_p0', 0.5, 'Probability of upper branch of IFS.')
    flags.DEFINE_float('dropout', 0.8, 'Dropout rate (1 - keep probability).')
    flags.DEFINE_float('weight_decay', 5e-3, 'L2 total_loss on embedding matrix.')
    flags.DEFINE_float('bias_decay', 5e-3, 'L2 total_loss on bias.')
    flags.DEFINE_integer('epochs', 1000, 'Number of epochs to train.')
    flags.DEFINE_integer('hidden', 72,'Number of units in hidden layer.')
    flags.DEFINE_integer('ifs_layers_num', 4, 'Number of IFS layers')
    flags.DEFINE_bool('learnable_p', False, 'Whether or not to train the adjoint probability vector.')
    flags.DEFINE_bool('learnable_r', True, 'Whether or not to train the repreation layer coefficient r.')
    flags.DEFINE_bool('IFS_layer_bais',False, 'Whether or not to use bias for all iterations in IFS layer.')
    flags.DEFINE_bool('output_bais', False, 'Whether or not to use bias for output layer.')
    flags.DEFINE_integer('patience', 100, 'Early stopping.')

# model save path
checkpt_file = 'data/checkpoint/ignns_layer_{}_{}_{}.ckpt'.format(FLAGS.ifs_layers_num,FLAGS.hidden,FLAGS.dataset)

# Load data
adj, adj_up, adj_low, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(FLAGS.dataset)

# preprocessing
features = preprocess_features(features)
support_up = preprocess_adj(adj_up, normalizing=FLAGS.neighborhood_normalization)
support_low = preprocess_adj(adj_low, normalizing=FLAGS.neighborhood_normalization)

# Define placeholders
placeholders = {
    'support_up': tf.sparse_placeholder(tf.float32),
    'support_low': tf.sparse_placeholder(tf.float32),
    'features': tf.sparse_placeholder(tf.float32, shape=tf.constant(features[2], dtype=tf.int64)),
    'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
    'labels_mask': tf.placeholder(tf.int32),
    'dropout': tf.placeholder_with_default(0., shape=()),
    'num_features_nonzero': tf.placeholder(tf.int32)}

# Create model

act=tf.nn.relu
# act=identity_function

model = IGNNS(placeholders,
              input_dim=features[2][1],
              input_num=features[2][0],
              act=act)


def construct_feed_dict(features,support_up,support_low,labels,labels_mask,placeholders):
    """Construct feed dictionary."""
    feed_dict = dict()
    feed_dict.update({placeholders['labels']: labels})
    feed_dict.update({placeholders['labels_mask']: labels_mask})
    feed_dict.update({placeholders['features']: features})
    feed_dict.update({placeholders['support_up']: support_up})
    feed_dict.update({placeholders['support_low']: support_low})
    feed_dict.update({placeholders['num_features_nonzero']: features[1].shape})
    return feed_dict

# Define model evaluation function

def evaluate(features,support_up,support_low, labels,mask,placeholders):
    t_test = time.time()
    feed_dict_val = construct_feed_dict(features,support_up,support_low,labels,mask,placeholders)
    outs_val = sess.run([model.loss, model.accuracy],feed_dict=feed_dict_val)
    return outs_val[0], outs_val[1], (time.time() - t_test)


# Initialize session
sess = tf.Session()
saver = tf.train.Saver()
# Init variables
sess.run(tf.global_variables_initializer())

best=0#model select
ms=0
patience=0
start_choosing=50# Select models from start_choosing epochs
# Train model
time_train = []
for epoch in range(1,FLAGS.epochs+1):
    t = time.time()
    # Construct feed dictionary
    feed_dict = construct_feed_dict(features,support_up,support_low,y_train,train_mask,placeholders)
    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
    # Training step
    _,train_loss,train_acc = sess.run([model.opt_op,model.loss,model.accuracy],feed_dict=feed_dict)
    # Validation
    val_loss, val_acc, duration = evaluate(features,support_up,support_low,y_val,val_mask,placeholders)
    T=time.time() - t
    time_train.append(T)
    # Model select
    if FLAGS.dataset in ['cora','pubmed'] and epoch>=start_choosing:
        ms = val_acc
        if ms>best:
            best=ms
            patience=0
            saver.save(sess, checkpt_file)
            print("The best model is saved!")
        else:
            patience+=1
    if FLAGS.dataset in ['citeseer'] and epoch>=start_choosing*4:
        ms = val_acc
        if ms>best:
            best = ms
            patience=0
            saver.save(sess, checkpt_file)
            print("The best model is saved!")
        else:
            patience+=1
    # Early stopping!
    if patience>=FLAGS.patience:
        print("Eearly stipping!")
        print("Total training time:","{:.1f}".format(sum(time_train)))
        print("Average training time per epoch:","{:.4f}".format(sum(time_train)/epoch))
        break

    # print results during trianing step
    # note that test set not involved in model selection
    _, test_acc, _ = evaluate(features, support_up, support_low, y_test, test_mask, placeholders)
    print("Epoch:", '%04d' % (epoch),
          "train_loss=", "{:.5f}".format(train_loss),
          "train_acc=", "{:.5f}".format(train_acc),
          "val_loss=", "{:.5f}".format(val_loss),
          "val_acc=", "{:.5f}".format(val_acc),
          "train_time=", "{:.5f}".format(T),
          "test_acc=", "{:.5f}".format(test_acc)
    )
sess.close()


print("Training completed!")

# test the performance of stored model
sess=tf.Session()
saver = tf.train.Saver()
saver.restore(sess, checkpt_file)
test_loss, test_acc, test_duration = evaluate(features,support_up,support_low,y_test,test_mask,placeholders)
print("Test set results:",
      "text_loss=", "{:.5f}".format(test_loss),
      "test_accuracy=", "{:.5f}".format(test_acc),
      "total_text_time=", "{:.5f}".format(test_duration))
sess.close()