from consts import *

class Network():

    def __init__(self, W, B):
        assert W.shape[0] == B.shape[0]
        assert W.shape[1] == D  
        self.W = W
        self.B = B
        self.r = W.shape[0]

    def check_predriction_for_dataset(self, sess):
        current_margin = sess.run([self.mask_two_margin], {self.tf_W: self.W, self.tf_B: self.B})
        return np.sum(current_margin)

    def prepere_update_network(self, X, Y):
        self.tf_W = tf.placeholder(TYPE, name='W', shape=[self.r, D])
        self.tf_B = tf.placeholder(TYPE, name='B', shape=[self.r])
        tf_X = tf.constant(X, dtype=TF_TYPE, name="X")
        tf_Y = tf.constant(Y, dtype=TF_TYPE, name="Y")
        # Create matrix
        self.X_matrix = tf.concat([tf_X, tf.ones([X.shape[0],1], dtype=TF_TYPE)], axis=1)
        self.W_matrix = tf.concat([tf.transpose(self.tf_W), [self.tf_B]], axis=0)
        # Calc first layer
        self.first_layer_output = tf.matmul(self.X_matrix, self.W_matrix)
        self.mask_first_layer_output= tf.where(self.first_layer_output < 0, tf.zeros(self.first_layer_output.shape, dtype=TF_TYPE), self.first_layer_output)
        # Calc second layer
        self.second_layer_output = tf.reshape(tf.matmul(self.mask_first_layer_output, tf.ones([self.r,1], dtype=TF_TYPE)), [X.shape[0]]) + SECOND_LAYER_BIAS
        # Calc margin
        self.margin = tf.multiply(tf.transpose(self.second_layer_output), tf_Y)
        self.mask_one_margin = tf.where(self.margin < 0, tf.zeros(self.margin.shape, dtype=TF_TYPE), self.margin)
        self.mask_two_margin = tf.where(self.margin > 0, tf.ones(self.margin.shape, dtype=TF_TYPE), self.mask_one_margin)
        # Calc reset hinge loss map
        self.hinge_loss_map = HINGE_LOSS_CONST - self.margin
        self.mask_one_hinge_loss_map = tf.where(self.hinge_loss_map < 0, tf.zeros(self.hinge_loss_map.shape, dtype=TF_TYPE), self.hinge_loss_map)
        self.mask_two_hinge_loss_map = tf.where(self.hinge_loss_map > 0, tf.ones(self.hinge_loss_map.shape, dtype=TF_TYPE), self.mask_one_hinge_loss_map)
        # Calc relu reset map
        self.relu_reset_map = tf.transpose(self.mask_first_layer_output)
        self.mask_relu_reset_map= tf.where(self.relu_reset_map > 0, tf.ones(self.relu_reset_map.shape, dtype=TF_TYPE), self.relu_reset_map)
        # Calc update rule
        self.total_reset_map = tf.multiply(self.mask_relu_reset_map, self.mask_two_hinge_loss_map)
        self.total_map = tf.multiply(self.total_reset_map, tf_Y)
        self.w_update_rule = tf.matmul(self.total_map, tf_X)
        self.b_update_rule = tf.reduce_sum(self.total_map, axis=1)
              
    def update_network(self, sess, lr):
        global_minimum_point, local_minimum_point = True, True
        current_w_update_rule, current_b_update_rule, current_hinge_loss_map = sess.run([self.w_update_rule, self.b_update_rule, self.mask_two_hinge_loss_map], {self.tf_W: self.W, self.tf_B: self.B})
        
        # Update the weights
        self.W = self.W + lr*current_w_update_rule
        self.B = self.B + lr*current_b_update_rule
        # Check if we are in local minimum point
        local_minimum_point = np.sum(np.abs(current_w_update_rule)) == 0 and np.sum(np.abs(current_b_update_rule)) == 0
        # Check if we are in global minimum point
        non_zero_loss_sample_counter = np.sum(current_hinge_loss_map)
        global_minimum_point = non_zero_loss_sample_counter == 0
        # Return if we are in a global minimum 
        return global_minimum_point, local_minimum_point, non_zero_loss_sample_counter