import numpy as np

import tensorflow as tf

from utils import Batcher 


class  Embedder():
    
    DISTANCE_TYPE_L2_SQ_NORMALIZED = 1
    DISTANCE_TYPE_L2_SQ_UNNORMED = 2
    
    
    def __init__(self,D,K,Bx,   
                 learning_rate = 1e-2, 
                 fit_Bx = 1000,
                 fit_cb = None,
                 fit_epochs = 100,
                 n_classes = None,
                 l1_reg = 0.,
                 l2_reg = 0.,
                 l2_infty_reg = 0.,
                 fit_By = None,
                 pairwise_cost_pos_multiplier = 1.,
                 neg_upper_threshold = 1., 
                 out_activation = tf.sigmoid,
                 distance_type = DISTANCE_TYPE_L2_SQ_UNNORMED,
                 intermediate_layer_sizes = []
                ):


        self.intermediate_layer_sizes = intermediate_layer_sizes
        self.distance_type = distance_type

        self.out_activation = out_activation
        self.pairwise_cost_pos_multiplier = pairwise_cost_pos_multiplier
        self.neg_upper_threshold = neg_upper_threshold
            
        
        self.l1_reg = l1_reg
        self.l2_reg = l2_reg
        self.l2_infty_reg = l2_infty_reg
                        
        self.n_classes = n_classes
        
        self.build_graph(D,K,Bx,
                         learning_rate = learning_rate, 
        )
    
        self.global_epoch = 0
    
        self.fit_cb = fit_cb
        self.fit_Bx = fit_Bx
        self.fit_By = fit_Bx if fit_By is None else fit_By
        self.fit_epochs = fit_epochs
        
    
        self.sess =  tf.compat.v1.Session(graph = self.graph)
        self.sess.run(self.t_init)
        
        
        

    def get_weights(self):
        
        return [self.sess.run(layer.variables) for layer in self.intermediate_layers]
        #res_weights,res_biases = self.sess.run(self.dense_layer.variables)
        #return res_weights,res_biases

    

    def t_get_d_xy(self, x_sigmoids, y_sigmoids):
                
        
        if self.distance_type == type(self).DISTANCE_TYPE_L2_SQ_UNNORMED:

            d_xy = tf.reduce_sum(        
                             (x_sigmoids - y_sigmoids)**2, 
                            axis = 1
            )                        
        
        elif self.distance_type == type(self).DISTANCE_TYPE_L2_SQ_NORMALIZED:

            d_xy = tf.reduce_mean(        
                             (x_sigmoids - y_sigmoids)**2, 
                            axis = 1
            )                        

        else:
            assert False, 'Unknown distance type'


        return d_xy


    def get_pairwise_distances(self,Batch_x,Batch_y):

        t_dyn_Bx = tf.shape(Batch_x)[0]
        t_dyn_By = tf.shape(Batch_y)[0]
        K = tf.shape(Batch_x)[1]
        
        t_x_repl = tf.reshape( tf.tile(Batch_x,[1,t_dyn_By]),[t_dyn_Bx*t_dyn_By,K])        
        t_y_repl = tf.tile(Batch_y,[t_dyn_Bx,1])
            
        d_xy = tf.reshape(self.t_get_d_xy(t_x_repl,t_y_repl), [t_dyn_Bx,t_dyn_By])
                
        
        return t_dyn_Bx, t_dyn_By , d_xy


            
    
    def build_graph(self,D,K,Bx, 
                    learning_rate,
                    init_std_dev = .05 ):
        
        
        self.learning_rate = learning_rate
        self.D = D
        self.Bx = Bx
        self.K = K
        
        
        self.graph = tf.Graph()

        Kfloat = np.float32(K)
        
        
        with self.graph.as_default():

            #note that Bx can be None
            self.t_x_ph = tf.compat.v1.placeholder(tf.float32,shape = [Bx,D])
            self.t_x_labels_ph = tf.compat.v1.placeholder(tf.float32,shape = [Bx])
                        
            self.t_y_ph = tf.compat.v1.placeholder(tf.float32,shape = [Bx,D])
            self.t_y_labels_ph = tf.compat.v1.placeholder(tf.float32,shape = [Bx])


            self.t_reg_l1_pl = tf.compat.v1.placeholder(tf.float32,shape = ())            
            self.t_reg_l2_pl = tf.compat.v1.placeholder(tf.float32,shape = ())            
            self.t_reg_l2infty_pl = tf.compat.v1.placeholder(tf.float32,shape = ())            
            self.t_learning_rate_pl = tf.compat.v1.placeholder(tf.float32,shape = ())

            self.intermediate_layers = []                
            if len(self.intermediate_layer_sizes) != 0:

                self.intermediate_layers = []                
                
                input_shape = [Bx,D]
                
                for layer_size in self.intermediate_layer_sizes:

                    layer = tf.keras.layers.Dense( units = layer_size,
                                                    activation=tf.keras.layers.LeakyReLU(0.2),
                                                    kernel_regularizer = None,
                                                    bias_regularizer = None
                                                 )        
                    layer.build(input_shape = input_shape)        
                    self.intermediate_layers.append(layer)
                    input_shape = [Bx,layer_size]
                    
                
                self.layer_tensors = []
                                
                input_x = self.t_x_ph
                input_y = self.t_y_ph
                
                for layer in self.intermediate_layers:
                    input_x = layer.call(input_x)
                    input_y = layer.call(input_y)
                    self.layer_tensors.append((input_x,input_y))
                                
                top_level_input_x = input_x
                top_level_input_y = input_y

            else:                
                top_level_input_x = self.t_x_ph
                top_level_input_y = self.t_y_ph



            self.dense_layer = tf.keras.layers.Dense( units = K,
                                                activation=None,
                                                kernel_regularizer = None,
                                                bias_regularizer = None
                                            )        
            
            self.dense_layer.build(input_shape = top_level_input_x.shape)        
            self.intermediate_layers.append(self.dense_layer)
            
            
            def tf_sum_mean(v):
                return tf.reduce_mean(tf.reduce_sum(v,axis = 0))
            
            self.t_weight_l1  = sum([tf_sum_mean(tf.math.abs(layer.variables[0])) 
                                     for layer in self.intermediate_layers]
            )
            self.t_weight_l2  = sum([tf_sum_mean(layer.variables[0]**2) 
                                     for layer in self.intermediate_layers]
            )                        
            self.t_weight_l2_infty  = sum([tf_sum_mean(tf.reduce_mean(layer.variables[0]**2, axis = 0))
                                     for layer in self.intermediate_layers]
            )
            
            
            
            ############## operators             
            
            #out_activation is usually tf.sigmoid 
            self.t_partition_logits_per_x = self.dense_layer.call(top_level_input_x)
            self.t_partition_sigmoids_per_x = self.out_activation(self.t_partition_logits_per_x)

            self.t_partition_logits_per_y = self.dense_layer.call(top_level_input_y)
            self.t_partition_sigmoids_per_y = self.out_activation(self.t_partition_logits_per_y)

            
            self.t_dyn_Bx , self.t_dyn_By , self.d_xy = self.get_pairwise_distances(
                self.t_partition_sigmoids_per_x,
                self.t_partition_sigmoids_per_y,
            )
            t_dyn_Bx , t_dyn_By = self.t_dyn_Bx , self.t_dyn_By

            
            #label equality mask 
            self.t_label_eq_mask = tf.cast( 
                tf.equal( 
                    tf.broadcast_to(tf.expand_dims(self.t_x_labels_ph,1),[t_dyn_Bx,t_dyn_By]),
                    tf.broadcast_to(tf.expand_dims(self.t_y_labels_ph,0),[t_dyn_Bx,t_dyn_By])
                ),
                tf.float32
            )
            
            self.positive_vals_all = tf.nn.relu(self.d_xy) # = self.d_xy , if there's no threshold
            self.negative_vals_all = tf.nn.relu(self.neg_upper_threshold - self.d_xy)
            
            self.positive_vals = self.t_label_eq_mask*self.positive_vals_all
            self.negative_vals = (1.-self.t_label_eq_mask)*self.negative_vals_all
            
            #this is the loss as described in the paper.
            self.total_out = (
                self.pairwise_cost_pos_multiplier*(self.positive_vals) + (self.negative_vals)
            )
            
            
            #thats the proper out if we are optimizing the mean over pairs directly
            self.out_tensor_orig = tf.reduce_sum(self.total_out, axis = 1) / tf.cast(t_dyn_By,tf.float32)
            self.t_value_orig = tf.reduce_mean(self.out_tensor_orig)                
            
            #due to historical reasons.
            self.out_tensor = self.out_tensor_orig                                        
            self.t_value = self.t_value_orig


            #self.t_value is minimized
            self.t_loss = self.t_value + ( 
                            self.t_reg_l1_pl * self.t_weight_l1 + 
                            self.t_reg_l2_pl * self.t_weight_l2 +                            
                            self.t_reg_l2infty_pl * self.t_weight_l2_infty
            )

            
            self.t_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.t_learning_rate_pl)
            self.t_train_op = self.t_optimizer.minimize(self.t_loss)
            
            
            self.t_init = tf.compat.v1.global_variables_initializer()



    def fit(self, feat, lbl, num_epochs = None):
        
        
        xbatcher = Batcher(self.fit_Bx, feat, lbl)
        ybatcher = Batcher(self.fit_By, feat, lbl)
        
        
        if num_epochs is None:
            num_epochs = self.fit_epochs
        
        
        #num_epochs = 0
        for epoch_cnt in range(num_epochs):
            
            epoch_end =  False
            self.batch_counter = 0                
            while not epoch_end:
            
                xbatch, xlabels , xepoch_end  = xbatcher.get_batch()  
                ybatch, ylabels , yepoch_end  = ybatcher.get_batch()  
                #epoch_end = xepoch_end or yepoch_end
                epoch_end = xepoch_end
                                        
                feed_dict = {self.t_x_ph:xbatch, 
                            self.t_x_labels_ph:xlabels,
                            self.t_y_ph:ybatch, 
                            self.t_y_labels_ph:ylabels,
                            self.t_reg_l1_pl : self.l1_reg,
                            self.t_reg_l2_pl : self.l2_reg,
                            self.t_reg_l2infty_pl : self.l2_infty_reg,
                            self.t_learning_rate_pl: self.learning_rate
                            }

                
                #debug
                self.curr_feed_dict = feed_dict
                
                
                self.sess.run([self.t_train_op], feed_dict = feed_dict)
            
                if self.fit_cb is not None:
                    self.fit_cb.cb(self,feed_dict)
                
                self.batch_counter += 1
                
            self.global_epoch += 1
    

        return 


