

import numpy as np
import tensorflow as tf

"""
generator trainer

"""
class GraphCGAN_Trainer():
  def __init__(self, cgan_model):
    self.cgan_model = cgan_model

  def fit(self, sess, adj_dense, features_dense, epoch_num = 5):
    print("="*40 + "Run cGan" + "="*40)
    for epoch in range(epoch_num):  
      idx = np.random.choice(self.cgan_model.N, self.cgan_model.batch_size, replace = False)
      idx0 = np.array([i for i in range(self.cgan_model.N, self.cgan_model.N + self.cgan_model.batch_size)])
      feed_dict = {}
      feed_dict.update({self.cgan_model.placeholders['support_orig']: adj_dense})
      feed_dict.update({self.cgan_model.placeholders['features_orig']: features_dense})
      feed_dict.update({self.cgan_model.placeholders['sy_index']: idx})
      feed_dict.update({self.cgan_model.placeholders['sy_index0']: idx0})
      sess.run([self.cgan_model.cgan_train_op], feed_dict = feed_dict)

    return sess.run([self.cgan_model.sy_covariates, self.cgan_model.sy_links])
