


import tensorflow as tf
from .gnn_utils import *

"""
GCN trainer

"""

class GCN_Trainer():
  #####################
  ## Trainer for GCN ##
  #####################
  def __init__(self, model,  dropout = 0.5):
    self.model = model
    self.dropout = dropout 
  def evaluate_on_test(self, sess, features_sp, adj_sp, y_test, test_mask):
      feed_dict_test = construct_feed_dict(features_sp, adj_sp, y_test, test_mask, self.model.placeholders) 
      feed_dict_test.update({self.model.placeholders['dropout']: 0.})
      test_cost, test_acc, pred_result, feature_space = sess.run([self.model.gcn_loss, self.model.acc, tf.nn.softmax(self.model.predictions), self.model.h2], feed_dict = feed_dict_test)
      print('Test accuracy:', test_acc)
      return test_cost, test_acc, pred_result, feature_space


  def fit(self, sess, features_sp, adj_sp, y_train, train_mask, y_val = None, val_mask = None, epoch_num = 201, print_every = 10):
        patience = 20
        curr_step = 0
        cost_val = []
        val_cost, val_acc = 0,0
        print("="*40 + "Run gcn" + "="*40)
        for epoch in range(epoch_num):
            # Construct feed dictionary
            feed_dict = construct_feed_dict(features_sp, adj_sp, y_train, train_mask, self.model.placeholders)  ## 这儿。。拿整个来做啊。。有点狠
            feed_dict.update({self.model.placeholders['dropout']: self.dropout})
            # Training step
            outs = sess.run([self.model.train_op, self.model.gcn_loss, self.model.acc], feed_dict = feed_dict)
            outs.append(0)
            # Validation
            if epoch % print_every == 0 and y_val is not None:
                feed_dict_val = construct_feed_dict(features_sp, adj_sp, y_val, val_mask, self.model.placeholders)
                feed_dict_val.update({self.model.placeholders['dropout']: 0.})
                # val step
                val_cost, val_acc = sess.run([self.model.gcn_loss, self.model.acc], feed_dict = feed_dict_val)
                #self.val_acc_list.append(val_acc) ## store vall
                cost_val.append(val_cost)
                # Print results
                print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(outs[1]), "D_loss=", "{:.5f}".format(outs[3]),
                      "train_acc=", "{:.5f}".format(outs[2]), "val_loss=", "{:.5f}".format(val_cost),
                      "val_acc=", "{:.5f}".format(val_acc))
        print("Optimization Finished!")
        return val_cost, val_acc



