import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
#import optuna
import os

np.random.seed(1)

class Model(object):
    def __init__(self, input_node,sample_size, hidden_layers_node, batch_size, display_step, activation,embd_layer,
            seed=1,sample_selection=True,normalize=False,use_reg_expectation = True,
            lam=0.5, train_bias = True,sigma=0.5,optimizer='GD'
        ): 
        # use_reg_expectation - determines if the model will use expecation E(z) or p(z>0) in the regularization term


        # Register hyperparameters for feature selection
        self.a =1
        self.sigma =sigma
        self.lam = lam
        self.sample_selection=sample_selection
        # Register regular hyperparameters
        self.train_bias = train_bias
        self.batch_size = batch_size
        self.sample_size=sample_size
        self.embd_layer=embd_layer
        self.display_step = display_step # to print loss/acc information during training
        self.hidden_layers_node=np.copy(hidden_layers_node)
        
        self.hidden_layers_node=np.append(self.hidden_layers_node,input_node)
        #print(self.hidden_layers_node)
        G = tf.Graph()
        with G.as_default():
            self.sess = tf.Session(graph=G)
            # tf Graph Input
            
            X = tf.placeholder(tf.float32, [None, input_node])             
            train_gates = tf.placeholder(tf.float32, [1], name='train_gates')
            temp=tf.constant(True, dtype=tf.bool)
            self.test_phase = tf.placeholder_with_default(temp, [], name='test_phase')
            self.learning_rate = tf.placeholder(tf.float32, [1], name='learning_rate')
            batch_index = tf.placeholder(tf.float32, [1], name='batch_index')
            self.shift = tf.placeholder(tf.float32, [1], name='shift')
            
            self.nnweights = []
            self.nnbiases=[]

            
            prev_node = input_node
            prev_x = X
            
            with tf.variable_scope('gates', reuse=tf.AUTO_REUSE):
                self.alpha = tf.get_variable('alpha', [self.sample_size,],
                                          initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01))
                
                
                self.gen_gates(batch_index,train_gates)
                #gated_x=tf.cond(self.test_phase,lambda:prev_x,lambda: self.sample_selector(prev_x, train_gates,batch_index) )
                gated_x=tf.cond(self.test_phase,lambda:prev_x,lambda: self.gates*prev_x )
                
                prev_x=gated_x
            layer_name = 'layer' + str(1)
            for i in range(len(self.hidden_layers_node)):
                layer_name = 'layer' + str(i+1)
                with tf.variable_scope(layer_name, reuse=tf.AUTO_REUSE):
                    weights = tf.get_variable('weights', [prev_node, self.hidden_layers_node[i]],
                                              initializer=tf.truncated_normal_initializer(stddev=0.05))
                    self.nnweights.append(weights)
                    biases = tf.get_variable('biases', [self.hidden_layers_node[i]],
                                             initializer=tf.constant_initializer(0.0))
                    self.nnbiases.append(biases)
                    
                    
                    layer_out = tf.matmul(prev_x, weights) + biases
               
                    if activation =='relu':
                        layer_out = tf.nn.relu(layer_out)
                    elif activation =='l_relu':
                        layer_out = tf.nn.leaky_relu(layer_out)
                    elif activation == 'sigmoid':
                        layer_out = tf.nn.sigmoid(layer_out)
                    elif activation == 'tanh':
                        layer_out = tf.nn.tanh(layer_out)
                    elif activation == 'none':
                        layer_out =(layer_out)
                    else:
                        raise NotImplementedError('activation not recognized')
                    
                    if i==self.embd_layer:
                        self.embedding=layer_out

                    prev_node = self.hidden_layers_node[i]
                    
                    
                    prev_x = layer_out
            

            # Output of model
            # Minimize error using cross entropy
            pred=layer_out
            
     
            gated_out=tf.cond(self.test_phase,lambda:pred,lambda:self.gates*pred )
            
            if normalize==True:
                
                loss_fun = tf.reduce_mean(tf.reduce_sum((gated_out- gated_x)**2,axis=1))/ tf.reduce_mean(tf.reduce_sum((X)**2,axis=1))
            else:
                loss_fun = tf.reduce_mean(tf.reduce_sum((gated_out- gated_x)**2,axis=1))

            ## gates regularization
            input2cdf = self.alpha

            if use_reg_expectation:
                phi = lambda x: 0.5*(1+tf.erf(x/np.sqrt(2)))
                mu_z = self.sigma**2/(self.sigma* np.sqrt(2*np.pi)) * \
                                    ( tf.exp(-self.alpha**2/(2*self.sigma**2)) - tf.exp(-(1-self.alpha)**2/(2*self.sigma**2)) ) +\
                            (self.alpha-1)* phi((1-self.alpha)/self.sigma) - self.alpha*phi(-self.alpha/self.sigma) + 1        
                reg_gates = tf.reduce_mean(mu_z)
            else: 
                reg = 0.5 - 0.5*tf.erf((-1/(2*self.a) - input2cdf)/(self.sigma*np.sqrt(2)))
                reg_gates = tf.reduce_mean(reg)
        
            loss = loss_fun  - self.lam * reg_gates*train_gates 
            self.reg_gates = reg_gates # for debugging
            self.train_gates = train_gates
        
            # Get optimizer
            if optimizer=='GD':
                if self.train_bias:
                    train_step = tf.cond(tf.equal(self.train_gates, tf.constant(1.))[0],\
                                         lambda:  tf.train.GradientDescentOptimizer(self.learning_rate[0]).minimize(loss),\
                                         lambda: tf.train.GradientDescentOptimizer(self.learning_rate[0]).\
                                         minimize(loss, var_list=self.nnweights + self.nnbiases))
                else:
                    train_step = tf.cond(tf.equal(self.train_gates, tf.constant(1.))[0],\
                                         lambda: tf.train.GradientDescentOptimizer(self.learning_rate[0]).minimize(loss),\
                                         lambda: tf.train.GradientDescentOptimizer(self.learning_rate[0]).\
                                         minimize(loss, var_list=self.nnweights))
            else:
                if self.train_bias:
                    train_step = tf.cond(tf.equal(self.train_gates, tf.constant(1.))[0],\
                                         lambda:  tf.train.AdamOptimizer(self.learning_rate[0]).minimize(loss),\
                                         lambda: tf.train.AdamOptimizer(self.learning_rate[0]).\
                                         minimize(loss, var_list=self.nnweights + self.nnbiases))
                else:
                    train_step = tf.cond(tf.equal(self.train_gates, tf.constant(1.))[0],\
                                         lambda: tf.train.AdamOptimizer(self.learning_rate[0]).minimize(loss),\
                                         lambda: tf.train.AdamOptimizer(self.learning_rate[0]).\
                                         minimize(loss, var_list=self.nnweights))                                        

            
            ####Gradient clipping######
       #     optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate)
        #    gvs = optimizer.compute_gradients(loss)
         #   capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs]
          #  train_step  = optimizer.apply_gradients(capped_gvs)
          
            
            
            # For evaluation

            init_op = tf.global_variables_initializer()
            self.saver = tf.train.Saver()

        self.sess.run(init_op)
        # Save into class members
        self.X = X
        self.batch_index=batch_index
        self.pred = pred
        self.loss = loss
        self_gated_x=gated_x
        self.gatet_out=gated_out
        self.train_step = train_step
        self.train_gates=train_gates

        # set random state
        tf.set_random_seed(seed)
        self.sess.run(init_op)

    def _to_tensor(self, x, dtype):
        """Convert the input `x` to a tensor of type `dtype`.
        # Arguments
            x: An object to be converted (numpy array, list, tensors).
            dtype: The destination type.
        # Returns
            A tensor.
        """
        return tf.convert_to_tensor(x, dtype=dtype)
    def get_weights(self):
        weights_out=self.sess.run(self.nnweights)
        biases_out=self.sess.run(self.nnbiases)
        return weights_out,biases_out
    def hard_sigmoid(self, x, a):
        """Segment-wise linear approximation of sigmoid.
        Faster than sigmoid.
        Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
        In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
        # Arguments
            x: A tensor or variable.
        # Returns
            A tensor.
        """
        x = a * x + 0.5
        zero = self._to_tensor(0., x.dtype.base_dtype)
        one = self._to_tensor(1., x.dtype.base_dtype)
        x = tf.clip_by_value(x, zero, one)
        return x
    
    def gen_gates(self,batch_index,train_gates):
        '''
        feature selector - used at training time (gradients can be propagated)
        :param prev_x - input. shape==[batch_size, feature_num]
        :param train_gates (bool) - 1 during training, 0 during evaluation
        :return: gated input
        '''
        # gaussian reparametrization
          


        batch_gap=tf.cast(self.sample_size-self.batch_size*batch_index-self.shift, 'int32')

        tf_batch_size=tf.minimum(tf.cast(self.batch_size,'int32'),batch_gap)


        tf_batch_remain=tf.cast(self.batch_size,'int32')-tf_batch_size

        base_noise = tf.random_normal(shape=[self.batch_size], mean=0., stddev=1.)
        batch_alpha=tf.slice(self.alpha,tf.cast(self.batch_size*batch_index+self.shift,'int32'),tf_batch_size)




        batch_alpha_remain=tf.slice(self.alpha,[0],tf_batch_remain)
        batch_alpha_full=tf.concat([batch_alpha,batch_alpha_remain],-1)

        z = batch_alpha_full+ self.sigma * base_noise*train_gates  
        stochastic_gate = self.hard_sigmoid(z, self.a)


        self.gates = tf.expand_dims(stochastic_gate,1)


    

    def sample_selector(self, prev_x, train_gates,batch_index):
        '''
        feature selector - used at training time (gradients can be propagated)
        :param prev_x - input. shape==[batch_size, feature_num]
        :param train_gates (bool) - 1 during training, 0 during evaluation
        :return: gated input
        '''
        # gaussian reparametrization
        if self.sample_selection:            
            
            
            batch_gap=tf.cast(self.sample_size-self.batch_size*batch_index-self.shift, 'int32')

            tf_batch_size=tf.minimum(tf.cast(self.batch_size,'int32'),batch_gap)
            
            
            tf_batch_remain=tf.cast(self.batch_size,'int32')-tf_batch_size
               
            base_noise = tf.random_normal(shape=[self.batch_size], mean=0., stddev=1.)
            batch_alpha=tf.slice(self.alpha,tf.cast(self.batch_size*batch_index+self.shift,'int32'),tf_batch_size)
         

  
                
            batch_alpha_remain=tf.slice(self.alpha,[0],tf_batch_remain)
            batch_alpha_full=tf.concat([batch_alpha,batch_alpha_remain],-1)

            z = batch_alpha_full+ self.sigma * base_noise*train_gates  
            stochastic_gate = self.hard_sigmoid(z, self.a)
            

            new_x = tf.expand_dims(stochastic_gate,1)*prev_x

            
        else:
            new_x = prev_x 
        return new_x

    def validate(self, new_X):
        sample_select_temp=self.sample_selection
        self.sample_selection=False
        loss = self.sess.run([self.loss], feed_dict={self.X: new_X,
                                                        self.train_gates: [0.0],self.test_phase:True,self.batch_index:[0.0],\
                                                        self.shift:[0.0]
                                                        })
        self.sample_selection=sample_select_temp
        return np.squeeze(loss)

    def get_raw_alpha(self):
        """
        evaluate the learned parameter for stochastic gates 
        """
        dp_alpha = self.sess.run(self.alpha)
        return dp_alpha
    
    def get_weights(self):
        """
        evaluate the learned parameter for stochastic gates 
        """
        weights = self.sess.run([self.nnweights, self.nnbiases])
        return weights

    def get_prob_alpha(self):
        """
        convert the raw alpha into the actual probability
        """
        dp_alpha = self.get_raw_alpha()
        prob_gate = self.compute_learned_prob(dp_alpha)
        return prob_gate

    def hard_sigmoid_np(self, x, a):
        return np.minimum(1, np.maximum(0,a*x+0.5))

    def compute_learned_prob(self, alpha):
        z = alpha
        stochastic_gate = self.hard_sigmoid_np(z, self.a)
        return stochastic_gate

    def load(self, model_path=None):
        if model_path == None:
            raise Exception()
        self.saver.restore(self.sess, model_path)

    def save(self, step, model_dir=None):
        if model_dir == None:
            raise Exception()
        try:
            os.mkdir(model_dir)
        except:
            pass
        model_file = model_dir + "/model"
        self.saver.save(self.sess, model_file, global_step=step)

    def train(self, x_train,x_valid, output_dir, learning_rate, num_epoch=100,train_gates=1):
        train_losses = []
        val_losses = []
        
        self.best_val= np.inf
        
        
        shift=0
        for epoch in range(num_epoch):
            avg_loss = 0.
            total_batch = int(np.ceil((self.sample_size-shift)/self.batch_size))
            # Loop over all batches
           # per_indices=np.random.permutation(self.sample_size)
            #x_train_per=x_train[per_indices,:]
            #print(total_batch)
            for i in range(total_batch):
                batch_xs = x_train[shift+self.batch_size*i:shift+self.batch_size*(i+1),:]
                batch_xs_full=np.vstack((batch_xs,x_train[:self.batch_size-batch_xs.shape[0],:]))
                    
                _, c, reg_fs = self.sess.run([self.train_step, self.loss, self.reg_gates],\
                                             feed_dict={self.X: batch_xs_full,self.test_phase:False,
                                                              self.train_gates: [train_gates],self.batch_index:[int(i)],\
                                                        self.learning_rate:[learning_rate],\
                                                       self.shift:[shift]})
                if np.isnan(self.get_weights()[0][0]).any() or np.isnan(c).any():
                    print("nan")
                avg_loss += c / total_batch
                
            shift=self.batch_size-batch_xs.shape[0]
            
            
            train_losses.append(avg_loss)
            self.best_gates=[]
            # Display logs per epoch step
            if (epoch+1) % self.display_step == 0:
                
                valid_loss = self.validate(x_valid)
                val_losses.append(valid_loss)

                print("Epoch: {} train loss={:.9f} valid loss= {:.9f}".format(epoch+1,avg_loss[0], valid_loss))                                                            
                print("train reg_fs: {}".format(reg_fs))
                               

                if self.best_val> valid_loss:
              
                        self.best_gates=self.get_raw_alpha()

        #print("Optimization Finished!")

        return  train_losses,val_losses,self.best_gates
    def test(self,X_test):
        sample_select_temp=self.sample_selection
        self.sample_selection=False
        prediction,embedding = self.sess.run([self.pred,self.embedding], feed_dict={self.X: X_test,self.train_gates: [0.0],self.test_phase:True,\
                                                                                   self.batch_index:[0],self.shift:[0]})
        self.sample_selection=sample_select_temp
        return prediction,embedding
