"""RealNVP bijector flow."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_probability import bijectors

import numpy as np

__all__ = [
    "ConditionalRealNVPFlow",
]


def _use_static_shape(input_tensor, ndims):
    return input_tensor.shape.is_fully_defined() and isinstance(ndims, int)


class ConditionalChain(bijectors.ConditionalBijector, bijectors.Chain):
    pass


class ConditionalRealNVPFlow(bijectors.ConditionalBijector):
    """TODO"""

    def __init__(self,
                 num_coupling_layers=2,
                 hidden_layer_sizes=(64,),
                 use_batch_normalization=False,
                 event_dims=None,
                 is_constant_jacobian=False,
                 validate_args=False,
                 name="conditional_real_nvp_flow"):
        """Instantiates the `ConditionalRealNVPFlow` normalizing flow.

        Args:
            is_constant_jacobian: Python `bool`. Default: `False`. When `True` the
                implementation assumes `log_scale` does not depend on the forward domain
                (`x`) or inverse domain (`y`) values. (No validation is made;
                `is_constant_jacobian=False` is always safe but possibly computationally
                inefficient.)
            validate_args: Python `bool` indicating whether arguments should be
                checked for correctness.
            name: Python `str`, name given to ops managed by this object.

        Raises:
            ValueError: if TODO happens
        """
        self._graph_parents = []
        self._name = name

        self._num_coupling_layers = num_coupling_layers
        self._hidden_layer_sizes = tuple(hidden_layer_sizes)
        if use_batch_normalization:
            raise NotImplementedError(
                "TODO(hartikainen): Batch normalization is not yet supported"
                " for ConditionalRealNVPFlow.")
        self._use_batch_normalization = use_batch_normalization

        assert event_dims is not None, event_dims
        self._event_dims = event_dims

        self.build()

        super(ConditionalRealNVPFlow, self).__init__(
            forward_min_event_ndims=1,
            inverse_min_event_ndims=1,
            is_constant_jacobian=is_constant_jacobian,
            validate_args=validate_args,
            name=name)

    def build(self):
        D = np.prod(self._event_dims)

        flow = []
        for i in range(self._num_coupling_layers):
            if self._use_batch_normalization:
                batch_normalization_bijector = bijectors.BatchNormalization()
                flow.append(batch_normalization_bijector)

            real_nvp_bijector = bijectors.RealNVP(
                num_masked=D // 2,
                shift_and_log_scale_fn=conditioned_real_nvp_template(
                    hidden_layers=self._hidden_layer_sizes,
                    # TODO: test tf.nn.relu
                    activation=tf.nn.tanh),
                name='real_nvp_{}'.format(i))

            flow.append(real_nvp_bijector)

            if i < self._num_coupling_layers - 1:
                permute_bijector = bijectors.Permute(
                    permutation=list(reversed(range(D))),
                    name='permute_{}'.format(i))
                # TODO(hartikainen): We need to force _is_constant_jacobian due
                # to the event_dim caching. See the issue filed at github:
                # https://github.com/tensorflow/probability/issues/122
                permute_bijector._is_constant_jacobian = False
                flow.append(permute_bijector)

        # Note: bijectors.Chain applies the list of bijectors in the
        # _reverse_ order of what they are inputted.
        self.flow = flow

    def _get_flow_conditions(self, **condition_kwargs):
        conditions = {
            bijector.name: condition_kwargs
            for bijector in self.flow
            if isinstance(bijector, bijectors.RealNVP)
        }

        return conditions

    def _forward(self, x, **condition_kwargs):
        conditions = self._get_flow_conditions(**condition_kwargs)
        for bijector in self.flow:
            x = bijector.forward(x, **conditions.get(bijector.name, {}))

        # TODO(hartikainen): Once tfp.bijectors.Chain supports conditioning,
        # replace the above for-loops with self.flow.forward.
        # x = self.flow.forward(x, **conditions)

        return x

    def _inverse(self, y, **condition_kwargs):
        conditions = self._get_flow_conditions(**condition_kwargs)
        for bijector in reversed(self.flow):
            y = bijector.inverse(y, **conditions.get(bijector.name, {}))

        # TODO(hartikainen): Once tfp.bijectors.Chain supports conditioning,
        # replace the above for-loops with self.flow.inverse.
        # y = self.flow.inverse(y, **conditions)

        return y

    def _forward_log_det_jacobian(self, x, **condition_kwargs):
        conditions = self._get_flow_conditions(**condition_kwargs)

        # TODO(hartikainen): Once tfp.bijectors.Chain supports conditioning,
        # replace everything below with self.flow.forward_log_det_jacobian.
        # fldj = self.flow.forward_log_det_jacobian(
        #     x, event_ndims=1, **conditions)

        fldj = tf.cast(0., dtype=x.dtype.base_dtype)
        event_ndims = self._maybe_get_static_event_ndims(
            self.forward_min_event_ndims)

        if _use_static_shape(x, event_ndims):
            event_shape = x.shape[x.shape.ndims - event_ndims:]
        else:
            event_shape = tf.shape(x)[tf.rank(x) - event_ndims:]
        for b in self.flow:
            fldj += b.forward_log_det_jacobian(
                x, event_ndims=event_ndims, **conditions.get(b.name, {}))
            if _use_static_shape(x, event_ndims):
                event_shape = b.forward_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims)
            else:
                event_shape = b.forward_event_shape_tensor(event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
                if event_ndims_ is not None:
                    event_ndims = event_ndims_
            x = b.forward(x, **conditions.get(b.name, {}))

        return fldj

    def _inverse_log_det_jacobian(self, y, **condition_kwargs):
        conditions = self._get_flow_conditions(**condition_kwargs)

        # TODO(hartikainen): Once tfp.bijectors.Chain supports conditioning,
        # replace everything below with self.flow.inverse_log_det_jacobian.
        # ildj = self.flow.inverse_log_det_jacobian(
        #     y, event_ndims=1, **conditions)

        ildj = tf.cast(0., dtype=y.dtype.base_dtype)

        event_ndims = self._maybe_get_static_event_ndims(
            self.inverse_min_event_ndims)

        if _use_static_shape(y, event_ndims):
            event_shape = y.shape[y.shape.ndims - event_ndims:]
        else:
            event_shape = tf.shape(y)[tf.rank(y) - event_ndims:]

        for b in reversed(self.flow):
            ildj += b.inverse_log_det_jacobian(
                y, event_ndims=event_ndims, **conditions.get(b.name, {}))

            if _use_static_shape(y, event_ndims):
                event_shape = b.inverse_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    event_shape.ndims)
            else:
                event_shape = b.inverse_event_shape_tensor(event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
                if event_ndims_ is not None:
                    event_ndims = event_ndims_

            y = b.inverse(y, **conditions.get(b.name, {}))

        return ildj


def conditioned_real_nvp_template(hidden_layers,
                                  shift_only=False,
                                  activation=tf.nn.relu,
                                  name=None,
                                  *args,  # pylint: disable=keyword-arg-before-vararg
                                  **kwargs):

    with tf.name_scope(name, "conditioned_real_nvp_template"):

        def _fn(x, output_units, **condition_kwargs):
            """MLP which concatenates the condition kwargs to input."""
            x = tf.concat(
                (x, *[condition_kwargs[k] for k in sorted(condition_kwargs)]),
                axis=-1)

            for units in hidden_layers:
                x = tf.layers.dense(
                    inputs=x,
                    units=units,
                    activation=activation,
                    *args,  # pylint: disable=keyword-arg-before-vararg
                    **kwargs)
            x = tf.layers.dense(
                inputs=x,
                units=(1 if shift_only else 2) * output_units,
                activation=None,
                *args,  # pylint: disable=keyword-arg-before-vararg
                **kwargs)

            if shift_only:
                return x, None

            shift, log_scale = tf.split(x, 2, axis=-1)
            return shift, log_scale

        return tf.make_template("conditioned_real_nvp_template", _fn)
