import tensorflow as tf
from tensorflow import keras
from typing import *

from tensorflow.python.ops.init_ops_v2 import _compute_fans
import deepxde as dde
import tensorflow_probability as tfp
import numpy as np

tfb = tfp.bijectors
tfd = tfp.distributions


class Sine(keras.layers.Layer):
    def __init__(self, w0: float = 1.0, **kwargs):
        """
        Sine Activation.

        :param w0:
        :type w0:
        :param kwargs:
        :type kwargs:
        """
        super(Sine, self).__init__(**kwargs)
        self.w0 = w0

    def call(self, inputs):
        return tf.sin(self.w0 * inputs)

    def get_config(self):
        config = {'w0': self.w0}
        base_config = super(Sine, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class SIRENFirstLayerInitializer(tf.keras.initializers.Initializer):

    def __init__(self, scale=1.0, seed=None):
        super().__init__()
        self.scale = scale
        self.seed = seed

    def __call__(self, shape, dtype=tf.float32):
        fan_in, fan_out = _compute_fans(shape)
        limit = self.scale / max(1.0, float(fan_in))
        return tf.random.uniform(shape, -limit, limit, seed=self.seed)

    def get_config(self):
        base_config = super().get_config()
        config = {
            'scale': self.scale,
            'seed': self.seed
        }
        return dict(list(base_config.items()) + list(config.items()))


class SIRENInitializer(tf.keras.initializers.VarianceScaling):

    def __init__(self, w0: float = 1.0, c: float = 6.0, seed: int = None):
        # Uniform variance scaler multiplies by 3.0 for limits, so scale down here to compensate
        self.w0 = w0
        self.c = c
        scale = c / (3.0 * w0 * w0)
        super(SIRENInitializer, self).__init__(scale=scale, mode='fan_in', distribution='uniform', seed=seed)

    def get_config(self):
        base_config = super().get_config()
        config = {
            'w0': self.w0,
            'c': self.c
        }
        return dict(list(base_config.items()) + list(config.items()))


class DenseSIREN(tf.keras.layers.Dense):

    def __init__(self,
                 units,
                 w0: float = 1.0,
                 c: float = 6.0,
                 activation='sine',
                 use_bias=True,
                 kernel_initializer='siren_uniform',
                 bias_initializer='he_uniform',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        """
        Sine Representation Dense layer. Extends Dense layer.
        Constructs weights which support `w0` scaling per layer along with change to `c`
        from the paper "Implicit Neural Representations with Periodic Activation Functions".
        Args:
            units: Positive integer, dimensionality of the output space.
            w0: w0 in the activation step `act(x; w0) = sin(w0 * x)`.
            c: Recommended value to scale the distribution when initializing
                weights.
            activation: Activation function to use.
                If you don't specify anything, no activation is applied
                (ie. "linear" activation: `a(x) = x`).
            use_bias: Boolean, whether the layer uses a bias vector.
            kernel_initializer: Initializer for the `kernel` weights matrix.
            bias_initializer: Initializer for the bias vector.
            kernel_regularizer: Regularizer function applied to
                the `kernel` weights matrix.
            bias_regularizer: Regularizer function applied to the bias vector.
            activity_regularizer: Regularizer function applied to
                the output of the layer (its "activation")..
            kernel_constraint: Constraint function applied to
                the `kernel` weights matrix.
            bias_constraint: Constraint function applied to the bias vector.
        # References:
            -   [Implicit Neural Representations with Periodic Activation Functions](https://arxiv.org/abs/2006.09661)
        """
        self.w0 = float(w0)
        self.c = float(c)

        if activation == 'sine':
            activation = Sine(w0=w0)

        if kernel_initializer == 'siren_uniform':
            kernel_initializer = SIRENInitializer(w0=w0, c=c)

        if bias_initializer == 'siren_uniform':
            bias_initializer = SIRENInitializer(w0=w0, c=c)

        super().__init__(
            units=units,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs)

    def get_config(self):
        base_config = super(DenseSIREN, self).get_config()
        config = {
            'w0': self.w0,
            'c': self.c
        }
        return dict(list(base_config.items()) + list(config.items()))


def clip_output(x):
    """
    clip by value from tensorflow is slow, just do it by hand..
    :param x:
    :type x:
    :return:
    :rtype:
    """
    max_val = 1e13
    cond = tf.less(x, max_val * tf.ones(tf.shape(x)))
    out = tf.where(cond, x, max_val * tf.ones(tf.shape(x)))
    return out


class SIREN(dde.nn.tensorflow.NN):
    def __init__(self,
                 units: int,
                 output_units: int,
                 output_activation: str = "linear",
                 num_layers: int = 1,
                 w0: float = 30.0,
                 w0_initial: float = 30.0,
                 initial_layer_init: str = 'siren_first_uniform',
                 include_first_layer=True,
                 use_bias: bool = True, **kwargs):
        """
        SIREN model from the paper [Implicit Neural Representations with Periodic Activation Functions](https://arxiv.org/abs/2006.09661).
        Used to create a multi-layer MLP using SinusodialRepresentationDense layers.
        Args:
            units: Number of hidden units in the intermediate layers.
            output_units: Number of hidden units in the final layer.
            output_activation: Activation function of the final layer.
            num_layers: Number of layers in the network.
            w0: w0 in the activation step `act(x; w0) = sin(w0 * x)`.
            w0_initial: By default, scales `w0` of first layer to 30 (as used in the paper).
            initial_layer_init: Initialization for the first SIREN layer.
                Can be any valid keras initialization object or string.
                For SIREN, use `siren_uniform` for the general initialization,
                or `siren_first_uniform` which is specific for first layer.
            use_bias: Boolean whether to use bias or not.
        # References:
            -   [Implicit Neural Representations with Periodic Activation Functions](https://arxiv.org/abs/2006.09661)
        """
        super(SIREN, self).__init__()

        siren_layers = [DenseSIREN(units, w0=w0_initial, use_bias=use_bias,
                                   kernel_initializer=initial_layer_init if include_first_layer else "siren_uniform",
                                   **kwargs)]
        for _ in range(num_layers - 1):
            siren_layers.append(DenseSIREN(units, w0=w0, use_bias=use_bias, **kwargs))

        self.siren_layers = tf.keras.Sequential(siren_layers)
        self.final_dense = DenseSIREN(output_units, activation=output_activation,
                                      use_bias=use_bias, **kwargs)

    def call(self, inputs, training=None, mask=None):

        if self._input_transform is not None:
            inputs = self._input_transform(inputs)

        features = self.siren_layers(inputs, training=training)
        output = self.final_dense(features, training=training)
        # gauss_scaling = self.gaussian.prob(inputs[..., :-1])
        #
        if self._output_transform is not None:
            return self._output_transform(inputs, output)
        else:
            return output
        # clip to enforce for finite integral, multiply (or sum since in log space) with gaussian
        #
        # return clip_output(output) + tf.stop_gradient(-tf.square(inputs[..., :-1])/.3) #* gauss_scaling


def trainable_lu_factorization(event_size, batch_shape=(), seed=None, dtype=tf.float32, name=None):
    with tf.name_scope(name or 'trainable_lu_factorization'):
        event_size = tf.convert_to_tensor(
            event_size, dtype_hint=tf.int32, name='event_size')
        batch_shape = tf.convert_to_tensor(
            batch_shape, dtype_hint=event_size.dtype, name='batch_shape')
        random_matrix = tf.random.uniform(
            shape=tf.concat([batch_shape, [event_size, event_size]], axis=0),
            dtype=dtype,
            seed=seed)
        random_orthonormal = tf.linalg.qr(random_matrix)[0]
        lower_upper, permutation = tf.linalg.lu(random_orthonormal)
        lower_upper = tf.Variable(
            initial_value=lower_upper,
            trainable=True,
            name='lower_upper')
        # Initialize a non-trainable variable for the permutation indices so
        # that its value isn't re-sampled from run-to-run.
        permutation = tf.Variable(
            initial_value=permutation,
            trainable=False,
            name='permutation')
        return lower_upper, permutation


def create_bijections(num_layers=3, hidden_units=128, ndim=2) -> List[tfb.Bijector]:
    """
    Return list of bijection layers
    """
    my_bijects = []
    # loop over desired bijectors and put into list
    for i in range(num_layers):
        # Syntax to make a MAF
        anet = tfb.AutoregressiveNetwork(
            params=ndim, hidden_units=[hidden_units, hidden_units],
            activation="relu"
        )
        ab = tfb.MaskedAutoregressiveFlow(anet)
        # Add bijector to list
        my_bijects.append(ab)

        # Now permute (!important!)
        permute = tfb.Permute([1, 0])
        my_bijects.append(permute)
        #
        # if i == 0:
        #     my_bijects.append(tfb.Tanh())

        # conv1x1 = tfb.ScaleMatvecLU(*trainable_lu_factorization(ndim),
        #                             validate_args=True)

        # my_bijects.append(tfb.Invert(conv1x1))
        # my_bijects.append(tfb.Shift(tf.Variable(0., trainable=True)))
        my_bijects.append(tfb.glow.ActivationNormalization(ndim))

    # return list of bijectors
    return my_bijects


def build_flow(dims, num_layers=3, num_units=64):
    base_loc = np.array([0.0, -1]).astype(np.float32)
    base_sigma = np.array([0.2, 0.5]).astype(np.float32)
    distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)

    bijector = tfb.Chain(create_bijections(num_layers, hidden_units=num_units, ndim=dims))

    maf = tfd.TransformedDistribution(
        distribution=distribution,
        bijector=bijector)

    return maf

class Flow(dde.nn.tensorflow.NN):
    def __init__(self, units: int = 64, dim: int = 2, num_layers: int = 3, use_bias: bool = True, **kwargs):
        """
        SIREN model from the paper [Implicit Neural Representations with Periodic Activation Functions](https://arxiv.org/abs/2006.09661).
        Used to create a multi-layer MLP using SinusodialRepresentationDense layers.
        Args:
            units: Number of hidden units in the intermediate layers.
            output_units: Number of hidden units in the final layer.
            output_activation: Activation function of the final layer.
            num_layers: Number of layers in the network.
        """
        super().__init__()
        self.net = build_flow(dims=dim, num_units=units, num_layers=num_layers)

    def call(self, inputs, training=None, mask=None):
        if self._input_transform is not None:
            inputs = self._input_transform(inputs)
        output = self.net.log_prob(inputs, training=training)[..., tf.newaxis]
        # gauss_scaling = self.gaussian.prob(inputs[..., :-1])
        #
        if self._output_transform is not None:
            return self._output_transform(inputs, output)
        else:
            return output
        # clip to enforce for finite integral, multiply (or sum since in log space) with gaussian
        #
        # return clip_output(output) + tf.stop_gradient(-tf.square(inputs[..., :-1])/.3) #* gauss_scaling


tf.keras.utils.get_custom_objects().update({
    'sine': Sine,
    'siren_uniform': SIRENInitializer,
    'siren_first_uniform': SIRENFirstLayerInitializer,
    'DenseSIREN': DenseSIREN
})
