import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np



class residual_block(layers.Layer):
    '''
    a simple residual block implementation
    '''
    
    def __init__(self, dim, **kwargs):
        super(residual_block, self).__init__(**kwargs)
        self.dim = dim 
        self.dense1 = layers.Dense(dim, activation = 'relu', name = 'dense_1')
        self.dense2 = layers.Dense(dim, activation = 'relu', name = 'dense_2')
        
    def call(self, inputs):
        
        assert inputs.shape[-1] == self.dim 
        
        outputs = self.dense1(inputs)
        outputs = self.dense2(inputs)
        
        outputs = tf.math.add(inputs, outputs)
        
        return tf.math.l2_normalize(outputs)

class affine_coupling_layer(layers.Layer):
    
    '''
    implements affine coupling (invertible) transformation for samples from auxiliary distribution, e.g. uniform, standard gaussian
    (1). if actnorm is set to be true, first transform the samples via x = s * x_0 + b
    (2). divide input samples in to halves or random split, e.g. x = [x_1, x_2] 
    (3). h_1 = x_1 * exp(f1(x_2)) + g1(x_2)
         h_2 = x_2 * exp(f2(x_1)) + g2(x_1)
         f_1, f_2, g_1, g_2 are feed-forward neural nets, and * and + are element-wise, or residual block
    (4). return [h_1, h_2][:d], d is the required number of samples for output
    
    arguments
    
    in_dim: number of auxiliary samples 
    inter_dims: list of hidden dimensions for f_1, f_2, g_1, g_2
                if not provided, use res block instead
    out_dim: required number of samples 
    divide: divide method, i.e. into halves or randomly split x into x_1, x_2
    blocks: number of affine coupling layer
    actnorm: whether to use actnorm transformation
    
    input 
    
    samples from the auxiliary distribution
    
    output
    
    transformed samples of out_dim
    '''
    
    def __init__(self, in_dim, inter_dims, out_dim, divide = 'half', blocks = 1, actnorm = False, name = None, l2_reg = 0.001):
        super(affine_coupling_layer, self).__init__(name)
        
        assert divide in ['half', 'random']
        assert in_dim >= out_dim
        assert blocks >= 1
        
        self.inter_dims = inter_dims
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.blocks = blocks
        self.actnorm = actnorm
        
        if self.actnorm:
            self.s = self.add_weight(name='s_var', shape=[self.in_dim], 
                                     initializer='ones', trainable=True)
            self.t = self.add_weight(name='t_var', shape=[self.in_dim], 
                                     initializer='zeros', trainable=True)
        else:
            self.s = tf.constant([1.] * self.in_dim, name='s_var')
            self.t = tf.constant([0.] * self.in_dim, name='t_var')
        
        
        self.t_model_1 = tf.keras.Sequential(name='t_1_model')
        self.s_model_1 = tf.keras.Sequential(name='s_1_model')
        
        self.t_model_2 = tf.keras.Sequential(name='t_2_model')
        self.s_model_2 = tf.keras.Sequential(name='s_2_model') 
        
        if not in_dim % 2 == 0:
            divide = 'random'
        
        if divide == 'half':
            self.d = int(in_dim / 2)
        else:
            self.d = np.random.choice(range(1, in_dim), 1)[0]
            
        
        if type(inter_dims) is list:
            for i, dim in enumerate(inter_dims):
                self.t_model_1.add(layers.Dense(dim, activation='sigmoid', name='t_1_dense_'+str(i+1)))
                self.s_model_1.add(layers.Dense(dim, activation='sigmoid', name='s_1_dense_'+str(i+1)))

                self.t_model_2.add(layers.Dense(dim, activation='sigmoid', name='t_2_dense_'+str(i+1)))
                self.s_model_2.add(layers.Dense(dim, activation='sigmoid', name='s_2_dense_'+str(i+1)))
        else:
            self.t_model_1.add(residual_block(self.d, name='t_1_res'))
            self.s_model_1.add(residual_block(self.d, name='s_1_res'))
            self.t_model_2.add(residual_block(self.in_dim - self.d, name='t_2_res'))
            self.s_model_2.add(residual_block(self.in_dim - self.d, name='s_2_res'))
        
        
        self.t_model_1.add(layers.Dense(self.d, name='t_1_dense_final', kernel_regularizer = tf.keras.regularizers.l2(l2_reg)))
        self.s_model_1.add(layers.Dense(self.d, name='s_1_dense_final', kernel_regularizer = tf.keras.regularizers.l2(l2_reg)))
        
        self.t_model_2.add(layers.Dense(self.in_dim - self.d, name='t_2_dense_final', kernel_regularizer = tf.keras.regularizers.l2(l2_reg)))
        self.s_model_2.add(layers.Dense(self.in_dim - self.d, name='s_2_dense_final', kernel_regularizer = tf.keras.regularizers.l2(l2_reg)))
            
    def call(self, inputs):
        
        assert inputs.shape[1] == self.in_dim
        inputs = tf.math.add(tf.math.multiply(inputs, self.s), self.t)
        
        for _ in range(self.blocks):
            
            inputs_1 = inputs[:,:self.d]
            inputs_2 = inputs[:,self.d:]
            
            s_1 = self.s_model_1(inputs_2)
            s_1 = tf.math.exp(s_1)
            t_1 = self.t_model_1(inputs_2)
            inputs_1 = tf.math.add(tf.math.multiply(inputs_1, s_1), t_1)

            s_2 = self.s_model_2(inputs_1)
            s_2 = tf.math.exp(s_2)
            t_2 = self.t_model_2(inputs_1)
            inputs_2 = tf.math.add(tf.math.multiply(inputs_2, s_2), t_2)
        
            inputs = tf.concat((inputs_1, inputs_2), axis = -1)
        
        return inputs[:,:self.out_dim]
    
    
class rand_feat_layer(layers.Layer):
    
    '''
    the sampled-and-held-fixed Fourier random feature layer
    (1). sample x from uniform or standard normal distribution
    (2). transfrom x into z by (a). affine coupling (invertible) layer 
                               (b). MLP (non-invertible)
    (3). take input of type [h, t], where h is the hidden outputs from the former layer
         and t is the timepoint vector for the batch
    (4). construct the fourier random features sin(t*z) and cos(t*z) or
         sin([h,t]*z) and cos([h,t]*z) depending on the 'time_only' argument 
    (5). return [h, fourier_features] or h + fourier_features
    
    arguments
    
    in_dim: dimension of the inputs h (hidden factors from the forme layer)
    combine: how to combine h and the fourier features, i.e. add or concat
    time_dim: dimension of the fourier features (divided by 2 for cos and sin features)
    intercept: whether or not to use intercept terms b in sin(t*z + b) as fourier features
    time_only: using only time t to construct fourier features or use [h,t]
    invert_layer: whether to use affine coupling layer (otherwise use MLP) to transform samples 
    distribution: auxiliary distribution to sample from 
    init_dim: number of samples from the auxiliary distribution, i.e. dimension of x 
    num_affine_block: number of affine coupling blocks
    
    input 
    
    hidden units h from the former layer and the corresponding time points t
    
    output
    
    [hidden outputs from the former layer, random Fourier features (sin and cos)] 
    
    '''
    
    def __init__(self, in_dim, 
                 combine = 'concat', 
                 time_dim = None, 
                 intercept=True, 
                 time_only=False,
                 invert_layer=True,
                 distribution = 'normal', 
                 init_dim = 100,
                 num_affine_block = 1,
                 inter_dims = [32,16],
                 actnorm = False,
                 l2_reg = 0,
                 name = None):
        super(rand_feat_layer, self).__init__(name)
        
        self.in_dim = in_dim
        self.time_only = time_only
        self.num_affine_block = num_affine_block
        
        assert num_affine_block >= 1
        assert combine in ['concat', 'add']
        assert distribution in ['normal', 'uniform']
        
        if time_dim is None:
            assert in_dim % 2 == 0
            print('time dim not provided')
            self.combine = 'add' 
            self.feat_dim = int(in_dim / 2)
            self.raw_feat_affine = layers.Dense(in_dim)
        else:
            assert time_dim % 2 == 0
            self.combine = combine 
            self.feat_dim = int(time_dim / 2)
        
        if intercept:
            self.b1 = self.add_weight(name='b_1', shape=[self.feat_dim], initializer='zeros', trainable=True)
            self.b2 = self.add_weight(name='b_2', shape=[self.feat_dim], initializer='zeros', trainable=True)
        else:
            self.b1 = tf.constant([0.] * (self.feat_dim), name='b_1')
            self.b2 = tf.constant([0.] * (self.feat_dim), name='b_2')
        
        if invert_layer:
            if not time_only:
                self.transform_layer = affine_coupling_layer(init_dim, inter_dims, 
                                                             self.in_dim + 1, blocks = self.num_affine_block,
                                                             actnorm = actnorm,
                                                             name = 'affine_coupling_layer',
                                                             l2_reg = l2_reg)
            else:
                self.transform_layer = affine_coupling_layer(init_dim, inter_dims,
                                                             1, blocks = self.num_affine_block,
                                                             actnorm = actnorm,
                                                             name = 'affine_coupling_layer',
                                                             l2_reg = l2_reg)
        else:
            self.transform_layer = tf.keras.Sequential(name='mlp_transform_layer')
            self.transform_layer.add(layers.Dense(32, activation='relu', name='mlp_dense_1', kernel_regularizer = tf.keras.regularizers.l2(l2_reg)))
            self.transform_layer.add(layers.Dense(16, activation='relu', name='mlp_dense_2', kernel_regularizer = tf.keras.regularizers.l2(l2_reg)))
            if not time_only:
                self.transform_layer.add(layers.Dense((self.in_dim + 1), name='mlp_final', kernel_regularizer = tf.keras.regularizers.l2(l2_reg)))
            else:
                self.transform_layer.add(layers.Dense(1, name='mlp_final', kernel_regularizer = tf.keras.regularizers.l2(l2_reg)))
                
            
        if distribution == 'uniform':
            self.rand_feat = tf.constant(np.random.uniform(low=-1, high=1, size=(self.feat_dim, init_dim)).astype(np.float32),
                                        name='rand_feat_init')
        if distribution == 'normal':
            self.rand_feat = tf.constant(np.random.randn(self.feat_dim, init_dim).astype(np.float32),
                                        name='rand_feat_init')
        
            
    def call(self, inputs, time_inputs):
        
        rand_weights = self.transform_layer(self.rand_feat)
        
        if not self.time_only:
            full_inputs = tf.concat((inputs, time_inputs), axis=-1)
        else:
            full_inputs = time_inputs

        sin_feat = tf.math.sin(tf.add(tf.linalg.matmul(full_inputs, rand_weights, transpose_b=True), self.b1)) / np.sqrt(self.feat_dim)
        cos_feat = tf.math.cos(tf.add(tf.linalg.matmul(full_inputs, rand_weights, transpose_b=True), self.b2)) / np.sqrt(self.feat_dim)
        
        if self.combine == 'add':
            #inputs = self.raw_feat_affine(inputs)
            time_feat = tf.concat((sin_feat, cos_feat), axis=-1) #[B, L, time_dim]
            time = tf.expand_dims(time_feat, -1) #[B, L, time_dim, 1]
            inputs = tf.expand_dims(inputs, -2) #[B, L, 1, dim]
            out = tf.linalg.matmul(time, inputs) #[B, L, time_dim, dim]
            shape = out.shape
            out = tf.reshape(out, [-1, shape[1], shape[2] * shape[3]])
            return out
        else:
            return tf.concat((inputs, sin_feat, cos_feat), axis=-1)
        


# class TimeCNNTransform(tf.keras.layers.Layer):
    
#     def __init__(self, in_dim = in_dim, name=None, **kwargs):
#         super(TimeCNNTransform, self).__init__(name=name)
        
#         self.transform = rand_feat_layer(in_dim=in_dim, **kwargs)
        
#     def call(self, inputs):
        
#         inputs = inputs[:,:,:-1]
#         time_inputs = tf.expand_dims(inputs[:,:,-1], axis=-1)
        
#         outputs = self.transform(inputs, time_inputs)
#         return outputs
        



class TimeRNNCell(tf.keras.layers.Layer):
    
    '''
    
                         h1(t1)                    h2(t2)
                            |                       |
    rnn_cell1 -> (s1) -> (s1(t1)) -> rnn_cell2 -> ...
        |                   |           |           |
        x1                  t1          x2          t2
     
     
    takes [batch, timestep, [x, t]] as input
    returns output (at this timestep), state (pass to the next timestep)
    for now, output and state are set to be equal
    
    usage is basically the same as simpleRNNCell:
    can be organized into RNN by tf.kerars.layers.RNN(TimeRNNCell)
    
    arguments 
    
    output_dim: dimension of the hidden units (output and state) in rnn 
    cell_type: rnn, gru or lstm cell 
    the rest arguments are the same as rand_feat_layer
    '''
    
    def __init__(self, output_dim, name = None, 
                 cell_type = 'rnn', combine = 'concat', **kwargs):
        super(TimeRNNCell, self).__init__(name)
        
        self.combine = combine
        self.state_size = tf.TensorShape([output_dim]) # RNNCell are required to have state_size and output_size
        self.output_size = tf.TensorShape([output_dim])
        
        assert cell_type in ['gru', 'lstm', 'rnn']
        
        if self.combine == 'add':
            assert output_dim % 2 == 0
        
        if cell_type == 'rnn':
            self.rnn = tf.keras.layers.SimpleRNNCell(units = output_dim, name='simple_rnn')
        if cell_type == 'gru':
            self.rnn = tf.keras.layers.GRUCell(units = output_dim, name='gru')
        if cell_type == 'lstm':
            self.rnn = tf.keras.layers.LSTMCell(units = output_dim, name='lstm')
            
        self.time_layer = rand_feat_layer(in_dim = output_dim, 
                                          combine = combine, 
                                          name = 'time_layer',
                                          **kwargs)
        
        if self.combine == 'concat':
            # use a feed-forward layer to combine 
            self.fnn = tf.keras.Sequential(name='add_method_ffn')
            self.fnn.add(layers.Dense(output_dim, name='add_method_dense'))
            
    
    def call(self, inputs, states):
        
        hidden_inputs = inputs[:,:-1]
        time_inputs = tf.expand_dims(inputs[:,-1], axis=-1)
        rnn_states, _ = self.rnn(inputs = hidden_inputs, states = states)

        rnn_rand_feat_states = self.time_layer(rnn_states, time_inputs)
        
        if self.combine == 'concat':
            rnn_rand_feat_states = self.fnn(rnn_rand_feat_states)
                
        return rnn_rand_feat_states, [rnn_rand_feat_states]

        
        
class BaseRNNCell(tf.keras.layers.Layer):
    
    '''
     [h1,t1]             [h2,t2]
        |                   |
    rnn_cell1 -> (s1) -> rnn_cell2 -> ...
        |                   |
     [x1,t1]             [x2,t2]
    
    An extension of ordinary rnn cell, e.g. simpleRNNCell, GRUCell
    
    takes [batch, timestep, [x, t]] as input, with the additional /
    feature that t is not involved in RNN computing, but directly /
    appeneded to the outputs (not states) and pass to the next layer.
    
    usage is basically the same as simpleRNNCell:
    can be organized into RNN by tf.kerars.layers.RNN(BaseRNNCell)
    
    arguments
    
    output_dim: dimension of the hidden units (output and state) in rnn 
    cell_type: rnn, gru or lstm cell 
    '''
    
    def __init__(self, output_dim, cell_type = 'rnn', name = None):
        super(BaseRNNCell, self).__init__(name)
        
        self.state_size = tf.TensorShape([output_dim])
        self.output_size = tf.TensorShape([output_dim])
        assert cell_type in ['rnn', 'gru', 'lstm']
        
        if cell_type == 'rnn':
            self.rnn = tf.keras.layers.SimpleRNNCell(units = output_dim, name='simple_rnn')
        if cell_type == 'gru':
            self.rnn = tf.keras.layers.GRUCell(units = output_dim, name='gru')
        if cell_type == 'lstm':
            self.rnn = tf.keras.layers.LSTMCell(units = output_dim, name='lstm')
        

    def call(self, inputs, states):
        hidden_inputs = inputs[:,:-1]
        time_inputs = tf.reshape(inputs[:,-1], (-1,1))
        rnn_states, _ = self.rnn(inputs = hidden_inputs, states = states)
            
        return tf.concat((rnn_states, time_inputs), axis = -1), [rnn_states]
    
    
    
    
class RNNOutputTransform(tf.keras.layers.Layer):
    '''
      h1(t1)               h2(t2)
        |                   |
    -------------------------------
         RNNOutputTransform
    -------------------------------
     [h1,t1]             [h2,t2]
        |                   |
    rnn_cell1 -> (s1) -> rnn_cell2 -> ...
    '''
    
    def __init__(self, in_dim, use_ffn = False, ffn_dims = None, name=None, **kwargs):
        
        super(RNNOutputTransform, self).__init__(name)
        
        self.rand_feat_layer = rand_feat_layer(in_dim=in_dim, name='rand_feat_layer', **kwargs)
        
        self.use_ffn = use_ffn
        if self.use_ffn:
            self.ffn = tf.keras.Sequential(name='ffn_model')
            assert type(ffn_dims) is list
            assert len(ffn_dims) >= 1
            for i, dim in enumerate(ffn_dims):
                self.ffn.add(tf.keras.layers.Dense(dim, name='ffn_dense_'+str(i)))                
        
    def call(self, inputs):
        
        hidden_inputs = inputs[:,:,:-1]
        time_inputs = tf.expand_dims(inputs[:,:,-1], axis = -1)
        outputs = self.rand_feat_layer(hidden_inputs, time_inputs)
        
        if self.use_ffn:
            outputs = self.ffn(outputs)
        
        return outputs