from ifs_layer import *
from metrics import *
flags = tf.app.flags
FLAGS = flags.FLAGS

class IGNNS():
    def __init__(self, placeholders, input_dim, input_num, act=tf.nn.relu):
        self.placeholders = {}
        self.layers = []
        self.activations = []
        self.outputs = None
        self.opt_op = None
        self.loss = 0
        self.accuracy = 0
        self.inputs = placeholders['features']
        self.input_dim = input_dim
        self.input_num = input_num
        self.act = act
        self.output_dim = placeholders['labels'].get_shape().as_list()[1]
        self.placeholders = placeholders
        self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
        self.build()

    def _loss(self):
        # Weight decay total_loss
        for var in self.layers[0].vars.values():
            self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
        for var in self.layers[0].vars_bias.values():
            self.loss += FLAGS.bias_decay*tf.nn.l2_loss(var)
        for var in self.layers[0].vars_p.values():
            self.loss += 0.0005*tf.nn.l2_loss(var)
        # Cross entropy error
        self.loss += masked_softmax_cross_entropy(self.outputs,
                                                  self.placeholders['labels'],
                                                  self.placeholders['labels_mask'])
    def _accuracy(self):
        self.accuracy = masked_accuracy(self.outputs,
                                        self.placeholders['labels'],
                                        self.placeholders['labels_mask'])
    def _build(self):
        IFS=IFSlayer(input_dim=self.input_dim,
                                    output_dim=self.output_dim,
                                    input_num=self.input_num,
                                    placeholders=self.placeholders,
                                    act=self.act,
                                    ifs_layers_num=FLAGS.ifs_layers_num,
                                    hide_dim=FLAGS.hidden,
                                    initial_value_of_p0=FLAGS.initial_value_of_p0,
                                    learnable_p=FLAGS.learnable_p,
                                    ifs_layer_bias=FLAGS.IFS_layer_bais,
                                    output_bias=FLAGS.output_bais,
                                    learnable_r=FLAGS.learnable_r,
                     )
        self.layers.append(IFS)

    def build(self):
        self._build()
        self.activations.append(self.inputs)
        for layer in self.layers:
            hidden = layer(self.activations[-1])
            self.activations.append(hidden)
        self.outputs = self.activations[-1]
        self._loss()
        self._accuracy()
        self.opt_op = self.optimizer.minimize(self.loss)

    def predict(self):
        return tf.nn.softmax(self.outputs)