
import tensorflow as tf
from .generator_utils_toy_mnist import *

"""
Generator model 
"""
class GraphCGAN_Generator():
  def __init__(self, N, d, M, attr_model, gnn_model, learning_rate_cgan = 5e-3, batch_size = 500, noise_dim = 8):
    ## batch_size for complementary nodes
    self.batch_size = batch_size
    self.noise_dim = noise_dim
    self.N = N 
    self.d = d
    self.dim = M
    self.learning_rate_cgan = learning_rate_cgan
    self.attr_model = attr_model
    self.gnn_model = gnn_model
    self._define_placeholder()
    self._build_CGAN()

  def _define_placeholder(self):
    self.placeholders = { 
        'sy_index': tf.placeholder(tf.int32, shape = self.batch_size),
        'sy_index0': tf.placeholder(tf.int32, shape  = self.batch_size),
        'features_orig': tf.placeholder(tf.float32, [None, self.d]),
        'support_orig': tf.placeholder(tf.float32, [None, self.N])
    }
  def _build_CGAN(self, cov_hidden_layer = 128, link_hidden_layer = 128):
    self.z = sample_noise(self.batch_size, dim = self.noise_dim)
    #######################
    ## generate attribute #
    #######################
    self.sy_covariates = generator_covariate(self.z, self.d, hidden_layer = cov_hidden_layer, name = 'cGan_attr')
    #self.sy_covariates_norm = self.sy_covariates / tf.reshape(tf.reduce_sum(self.sy_covariates , 1), (-1,1))
    new_attr = tf.concat((self.placeholders['features_orig'], self.sy_covariates), 0)
    new_attr_sp = dense_to_sparse(new_attr, tf.int64)
    _, self.attr_feature_space = self.attr_model(dense_to_sparse(tf.eye(self.N + self.batch_size, dtype = tf.float32), tf.int64), new_attr_sp )
    ##################
    ## generate link #
    ##################
    self.sy_links = generator_link(self.z, self.N, hidden_layer = link_hidden_layer, name = 'cGan_link') 
    left_u = self.placeholders['support_orig']
    right_u = tf.transpose(self.sy_links)
    left_l = self.sy_links
    right_l = tf.eye(self.batch_size, dtype = tf.float32)
    new_support = tf.concat( (tf.concat((left_u, right_u),1), tf.concat((left_l, right_l), 1)), 0)
    new_support_sp = dense_to_sparse(new_support, tf.int64)
    _, self.link_feature_space = self.gnn_model(new_support_sp, new_attr_sp )

    ###########################
    ## feature matching loss ##
    ###########################
    med_attr = tf.gather(self.attr_feature_space, self.placeholders['sy_index'], axis = 0)
    med_attr_0 = tf.gather(self.attr_feature_space, self.placeholders['sy_index0'], axis = 0)
    loss_fm_attr = tf.reduce_sum((tf.reduce_mean(med_attr_0, 0) - tf.reduce_mean(med_attr,0)) ** 2)

    med_link = tf.gather(self.link_feature_space, self.placeholders['sy_index'], axis = 0)
    med_link_0 = tf.gather(self.link_feature_space, self.placeholders['sy_index0'], axis = 0)
    loss_fm_link = tf.reduce_sum((tf.reduce_mean(med_link_0,0) - tf.reduce_mean(med_link,0)) ** 2)
    
    ###################
    ## pullaway loss ##
    ###################

    loss_pt_attr = pullaway_loss(med_attr_0)
    loss_pt_link = pullaway_loss(med_link_0)

    self.cgan_loss = 10 * loss_fm_attr + loss_pt_attr 
    self.cgan_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'cGan')
    solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate_cgan, name = "cGan")
    self.cgan_train_op = solver.minimize(self.cgan_loss, var_list = self.cgan_var_list)



