from inits import glorot,zeros
import tensorflow as tf
import numpy as np

def sparse_dropout(x, keep_prob, noise_shape):
    """Dropout for sparse tensors."""
    random_tensor = keep_prob
    random_tensor += tf.random_uniform(noise_shape)
    dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
    pre_out = tf.sparse_retain(x, dropout_mask)
    return pre_out * (1./keep_prob)

def dot(x, y, sparse=False):
    """Wrapper for tf.matmul (sparse vs dense)."""
    if sparse:
        res = tf.sparse_tensor_dense_matmul(x, y)
    else:
        res = tf.matmul(x, y)
    return res

class IFSlayer():
    def __init__(self, placeholders, input_dim, input_num, hide_dim, output_dim,
                 act=tf.nn.relu,
                 initial_value_of_p0=0.5,
                 ifs_layers_num=3,
                 ifs_layer_bias=False,
                 output_bias=False,
                 learnable_p=False,
                 learnable_r=True,
                 representation_layer_type='weighted_time_average'
                 ):
        self.vars = {}
        self.vars_bias={}
        self.vars_p={}
        self.dropout = placeholders['dropout']
        self.support_up = placeholders['support_up']
        self.support_low = placeholders['support_low']
        self.initial_value_of_p0 = initial_value_of_p0
        self.num_features_nonzero = placeholders['num_features_nonzero']
        self.act = act
        self.ifs_layers_num = ifs_layers_num
        self.ifs_layer_bias=ifs_layer_bias
        self.output_bias = output_bias
        self.learnable_p = learnable_p
        self.learnable_coefficient = learnable_r
        self.fractal_representation = {}
        self.prob_vec_iterate = {}
        self.ME = {}# mathematical expectation of fractal representation
        self.affine = [self.support_up, self.support_low]
        self.representation_layer_type=representation_layer_type
        self.vars['input_layer'] = glorot([input_dim, hide_dim], name='input_layer')
        if self.learnable_p:
            self.vars_p['p0'] = tf.Variable(self.initial_value_of_p0, dtype=tf.float32, name='p0')
            self.vars_p['p1'] = tf.Variable(1-self.initial_value_of_p0, dtype=tf.float32, name='p1')
            rb_up = tf.nn.relu(self.vars_p['p0'])
            rb_low = tf.nn.relu(self.vars_p['p1'])
            self.prob_vec = [tf.divide(rb_up + 0.1, rb_up + rb_low + 0.2), tf.divide(rb_low + 0.1, rb_up + rb_low + 0.2)]#8
        else:
            self.prob_vec = [self.initial_value_of_p0, 1 - self.initial_value_of_p0]

        if self.representation_layer_type in ['weighted_time_average','time_average']:
            self.vars['output_layer'] = glorot([hide_dim, output_dim], name='output_layer')
        if self.representation_layer_type in ['weighted_concatenation','concatenation']:
            self.vars['output_layer'] = glorot([hide_dim * self.ifs_layers_num, output_dim], name='output_layer')

        coefficient_of_expansion=(np.math.log(input_num)+0.577215664)**0.5
        initializing_weight_point= [1 / (coefficient_of_expansion**i) for i in range(self.ifs_layers_num)]
        # initializing_uniform_point= [1/self.ifs_layers_num for i in range(self.ifs_layers_num)]
        if self.learnable_coefficient:
            for i in range(1, self.ifs_layers_num + 1):
                self.vars['ifs_layer_weight_{}'.format(i)]=tf.Variable(initializing_weight_point[i-1],
                                                                       dtype=tf.float32,
                                                                       name='ifs_layer_weight_{}'.format(i))
        else:
            for i in range(1, self.ifs_layers_num + 1):
                self.vars['ifs_layer_weight_{}'.format(i)]=initializing_weight_point[i-1]

        if self.output_bias:
            self.vars_bias['output_bias'] = zeros([output_dim], name='output_bias')

        if self.ifs_layer_bias:
            self.vars_bias['ifs_bias_up'] = zeros([hide_dim], name='ifs_bias_up')
            self.vars_bias['ifs_bias_low'] = zeros([hide_dim], name='ifs_bias_low')

    def input_and_init_ifs_iterate(self, affine, inputs, weight, ifs_bais=False):
        inputs = sparse_dropout(inputs, 1 - self.dropout, self.num_features_nonzero)
        inputs_dot_weight = dot(inputs, weight, sparse= True)
        inputs_dot_weight = self.act(inputs_dot_weight)
        if ifs_bais:
            if affine==self.affine[0]:
                output = tf.add(dot(affine, inputs_dot_weight, sparse=True), self.vars_bias['ifs_bias_up'])
            if affine==self.affine[1]:
                output = tf.add(dot(affine, inputs_dot_weight, sparse=True), self.vars_bias['ifs_bias_low'])
        else:
            output = dot(affine, inputs_dot_weight, sparse=True)
        return output

    def ifs_iterate(self, affine, inputs, ifs_bais=False):
        inputs = tf.nn.dropout(inputs, 1 - self.dropout)
        if ifs_bais:
            if affine == self.affine[0]:
                output = dot(affine, inputs, sparse=True) + self.vars_bias['ifs_bias_up']
            if affine == self.affine[1]:
                output = dot(affine, inputs, sparse=True) + self.vars_bias['ifs_bias_low']
        else:
            output = dot(affine, inputs, sparse=True)
        return output

    def ifs_iterate_with_train_weight(self, affine, inputs, weight):
        inputs = tf.nn.dropout(inputs, 1 - self.dropout)
        inputs_dot_weight = dot(inputs, weight, sparse=False)
        output = dot(affine, inputs_dot_weight, sparse=True)
        return self.act(output)

    def __call__(self, inputs):
        ######## input layer and IFS layer
        self.fractal_representation['h_1'] = [self.input_and_init_ifs_iterate(affine, inputs, self.vars['input_layer'],
                                                                              ifs_bais=self.ifs_layer_bias) for affine in self.affine]
        self.prob_vec_iterate['p_1']= self.prob_vec
        self.ME['m_1'] = self.prob_vec_iterate['p_1'][0] * self.fractal_representation['h_1'][0] + self.prob_vec_iterate['p_1'][1] * self.fractal_representation['h_1'][1]
        if self.ifs_layers_num>1:
            for i in range(2, self.ifs_layers_num + 1):
                self.prob_vec_iterate['p_{}'.format(i)] = [p * pb for p in self.prob_vec for pb in self.prob_vec_iterate['p_{}'.format(i - 1)]]
                self.fractal_representation['h_{}'.format(i)] = [self.ifs_iterate(affine, h, ifs_bais=self.ifs_layer_bias) for affine in self.affine for h in self.fractal_representation['h_{}'.format(i - 1)]]
                self.ME['m_{}'.format(i)] = sum([self.prob_vec_iterate['p_{}'.format(i)][j] * self.fractal_representation['h_{}'.format(i)][j] for j in range(2 ** i)])

        ######## Representation layer
        if self.representation_layer_type in ['weighted_time_average']:
            ergodic_representation = sum(
                [tf.nn.relu(self.vars['ifs_layer_weight_{}'.format(i)]) * self.ME['m_{}'.format(i)] for i in
                 range(1, self.ifs_layers_num + 1)])
        if self.representation_layer_type in ['weighted_concatenation']:
            ergodic_representation = tf.concat(
                    [tf.nn.relu(self.vars['ifs_layer_weight_{}'.format(i)]) * self.ME['m_{}'.format(i)] for i in
                     range(1, self.ifs_layers_num + 1)], 1)
        if self.representation_layer_type in ['time_average']:
            ergodic_representation = tf.reduce_mean(
                [self.ME['m_{}'.format(i)] for i in range(1, self.ifs_layers_num + 1)], 0)
        if self.representation_layer_type in ['concatenation']:
            ergodic_representation = tf.concat(
                [self.ME['m_{}'.format(i)] for i in range(1, self.ifs_layers_num + 1)],1)

        ######## output layer
        ### mixed propagation
        mix_up = self.ifs_iterate_with_train_weight(self.support_up, ergodic_representation, self.vars['output_layer'])
        mix_low = self.ifs_iterate_with_train_weight(self.support_low, ergodic_representation, self.vars['output_layer'])
        # output = tf.reduce_mean([mix_up,mix_low], 0)
        output = self.prob_vec[0] * mix_up + self.prob_vec[1] * mix_low
        ### MLP
        # output = tf.matmul(ergodic_representation,self.vars['output_layer'])
        if self.output_bias:
            output += self.vars_bias['output_bias']
        return output