
import tensorflow as tf
from .gnn_utils import *
from .metrics import *
"""
GCN model
"""


class GCN_Model():
  def __init__(self, N, d, M, batch_size = 500,
               hidden_layer = 16, learning_rate = 1e-2, reg_constant = 5e-4, model_name = "GCN"):
    self.hidden_layer = hidden_layer 
    self.N = N
    self.d = d
    self.dim = M
    self.batch_size = batch_size
    self.learning_rate = learning_rate
    self.model_name = model_name
    self._define_placeholder()
    self.reg_constant = reg_constant 
    self._build_gcn()
    

  def _define_placeholder(self):
  ################################
  ## define placeholder for GCN ##
  ################################
    self.placeholders = {
        'support': tf.sparse_placeholder(tf.float32),
        'features': tf.sparse_placeholder(tf.float32),
        'labels': tf.placeholder(tf.float32, shape=(None, self.dim)),
        'labels_mask': tf.placeholder(tf.int32),
        'dropout': tf.placeholder_with_default(0., shape=()),
        'num_features_nonzero': tf.placeholder(tf.int32),  # helper variable for sparse dropout
        'labels_D': tf.placeholder(tf.float32, shape=(None, 2)),
        'labels_G_mask': tf.placeholder(tf.int32)
        }

  def _layer_gcn(self, support, features, dropout, num_feat_nonzero, hidden_layers, name = "GCN", sparse_support = True):
  ########################
  ## One layer  for GCN ##
  ########################
      inputs = features
      w = []
      bias = []
      input_d = self.d
      ## construct the variable
      with tf.variable_scope(name, reuse = tf.AUTO_REUSE):
          for i in range(len(hidden_layers)):
              w_i = tf.get_variable("weight" + str(i), initializer = glorot((input_d, hidden_layers[i])))
              bias_i = tf.get_variable("bias" + str(i), initializer = tf.zeros(hidden_layers[i], dtype=tf.float32))
              w.append(w_i)
              bias.append(bias_i)
              input_d = hidden_layers[i]
      if dropout != 0:
          inputs = sparse_dropout(inputs, 1 - dropout, num_feat_nonzero)
      ## construct the hidden layer
      h = []
      sparse_init = True
      for i in range(len(hidden_layers) - 1):
          h_i =  dot(support, dot(inputs, w[i], sparse = sparse_init), sparse = sparse_support)  +  bias[i]
          if dropout != 0:
              h_i = tf.nn.dropout(h_i, 1 - dropout)
          inputs = h_i
          sparse_init = False ## b/c after the first layer the dot between h_i and weight would be not sparse 
          h.append(h_i)
      h_last = dot(support, h[-1], sparse=sparse_support)
      # fix_w = np.zeros((w[-1].shape))
      # for i in range(w[-1].shape[1]):
      #   fix_w[:,i] += i
      # fix_w = tf.constant(fix_w - 1.5, dtype=tf.float32)
      predictions = dot(h_last, w[-1], sparse = False) 
      return predictions, h_last

  def _model_func(self, support, features, dropout = None, num_feat_nonzero = None):
    if dropout is None:
      dropout = self.placeholders['dropout']
    if num_feat_nonzero is None:
      num_feat_nonzero = self.placeholders['num_features_nonzero']
    return self._layer_gcn(support = support, \
                          features = features, dropout = dropout, \
                          num_feat_nonzero = num_feat_nonzero ,\
                          hidden_layers = [self.hidden_layer, self.dim], name = self.model_name) 
    
  def __call__(self,  support, features):
    return self._model_func(support, features, dropout = 0, num_feat_nonzero = 0)

  def _build_gcn(self):
      self.predictions, self.h2 = self._model_func( self.placeholders['support'], self.placeholders['features'])
      variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_name)
      gcn_vars = {var.name: var for var in variables}
      ######################################################
      ## use cross entropy as the loss for gnn classifier ##
      ######################################################
      self.loss_CE = masked_softmax_cross_entropy(self.predictions, self.placeholders['labels'], self.placeholders['labels_mask']) # Build metrics
      ###############################
      ## gnn parameter regularizer ##
      ###############################
      regularizer = tf.contrib.layers.l2_regularizer(scale = self.reg_constant)
      self.loss_weight = tf.contrib.layers.apply_regularization(regularizer, [variables[0]])
      ####################
      ## graphcgan loss ##
      ####################
      pred_label = tf.reshape( tf.log(tf.reduce_sum( tf.exp( tf.minimum( tf.constant(10, dtype = tf.float32), self.predictions ) ),1)), (-1,1))
      pred_unlabel = tf.zeros_like(pred_label)
      self.predictions_D = tf.cast(tf.concat( ( pred_label, pred_unlabel), 1), dtype=tf.float32)
      self.Dis_loss = masked_softmax_cross_entropy(self.predictions_D, \
                                                   np.concatenate( ( [[1,0]] * self.N , [[0,1]]* self.batch_size) , 0),\
                                                   np.concatenate(([True] * self.N, [False]* self.batch_size)) )
      
      self.gcn_loss = self.loss_CE + self.loss_weight + self.Dis_loss
      self.acc = masked_accuracy(self.predictions, self.placeholders['labels'], self.placeholders['labels_mask'])
      self.var_list = variables
      solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate)
      self.train_op = solver.minimize(self.gcn_loss, var_list = self.var_list)

