import numpy as np
import tensorflow as tf
from NeuralNet.NeuralNetPy import NeuralNet


class LinearNet(NeuralNet):
    def __init__(self, sess, train_data, test_data=None, data_recorder=None):
        super(LinearNet, self).__init__(sess, train_data, test_data, data_recorder)

        # attributes for build network graph
        self._input_data = None
        self._output_data = None
        self._global_step = None
        self._loss = None
        self._gradient = None
        self._debug = None
        self._operator = None
        self._gradient_norm = None
        self._train_step = None
        self._merged = None
        self._current_lr = None

    def build_model(self, hid_layer_num=0, hid_layer_dim=None, layer_init_scale=None):

        # for comments refer to build_model in NeuralNet/DeepNet/DeepNetPy
        if not self.check_parameters():
            # run build model after parameters are set
            raise Exception("Cannot build model, need set parameter first")
        self._input_data = tf.compat.v1.placeholder(dtype=self._dtype, shape=[None, self._input_dim])
        self._output_data = tf.compat.v1.placeholder(dtype=self._dtype, shape=[None, self._output_dim])
        self._global_step = tf.compat.v1.placeholder(dtype=tf.int32)

        # construct network graph
        self._operator = tf.eye(self._output_dim)
        scaling = tf.constant(1, dtype=self._dtype)
        with tf.compat.v1.variable_scope('LinearNet'):
            if hid_layer_num is 0:
                W_temp = tf.compat.v1.get_variable(shape=[self._output_dim, self._input_dim], dtype=self._dtype,
                                                   trainable=True,
                                                   initializer=tf.random_normal_initializer(
                                                       stddev=self._init_scale / np.sqrt(self._input_dim)), name='W')
                self._operator = tf.matmul(self._operator, W_temp)
            else:
                if type(hid_layer_dim) is list:
                    if len(hid_layer_dim) is not hid_layer_num:
                        raise Exception('Error: Wrong specification for hidden layers')
                    else:
                        layer_dim = [self._output_dim] + hid_layer_dim + [self._input_dim]

                for idx in range(hid_layer_num + 1):
                    if layer_init_scale is None:
                        W_temp = tf.compat.v1.get_variable(shape=[layer_dim[idx], layer_dim[idx + 1]],
                                                           dtype=self._dtype,
                                                           trainable=True,
                                                           initializer=tf.random_normal_initializer(
                                                               stddev=self._init_scale / np.sqrt(layer_dim[idx + 1])),
                                                           name='W{}'.format(hid_layer_num + 1 - idx))
                    else:
                        W_init = np.random.normal(0, 1, [layer_dim[idx], layer_dim[idx + 1]])
                        W_init = W_init * layer_init_scale[idx]
                        W_temp = tf.compat.v1.get_variable(shape=[layer_dim[idx], layer_dim[idx + 1]],
                                                           dtype=self._dtype,
                                                           trainable=True,
                                                           initializer=tf.constant_initializer(
                                                               W_init),
                                                           name='W{}'.format(hid_layer_num + 1 - idx))
                    self._operator = tf.matmul(self._operator, W_temp)

        # compute loss and apply gradient
        self._loss = tf.reduce_mean(
            tf.math.squared_difference(tf.matmul(self._input_data, self._operator, transpose_b=True),
                                       self._output_data))

        self._gradient = tf.compat.v1.train.GradientDescentOptimizer(self._learning_rate).compute_gradients(self._loss)
        self._gradient_norm = tf.math.add_n(
            [tf.norm(gradient_component[0], ord='fro', axis=(0, 1)) for gradient_component in
             self._gradient])

        learning_rate = tf.compat.v1.train.exponential_decay(self._learning_rate, self._global_step,
                                                             self._learning_rate_decay_every,
                                                             self._learning_rate_decay, staircase=True)
        self._current_lr = learning_rate
        self._train_step = tf.compat.v1.train.GradientDescentOptimizer(learning_rate).apply_gradients(
            self._gradient)

        self._debug = tf.constant(False)
        self._sess.run(tf.compat.v1.global_variables_initializer())

    def get_prediction_acc(self, inputs, outputs):
        return None

    def update(self, inputs_batch, outputs_batch, global_step_count):
        _, batch_loss, grad_norm, lr, debug = self.sess.run(
            fetches=(self._train_step, self._loss, self._gradient_norm, self._current_lr, self._debug),
            feed_dict={self._input_data: inputs_batch,
                       self._output_data: outputs_batch,
                       self._global_step: global_step_count})

        return batch_loss, grad_norm, lr, debug

    def get_loss(self, inputs, outputs):
        loss = self._sess.run(
            fetches=self._loss,
            feed_dict={self._input_data: inputs,
                       self._output_data: outputs,
                       self._global_step: 1})

        return loss

    def get_operator(self):
        # get the equivalent linear operator of the network
        return self._sess.run(fetches=self._operator,
                              feed_dict={self._input_data: self._train_data['inputs'],
                                         self._output_data: self._train_data['outputs'],
                                         self._global_step: 1})

    def stop_crit_check(self, **kwargs):
        if kwargs['loss_val'] <= 1e-7:
            print("Quitting, stopping criteria satisfied: training loss < 1e-7")
            return True
        else:
            return False

    def test_func(self):
        pass


if __name__ == '__main__':
    pass
