import numpy as np
import tensorflow as tf
import optuna
import os

class Model(object):
    def __init__(self, input_node1,input_node2, hidden_layers_node, output_node, learning_rate, batch_size, display_step, activation,
            seed=1,
            feature_selection=False,
            a = 1,
            sigma = 0.1,
            lam1=0.5,lam2=0.5,
            param_search=False,initilize=True,u=None,v=None
        ): #Note: a, sigma, lam should be set by params dict that will be passed to this class.
        self.param_search = param_search
        # Register hyperparameters for feature selection
        self.a = a
        self.sigma = sigma
        self.lam1 = lam1
        self.lam2 = lam2
        # Register regular hyperparameters
        self.lr = learning_rate
        self.output_node=output_node
        use_all_singular_values = True
        self.batch_size = batch_size
        self.display_step = display_step 

        G = tf.Graph()
        with G.as_default():
            self.sess = tf.Session(graph=G)
            # tf Graph Input
            X1 = tf.placeholder(tf.float32, [None, input_node1])
            X2 = tf.placeholder(tf.float32, [None, input_node2])
            train_gates = tf.placeholder(tf.float32, [1], name='train_gates')
            self.nnweights = []
            prev_node1 = input_node1
            prev_x1 = X1
            with tf.variable_scope('gates1', reuse=tf.AUTO_REUSE):
                if initilize:
                    self.alpha1 = tf.get_variable('alpha1', [prev_node1,],
                                              initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01))
                else:
                    self.alpha1 = tf.get_variable('alpha1',[prev_node1,],
                                              initializer=tf.truncated_normal_initializer(mean=u, stddev=0.01))
                prev_x1 = self.feature_selector1(prev_x1, train_gates)

            layer_name = 'layer1' + str(1)
            for i in range(len(hidden_layers_node)):
                layer_name = 'layer1' + str(i+1)
                with tf.variable_scope(layer_name, reuse=tf.AUTO_REUSE):
                    weights = tf.get_variable('weights1', [prev_node1, hidden_layers_node[i]],
                                              initializer=tf.truncated_normal_initializer(stddev=0.1))
                    self.nnweights.append(weights)
                    biases = tf.get_variable('biases1', [hidden_layers_node[i]],
                                             initializer=tf.constant_initializer(0.0))
                    layer_out1 = (tf.matmul(prev_x1, weights) + biases) # Softmax
               
            
                    if i==0:
                        layer_out1 =(layer_out1)
                    else:
                        if activation == 'relu':
                            layer_out1 = tf.nn.relu(layer_out1)
                        elif activation == 'l_relu':
                            layer_out1 = tf.nn.leaky_relu(layer_out1)
                        elif activation == 'sigmoid':
                            layer_out1 = tf.nn.sigmoid(layer_out1)
                        elif activation == 'tanh':
                            layer_out1 = tf.nn.tanh(layer_out1)
                        elif activation == 'none':
                            layer_out1 =(layer_out1)
                        else:
                            raise NotImplementedError('activation not recognized')

                    prev_node1 = hidden_layers_node[i]
                    prev_x1 = layer_out1
            
            prev_x2 = X2
            prev_node2=input_node2
            with tf.variable_scope('gates2', reuse=tf.AUTO_REUSE):
                if initilize:
                    self.alpha2 = tf.get_variable('alpha2', [prev_node2,],
                                              initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01))
                else:
                    self.alpha2 = tf.get_variable('alpha2',[prev_node2,],
                                              initializer=tf.truncated_normal_initializer(mean=v, stddev=0.01))
                prev_x2 = self.feature_selector2(prev_x2, train_gates)

            layer_name = 'layer2' + str(1)
            for i in range(len(hidden_layers_node)):
                layer_name = 'layer2' + str(i+1)
                with tf.variable_scope(layer_name, reuse=tf.AUTO_REUSE):
                    weights = tf.get_variable('weights2', [prev_node2, hidden_layers_node[i]],
                                              initializer=tf.truncated_normal_initializer(stddev=0.1))
                    self.nnweights.append(weights)
                    biases = tf.get_variable('biases2', [hidden_layers_node[i]],
                                             initializer=tf.constant_initializer(0.0))
                    layer_out2 = (tf.matmul(prev_x2, weights) + biases) # Softmax
               
                    if i==0:
                        layer_out2 =(layer_out2)
                    else:
                        if activation == 'relu':
                            layer_out2 = tf.nn.relu(layer_out2)
                        elif activation == 'l_relu':
                            layer_out2 = tf.nn.leaky_relu(layer_out2)
                        elif activation == 'sigmoid':
                            layer_out2 = tf.nn.sigmoid(layer_out2)
                        elif activation == 'tanh':
                            layer_out2 = tf.nn.tanh(layer_out2)
                        elif activation == 'none':
                            layer_out2 =(layer_out2)
                        else:
                            raise NotImplementedError('activation not recognized')

                    prev_node2 = hidden_layers_node[i]
                    prev_x2 = layer_out2
            
            


            
            
            
            
            
            loss_fun  = self.neg_correlation(layer_out1, layer_out2, use_all_singular_values)
            
            
            if feature_selection:

                input2cdf1 = self.alpha1
                reg1= 0.5 - 0.5*tf.erf((-1/(2*self.a) - input2cdf1)/(self.sigma*np.sqrt(2)))
                input2cdf2 = self.alpha2
                reg2= 0.5 - 0.5*tf.erf((-1/(2*self.a) - input2cdf2)/(self.sigma*np.sqrt(2)))
                reg_gates = self.lam1*tf.reduce_mean(reg1)+self.lam2*tf.reduce_mean(reg2)
                
                loss = loss_fun  +  reg_gates
                self.reg_gates = tf.reduce_mean(reg1)+tf.reduce_mean(reg2)
            else:
                loss = loss_fun
                self.reg_gates = tf.constant(0.)
            # Get optimizer
            train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
            # For evaluation

            # Initialize the variables (i.e. assign their default value)
            init_op = tf.global_variables_initializer()
            self.saver = tf.train.Saver()

        # Save into class members
        self.X1 = X1
        self.X2 = X2
   
        self.train_gates = train_gates
        self.loss = loss_fun
        self.train_step = train_step

        self.output_node=output_node
        self.layer_out1=layer_out1
        self.layer_out2=layer_out2
        # set random state
        tf.set_random_seed(seed)
        self.sess.run(init_op)

        
        
        
        
        
    def neg_correlation(self, output1, output2, use_all_singular_values):
        r1 = 1e-4
        r2 = 1e-4
        eps = 1e-12

        # unpack (separate) the output of networks for view 1 and view 2
        H1 = tf.transpose(output1)
        H2 = tf.transpose(output2)

        m = tf.shape(H1)[1]

        H1bar = H1 - (1.0 / tf.cast(m, tf.float32)) * tf.matmul(H1, tf.ones([m, m]))
        H2bar = H2 - (1.0 / tf.cast(m, tf.float32)) * tf.matmul(H2, tf.ones([m, m]))

        SigmaHat12 = (1.0 / (tf.cast(m, tf.float32) - 1)) * tf.matmul(H1bar, tf.transpose(H2bar))
        SigmaHat11 = (1.0 / (tf.cast(m, tf.float32) - 1)) * tf.matmul(H1bar, tf.transpose(H1bar)) + r1 * tf.eye(self.output_node,)
        SigmaHat22 = (1.0 / (tf.cast(m, tf.float32) - 1)) * tf.matmul(H2bar, tf.transpose(H2bar)) + r2 * tf.eye(self.output_node,)

        # Calculating the root inverse of covariance matrices by using eigen decomposition
        [D1, V1] = tf.linalg.eigh(SigmaHat11)
        [D2, V2] = tf.linalg.eigh(SigmaHat22)

        # Added to increase stability
        posInd1 = tf.where(tf.greater(D1, eps))
        posInd1 = tf.reshape(posInd1, [-1, tf.shape(posInd1)[0]])[0]
        D1 = tf.gather(D1, posInd1)
        V1 = tf.gather(V1, posInd1)

        posInd2 = tf.where(tf.greater(D2, eps))
        posInd2 = tf.reshape(posInd2, [-1, tf.shape(posInd2)[0]])[0]
        D2 = tf.gather(D2, posInd2)
        V2 = tf.gather(V2, posInd2)

        SigmaHat11RootInv = tf.matmul(tf.matmul(V1, tf.linalg.diag(D1 ** -0.5)), tf.transpose(V1))
        SigmaHat22RootInv = tf.matmul(tf.matmul(V2, tf.linalg.diag(D2 ** -0.5)), tf.transpose(V2))

        Tval = tf.matmul(tf.matmul(SigmaHat11RootInv, SigmaHat12), SigmaHat22RootInv)

        if use_all_singular_values:
            # all singular values are used to calculate the correlation
            # corr = tf.sqrt(tf.linalg.trace(tf.matmul(tf.transpose(Tval), Tval)))  ### The usage of "sqrt" here is wrong!!!
            Tval.set_shape([self.output_node, self.output_node])
            s = tf.svd(Tval, compute_uv=False)
            corr = tf.reduce_sum(s)
        else:
            # just the top outdim_size singular values are used
            [U, V] = tf.linalg.eigh(tf.matmul(tf.transpose(Tval), Tval))
            non_critical_indexes = tf.where(tf.greater(U, eps))
            non_critical_indexes = tf.reshape(non_critical_indexes, [-1, tf.shape(non_critical_indexes)[0]])[0]
            U = tf.gather(U, non_critical_indexes)
            U = tf.gather(U, tf.nn.top_k(U[:, ]).indices)
            corr = tf.reduce_sum(tf.sqrt(U[0:self.output_node,]))
            
         


        return -corr
        
    def _to_tensor(self, x, dtype):
        """Convert the input `x` to a tensor of type `dtype`.
        # Arguments
            x: An object to be converted (numpy array, list, tensors).
            dtype: The destination type.
        # Returns
            A tensor.
        """
        return tf.convert_to_tensor(x, dtype=dtype)

    def hard_sigmoid(self, x, a):
        """Segment-wise linear approximation of sigmoid.
        Faster than sigmoid.
        Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
        In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
        # Arguments
            x: A tensor or variable.
        # Returns
            A tensor.
        """
        x = a * x + 0.5
        zero = self._to_tensor(0., x.dtype.base_dtype)
        one = self._to_tensor(1., x.dtype.base_dtype)
        x = tf.clip_by_value(x, zero, one)
        return x

    def feature_selector1(self, prev_x, train_gates):
        '''
        feature selector - used at training time (gradients can be propagated)
        :param prev_x - input. shape==[batch_size, feature_num]
        :param train_gates (bool) - 1 during training, 0 during evaluation
        :return: gated input
        '''
        # gaussian reparametrization
        base_noise = tf.random_normal(shape=tf.shape(prev_x), mean=0., stddev=1.)
        z = tf.expand_dims(self.alpha1, axis=0) + self.sigma * base_noise * train_gates
        stochastic_gate = self.hard_sigmoid(z, self.a)
        new_x = prev_x * stochastic_gate
        return new_x
    def feature_selector2(self, prev_x, train_gates):
        '''
        feature selector - used at training time (gradients can be propagated)
        :param prev_x - input. shape==[batch_size, feature_num]
        :param train_gates (bool) - 1 during training, 0 during evaluation
        :return: gated input
        '''
        # gaussian reparametrization
        base_noise = tf.random_normal(shape=tf.shape(prev_x), mean=0., stddev=1.)
        z = tf.expand_dims(self.alpha2, axis=0) + self.sigma * base_noise * train_gates
        stochastic_gate = self.hard_sigmoid(z, self.a)
        new_x = prev_x * stochastic_gate
        return new_x

    def eval(self, new_X1, new_X2):
        loss = self.sess.run([ self.loss], feed_dict={self.X1: new_X1,
                                                              self.X2: new_X2,
                                                        self.train_gates: [0.0]
                                                        })
        return np.squeeze(loss)
    
    
    def get_raw_weights(self):
        """
        evaluate the learned parameter for stochastic gates 
        """
        weights = self.sess.run([self.nnweights])
        return weights

    def get_raw_alpha(self):
        """
        evaluate the learned parameter for stochastic gates 
        """
        dp_alpha1,dp_alpha2 = self.sess.run([self.alpha1,self.alpha2])
        return dp_alpha1,dp_alpha2

    def get_prob_alpha(self):
        """
        convert the raw alpha into the actual probability
        """
        dp_alpha1,dp_alpha2 = self.get_raw_alpha()
        prob_gate1 = self.compute_learned_prob(dp_alpha1)
        prob_gate2 = self.compute_learned_prob(dp_alpha2)
        return prob_gate1,prob_gate2

    def hard_sigmoid_np(self, x, a):
        return np.minimum(1, np.maximum(0,a*x+0.5))

    def compute_learned_prob(self, alpha):
        z = alpha
        stochastic_gate = self.hard_sigmoid_np(z, self.a)
        return stochastic_gate

    def load(self, model_path=None):
        if model_path == None:
            raise Exception()
        self.saver.restore(self.sess, model_path)

    def save(self, step, model_dir=None):
        if model_dir == None:
            raise Exception()
        try:
            os.mkdir(model_dir)
        except:
            pass
        model_file = model_dir + "/model"
        self.saver.save(self.sess, model_file, global_step=step)

    def train(self, X1,X2,X1_valid,X2_valid,  num_epoch=100):
        train_losses, train_accuracies = [], []
        val_losses = []
        val_accuracies = []
       
        for epoch in range(num_epoch):
            avg_loss = 0.
            total_batch = int(X1.shape[0]/self.batch_size)
            # Loop over all batches
            for i in range(total_batch):
                X1_batch=X1[i*self.batch_size:(i+1)*self.batch_size,:]
                X2_batch=X2[i*self.batch_size:(i+1)*self.batch_size,:]
               # batch_xs, batch_ys = dataset.next_batch(self.batch_size)

                _, c, reg_fs = self.sess.run([self.train_step, self.loss, self.reg_gates], feed_dict={self.X1: X1_batch,
                                                              self.X2: X2_batch,
                                                              self.train_gates: [1.0]})
                avg_loss += c

                train_losses.append(avg_loss)
                # Display logs per epoch step
            avg_loss=avg_loss/total_batch
            if epoch % self.display_step==0:
                  #print("Epoch: {} train loss={:.9f} reg={:.9f}".format(epoch,avg_loss,reg_fs))
            #if  0:
                valid_loss = self.eval(X1_valid,X2_valid)
             
                val_losses.append(valid_loss)

                print("Epoch: {} train loss={:.9f} valid loss= {:.9f}".format(epoch+1,\
                                                                                  avg_loss, valid_loss))
                print("train reg_fs: {}".format(reg_fs))

                
        print("Optimization Finished!")
       
        return train_accuracies, train_losses, val_accuracies, val_losses 
    def test(self,X_test1,X_test2):
        out1,out2 = self.sess.run([self.layer_out1,self.layer_out2], feed_dict={self.X1: X_test1,self.X2: X_test2,self.train_gates: [0.0]})

        return out1,out2
    def evaluate(self, X, y):
        acc, loss = self.eval(X, y)
        print("test loss: {}, test acc: {}".format(loss, acc))
        print("Saving model..")
        #self.save(step=1, model_dir=output_dir)
        #self.acc=test_acc
        return acc, loss

