from turtle import width
from typing import Iterable, List, Dict, Tuple, Literal

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow_probability.python.internal import samplers
from tensorflow.keras.layers import (
    Dense
)
from ..set_transformer.layers import (
    RFF
)
from ..utils.geometry import (
    tril_eye
)

tfd = tfp.distributions
tfb = tfp.bijectors


class BatchedHighwayFlow(tfp.experimental.bijectors.HighwayFlow):
    """Modifies the TFP implem to allow for batching"""
    def _gated_residual_fraction(self):
        """Returns a vector of residual fractions that
        encodes gated dimensions."""
        return self.residual_fraction * tf.concat(
            [
                tf.ones([self.gate_first_n], dtype=self.dtype),
                tf.zeros([self.num_ungated], dtype=self.dtype)
            ],
            axis=0
        )


class ConditionalHighwayFlow(tfb.Bijector):

    def __init__(
        self,
        rff_kwargs: Dict,
        event_size: int,
        **kwargs
    ):
        """Highway Flow where all weights are functions of a
        conditional input

        Parameters
        ----------
        rff_kwargs : Dict
            for the weights regressors
        event_size : int
            from the base distribution
        """
        super().__init__(
            forward_min_event_ndims=1,
            **kwargs
        )

        self.lower_bijector = tfb.Chain(
            [
                tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)),
                tfb.Pad(paddings=[(1, 0), (0, 1)]),
                tfb.FillTriangular(),
            ]
        )
        lower_size = self.lower_bijector.inverse_event_shape(
            [event_size, event_size]
        )[0]
        self.lower_rff = tf.keras.Sequential(
            layers=[
                RFF(**rff_kwargs),
                Dense(
                    units=lower_size,
                    kernel_initializer=(
                        tf.keras.initializers
                        .RandomNormal(
                            mean=0.0,
                            stddev=1e-4
                        )
                    ),
                    bias_initializer="zeros"
                )
            ],
        )

        self.upper_bijector = tfb.FillScaleTriL(
            diag_bijector=tfb.Softplus(),
            diag_shift=None
        )
        upper_size = self.upper_bijector.inverse_event_shape(
            [event_size, event_size]
        )[0]
        self.upper_rff = tf.keras.Sequential(
            layers=[
                RFF(**rff_kwargs),
                Dense(
                    units=upper_size,
                    kernel_initializer=(
                        tf.keras.initializers
                        .RandomNormal(
                            mean=0.0,
                            stddev=1e-4
                        )
                    ),
                    bias_initializer="zeros"
                )
            ],
        )
        self.bias_rff = tf.keras.Sequential(
            layers=[
                RFF(**rff_kwargs),
                Dense(
                    units=event_size,
                    kernel_initializer=(
                        tf.keras.initializers
                        .RandomNormal(
                            mean=0.0,
                            stddev=1e-4
                        )
                    ),
                    bias_initializer="zeros"
                )
            ],
        )

    def forward(
        self,
        x: tf.Tensor,
        conditional_input: tf.Tensor,
        residual_fraction: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """HighwayFlow(...).forward(x)

        Parameters
        ----------
        x : tf.Tensor
            shape: batch_shape + (event_size,)
        conditional_input : tf.Tensor
            shape: batch_shape + (conditional_size,)
        residual_fraction
            shape: batch_shape + (1,)

        Returns
        -------
        tf.Tensor
            shape: batch_shape + (event_size,)
        """
        return BatchedHighwayFlow(
            residual_fraction=residual_fraction,
            activation_fn=tf.nn.softplus,
            bias=self.bias_rff(conditional_input),
            upper_diagonal_weights_matrix=self.upper_bijector(
                self.upper_rff(
                    conditional_input
                )
            ),
            lower_diagonal_weights_matrix=self.lower_bijector(
                self.lower_rff(
                    conditional_input
                )
            ),
            gate_first_n=None
        ).forward(x)

    def inverse(
        self,
        y: tf.Tensor,
        conditional_input: tf.Tensor,
        residual_fraction: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """HighwayFlow(...).inverse(y)

        Parameters
        ----------
        y : tf.Tensor
            shape: batch_shape + (event_size,)
        conditional_input : tf.Tensor
            shape: batch_shape + (conditional_size,)
        residual_fraction
            shape: batch_shape + (1,)

        Returns
        -------
        tf.Tensor
            shape: batch_shape + (event_size,)
        """
        return BatchedHighwayFlow(
            residual_fraction=residual_fraction,
            activation_fn=tf.nn.softplus,
            bias=self.bias_rff(conditional_input),
            upper_diagonal_weights_matrix=self.upper_bijector(
                self.upper_rff(
                    conditional_input
                )
            ),
            lower_diagonal_weights_matrix=self.lower_bijector(
                self.lower_rff(
                    conditional_input
                )
            ),
            gate_first_n=None
        ).inverse(y)

    def inverse_log_det_jacobian(
        self,
        y: tf.Tensor,
        event_ndims: int,
        conditional_input: tf.Tensor,
        residual_fraction: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """HighwayFlow(...).inverse_log_det_jacobian(y, event_ndims)

        Parameters
        ----------
        y : tf.Tensor
            shape: batch_shape + (event_size,)
        event_ndims : int
            rank of event shape (i.e. 1 here)
        conditional_input : tf.Tensor
            shape: batch_shape + (conditional_size,)
        residual_fraction
            shape: batch_shape + (1,)

        Returns
        -------
        tf.Tensor
            shape: batch_shape
        """
        return BatchedHighwayFlow(
            residual_fraction=residual_fraction,
            activation_fn=tf.nn.softplus,
            bias=self.bias_rff(conditional_input),
            upper_diagonal_weights_matrix=self.upper_bijector(
                self.upper_rff(
                    conditional_input
                )
            ),
            lower_diagonal_weights_matrix=self.lower_bijector(
                self.lower_rff(
                    conditional_input
                )
            ),
            gate_first_n=None
        ).inverse_log_det_jacobian(
            y=y,
            event_ndims=event_ndims
        )


def build_batched_highway_flow(
    width,
    residual_fraction_initial_value=0.5,
    activation_fn=None,
    gate_first_n=None,
    seed=None,
    validate_args=False
) -> BatchedHighwayFlow:
    """Builds a HighwayFlow parameterized by trainable variables.
    The variables are transformed to enforce the following parameter constraints:
    - `residual_fraction` is bounded between 0 and 1.
    - `upper_diagonal_weights_matrix` is a randomly initialized (lower) diagonal
        matrix with positive diagonal of size `width x width`.
    - `lower_diagonal_weights_matrix` is a randomly initialized lower diagonal
        matrix with ones on the diagonal of size `width x width`;
    - `bias` is a randomly initialized vector of size `width`.
    Args:
    width: Input dimension of the bijector.
    residual_fraction_initial_value: Initial value for gating parameter, must be
        between 0 and 1.
    activation_fn: Callable invertible activation function
        (e.g., `tf.nn.softplus`), or `None`.
    gate_first_n: Decides which part of the input should be gated (useful for
        example when using auxiliary variables).
    seed: Seed for random initialization of the weights.
    validate_args: Python `bool`. Whether to validate input with runtime
        assertions.
        Default value: `False`.
    Returns:
    trainable_highway_flow: The initialized bijector.
    """

    residual_fraction_initial_value = tf.convert_to_tensor(
        residual_fraction_initial_value,
        dtype_hint=tf.float32,
        name='residual_fraction_initial_value')
    dtype = residual_fraction_initial_value.dtype

    bias_seed, upper_seed, lower_seed = samplers.split_seed(seed, n=3)
    lower_bijector = tfb.Chain([
        tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)),
        tfb.Pad(paddings=[(1, 0), (0, 1)]),
        tfb.FillTriangular()])
    unconstrained_lower_initial_values = samplers.normal(
        shape=lower_bijector.inverse_event_shape([width, width]),
        mean=0.,
        stddev=.01,
        seed=lower_seed)
    upper_bijector = tfb.FillScaleTriL(
        diag_bijector=tfb.Softplus(), diag_shift=None)
    unconstrained_upper_initial_values = samplers.normal(
        shape=upper_bijector.inverse_event_shape([width, width]),
        mean=0.,
        stddev=.01,
        seed=upper_seed)

    return BatchedHighwayFlow(
        residual_fraction=tfp.util.TransformedVariable(
            initial_value=residual_fraction_initial_value,
            bijector=tfb.Sigmoid(),
            dtype=dtype),
        activation_fn=activation_fn,
        bias=tf.Variable(
            samplers.normal((width,), mean=0., stddev=0.01, seed=bias_seed),
            dtype=dtype),
        upper_diagonal_weights_matrix=tfp.util.TransformedVariable(
            initial_value=upper_bijector.forward(
                unconstrained_upper_initial_values),
            bijector=upper_bijector,
            dtype=dtype),
        lower_diagonal_weights_matrix=tfp.util.TransformedVariable(
            initial_value=lower_bijector.forward(
                unconstrained_lower_initial_values),
            bijector=lower_bijector,
            dtype=dtype),
        gate_first_n=gate_first_n,
        validate_args=validate_args)


class ConditionalAffine(tfb.Bijector):

    def __init__(
        self,
        scale_type: Literal["diag", "tril", "none"],
        rff_kwargs: Dict,
        event_size: int,
        **kwargs
    ):
        """Affine transform where the shift and scale
        are functions of a conditional input

        Parameters
        ----------
        scale_type : Literal["diag", "tril", "none]
            determines parametrization of the affine bijector
        rff_kwargs : Dict
            for shift and scale regressors
        event_size : int
            from the base distribution
        """
        super().__init__(
            forward_min_event_ndims=1,
            **kwargs
        )

        if scale_type not in [
            "diag",
            "tril",
            "none"
        ]:
            raise NotImplementedError(
                f"unknown scale type {scale_type}"
            )

        self.scale_type = scale_type

        self.shift = tf.keras.Sequential(
            layers=[
                RFF(**rff_kwargs),
                Dense(units=event_size)
            ],
        )
        if self.scale_type != "none":
            self.scale = tf.keras.Sequential(
                layers=[
                    RFF(**rff_kwargs),
                    Dense(
                        units=(
                            event_size
                            if self.scale_type == "diag"
                            else
                            (event_size * (event_size + 1)) / 2
                        )
                    )
                ],
            )

    def forward(
        self,
        x: tf.Tensor,
        conditional_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Affine(
            shift=f(conditional_input),
            scale=g(conditional_input)
        ).forward(x)

        Parameters
        ----------
        x : tf.Tensor
            shape: batch_shape + (event_size,)
        conditional_input : tf.Tensor
            shape: batch_shape + (conditional_size,)

        Returns
        -------
        tf.Tensor
            shape: batch_shape + (event_size,)
        """
        return tfb.Chain(
            [
                tfb.Shift(
                    shift=self.shift(conditional_input)
                ),
            ]
            +
            (
                [
                    tfb.ScaleMatvecDiag(
                        scale_diag=self.scale(conditional_input)
                    )
                ]
                if self.scale_type == "diag"
                else
                [
                    tfb.ScaleMatvecTriL(
                        scale_tril=tfp.math.fill_triangular(
                            self.scale(conditional_input)
                        )
                    )
                ]
                if self.scale_type == "tril"
                else
                []
            )
        ).forward(x)

    def inverse(
        self,
        y: tf.Tensor,
        conditional_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Affine(
            shift=f(conditional_input),
            scale=g(conditional_input)
        ).inverse(y)

        Parameters
        ----------
        y : tf.Tensor
            shape: batch_shape + (event_size,)
        conditional_input : tf.Tensor
            shape: batch_shape + (conditional_size,)

        Returns
        -------
        tf.Tensor
            shape: batch_shape + (event_size,)
        """
        return tfb.Chain(
            [
                tfb.Shift(
                    shift=self.shift(conditional_input)
                ),
            ]
            +
            (
                [
                    tfb.ScaleMatvecDiag(
                        scale_diag=self.scale(conditional_input)
                    )
                ]
                if self.scale_type == "diag"
                else
                [
                    tfb.ScaleMatvecTriL(
                        scale_tril=tfp.math.fill_triangular(
                            self.scale(conditional_input)
                        )
                    )
                ]
                if self.scale_type == "tril"
                else
                []
            )
        ).inverse(y)

    def inverse_log_det_jacobian(
        self,
        y: tf.Tensor,
        event_ndims: int,
        conditional_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Affine(
            shift=f(conditional_input),
            scale=g(conditional_input)
        ).forward(x)

        Parameters
        ----------
        y : tf.Tensor
            shape: batch_shape + (event_size,)
        event_ndims : int
            rank of event shape (i.e. 1 here)
        conditional_input : tf.Tensor
            shape: batch_shape + (conditional_size,)

        Returns
        -------
        tf.Tensor
            shape: batch_shape
        """
        return tfb.Chain(
            [
                tfb.Shift(
                    shift=self.shift(conditional_input)
                ),
            ]
            +
            (
                [
                    tfb.ScaleMatvecDiag(
                        scale_diag=self.scale(conditional_input)
                    )
                ]
                if self.scale_type == "diag"
                else
                [
                    tfb.ScaleMatvecTriL(
                        scale_tril=tfp.math.fill_triangular(
                            self.scale(conditional_input)
                        )
                    )
                ]
                if self.scale_type == "tril"
                else
                []
            )
        ).inverse_log_det_jacobian(
            y=y,
            event_ndims=event_ndims
        )


class Affine(tfb.Bijector):

    def __init__(
        self,
        scale_type: Literal["diag", "tril", "none"],
        event_size: int,
        **kwargs
    ):
        """Affine transform

        Parameters
        ----------
        scale_type : Literal["diag", "tril", "none]
            determines parametrization of the affine bijector
        event_size : int
            from the base distribution
        """
        super().__init__(
            forward_min_event_ndims=1,
            **kwargs
        )

        if scale_type not in [
            "diag",
            "tril",
            "none"
        ]:
            raise NotImplementedError(
                f"unknown scale type {scale_type}"
            )

        self.scale_type = scale_type

        self.shift = tf.Variable(
            tf.zeros((event_size,)),
            name="shift"
        )
        if self.scale_type != "none":
            if self.scale_type == "diag":
                self.scale = tf.Variable(
                    tf.ones((event_size,)),
                    name="diag_scale"
                )
            elif self.scale_type == "tril":
                self.scale = tfp.util.TransformedVariable(
                    tf.eye(event_size),
                    bijector=tfb.FillTriangular(),
                    name="tril_scale"
                )
            else:
                pass

        self.chain = tfb.Chain(
            [
                tfb.Shift(self.shift),
            ]
            + (
                [tfb.ScaleMatvecDiag(self.scale)]
                if self.scale_type == "diag"
                else
                [tfb.ScaleMatvecTriL(self.scale)]
                if self.scale_type == "tril"
                else
                []
            )
        )

    def forward(
        self,
        x: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Affine(
            shift=self.shift,
            scale=self.scale
        ).forward(x)

        Parameters
        ----------
        x : tf.Tensor
            shape: batch_shape + (event_size,)

        Returns
        -------
        tf.Tensor
            shape: batch_shape + (event_size,)
        """
        return self.chain.forward(x)

    def inverse(
        self,
        y: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Affine(
            shift=self.shift,
            scale=self.scale
        ).inverse(y)

        Parameters
        ----------
        y : tf.Tensor
            shape: batch_shape + (event_size,)

        Returns
        -------
        tf.Tensor
            shape: batch_shape + (event_size,)
        """
        return self.chain.inverse(y)

    def inverse_log_det_jacobian(
        self,
        y: tf.Tensor,
        event_ndims: int,
        **kwargs
    ) -> tf.Tensor:
        """Affine(
            shift=self.shift,
            scale=self.scale
        ).forward(x)

        Parameters
        ----------
        y : tf.Tensor
            shape: batch_shape + (event_size,)
        event_ndims : int
            rank of event shape (i.e. 1 here)

        Returns
        -------
        tf.Tensor
            shape: batch_shape
        """
        return self.chain.inverse_log_det_jacobian(
            y=y,
            event_ndims=event_ndims
        )


class ConditionalNFChain(tfb.Bijector):
    def __init__(
        self,
        nf_type_kwargs_per_bijector: List[
            Tuple[
                Literal["MAF", "scale", "realNVP", "Highway"],
                Dict
            ]
        ],
        event_size: int,
        conditional_event_size: int,
        with_permute: bool = True,
        with_batch_norm: bool = True,
        **kwargs
    ):
        """Stacks multiple NF bijectors of various types
        will broadcast conditional input to those

        Parameters
        ----------
        nf_type_kwargs_per_bijector : List[
            Tuple[Literal["MAF", "affine", "realNVP", "Highway"], Dict]
        ]
            list of (nf_type, nf_kwargs)
            determines number, type and parametrization
            of conditional NF bijectors in the chain
        event_size : int
            from base distribution
        conditional_event_size : int
            from upstream encoding
        with_permute : bool, optional
            permute event tensor between conditional NF bijectors,
            by default True
        with_batch_norm : bool, optional
            normalize event tensor between conditional NF bijectors,
            by default True
        """
        super().__init__(
            forward_min_event_ndims=1,
            **kwargs
        )
        self.event_size = event_size
        self.conditional_event_size = conditional_event_size
        self.regressors = {}

        self.conditional_bijectors = {}
        self.bijectors = []
        for (nf_type, nf_kwargs) in nf_type_kwargs_per_bijector:
            if nf_type == "Highway":
                self.residual_fraction_rff = tf.keras.Sequential(
                    layers=[
                        RFF(**nf_kwargs["rff_kwargs"]),
                        Dense(
                            units=1,
                            activation="sigmoid",
                            kernel_initializer=(
                                tf.keras.initializers
                                .RandomNormal(
                                    mean=0.0,
                                    stddev=1e-4
                                )
                            ),
                            bias_initializer=(
                                tf.keras.initializers
                                .Constant(
                                    value=0.5
                                )
                            )
                        )
                    ],
                )
                continue

        for b, (nf_type, nf_kwargs) in enumerate(
            nf_type_kwargs_per_bijector
        ):
            if with_permute:
                self.bijectors.append(
                    tfb.Permute(
                        permutation=(
                            np
                            .random
                            .permutation(self.event_size,)
                        )
                    )
                )
            if with_batch_norm:
                self.bijectors.append(
                    tfb.Invert(
                        tfb.BatchNormalization()
                    )
                )
            name, bijector = self.build_bijector(
                b=b,
                nf_type=nf_type,
                nf_kwargs=nf_kwargs
            )
            self.conditional_bijectors[nf_type, name] = bijector
            self.bijectors.append(bijector)

        self.chain = tfb.Chain(self.bijectors)

    def build_bijector(
        self,
        b: int,
        nf_type: Literal["MAF", "affine", "realNVP", "Highway"],
        nf_kwargs: Dict,
        **kwargs
    ) -> Tuple[str, tfb.Bijector]:
        """Returns a conditional NF bijector
        of given type

        Parameters
        ----------
        b : int
            used for naming
        nf_type : Literal["MAF", "affine", "realNVP", "Highway"]
            determines type of conditional NF bijector
        nf_kwargs : Dict
            dependent on the NF type

        Returns
        -------
        Tuple[str, tfb.Bijector]
            tuple (name : str, nf bijector : tfb.Bijector)

        Raises
        ------
        NotImplementedError
            if nf_type not in ["MAF", "affine", "realNVP", "Highway"]
        """
        if nf_type == "MAF":
            made = tfb.AutoregressiveNetwork(
                params=2,
                event_shape=self.event_size,
                conditional=True,
                conditional_event_shape=(self.conditional_event_size,),
                **nf_kwargs
            )
            name = f"MAF_{b}"

            return (
                name,
                tfb.MaskedAutoregressiveFlow(
                    made,
                    name=name
                )
            )
        elif nf_type == "affine":
            name = f"affine_{b}"
            return (
                name,
                ConditionalAffine(
                    event_size=self.event_size,
                    name=name,
                    **nf_kwargs
                )
            )
        elif nf_type == "realNVP":
            name = f"realNVP_{b}"
            for regressor in ["shift", "log_scale"]:
                self.regressors[f"{name}_{regressor}"] = (
                    tf.keras.Sequential(
                        layers=[
                            RFF(**nf_kwargs),
                            Dense(
                                units=(
                                    self.event_size
                                    -
                                    self.event_size // 2
                                )
                            )
                        ],
                    )
                )

            def shift_and_log_scale_fn(
                x: tf.Tensor,
                input_depth: int,
                conditional_input: tf.Tensor
            ) -> Tuple[tf.Tensor, tf.Tensor]:
                z = tf.concat(
                    [
                        x,
                        conditional_input
                    ],
                    axis=-1
                )
                return (
                    self.regressors[f"{name}_shift"](z),
                    self.regressors[f"{name}_log_scale"](z)
                )

            return (
                name,
                tfb.RealNVP(
                    num_masked=self.event_size // 2,
                    shift_and_log_scale_fn=shift_and_log_scale_fn,
                    name=name
                )
            )
        elif nf_type == "Highway":
            name = f"highway_{b}"
            return (
                name,
                ConditionalHighwayFlow(
                    event_size=self.event_size,
                    name=name,
                    **nf_kwargs
                )
            )
        else:
            raise NotImplementedError(
                f"{nf_type} is not a valid NF type"
            )

    def forward(
        self,
        chain_input: tf.Tensor,
        conditional_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Transforms chain_input, broadcasting
        conditional_input to all conditional NF bijectors

        Parameters
        ----------
        chain_input : tf.Tensor
            from base distribution
            shape: batch_shape + (event_size,)
        conditional_input : tf.Tensor
            from upstream encoding
            shape: batch_shape + (conditional_event_size,)

        Returns
        -------
        tf.Tensor
            transformed chain input
            shape: batch_shape + (event_size,)
        """
        return self.chain.forward(
            x=chain_input,
            **{
                key: (
                    {
                        "conditional_input": conditional_input,
                        "residual_fraction": self.residual_fraction_rff(
                            conditional_input
                        )
                    }
                    if nf_type == "Highway"
                    else
                    {
                        "conditional_input": conditional_input
                    }
                )
                for nf_type, key in (
                    self
                    .conditional_bijectors
                    .keys()
                )
            },
            **kwargs
        )

    def inverse(
        self,
        chain_input: tf.Tensor,
        conditional_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Transforms chain_input, broadcasting
        conditional_input to all conditional NF bijectors

        Parameters
        ----------
        chain_input : tf.Tensor
            from base distribution
            shape: batch_shape + (event_size,)
        conditional_input : tf.Tensor
            from upstream encoding
            shape: batch_shape + (conditional_event_size,)

        Returns
        -------
        tf.Tensor
            transformed chain input
            shape: batch_shape + (event_size,)
        """
        return self.chain.inverse(
            y=chain_input,
            **{
                key: (
                    {
                        "conditional_input": conditional_input,
                        "residual_fraction": self.residual_fraction_rff(
                            conditional_input
                        )
                    }
                    if nf_type == "Highway"
                    else
                    {
                        "conditional_input": conditional_input
                    }
                )
                for nf_type, key in (
                    self
                    .conditional_bijectors
                    .keys()
                )
            },
            **kwargs
        )

    def inverse_log_det_jacobian(
        self,
        chain_input: tf.Tensor,
        event_ndims: int,
        conditional_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Retruns inverse log det Jacobian at chain_input,
        broadcasting conditional_input to all conditional NF bijectors

        Parameters
        ----------
        chain_input : tf.Tensor
            from base distribution
            shape: batch_shape + (event_size,)
        event_ndims : int
            rank of event shape (i.e. 1 here)
        conditional_input : tf.Tensor
            from upstream encoding
            shape: batch_shape + (conditional_event_size,)

        Returns
        -------
        tf.Tensor
            shape: batch_shape
        """
        return self.chain.inverse_log_det_jacobian(
            y=chain_input,
            event_ndims=event_ndims,
            **{
                key: (
                    {
                        "conditional_input": conditional_input,
                        "residual_fraction": self.residual_fraction_rff(
                            conditional_input
                        )
                    }
                    if nf_type == "Highway"
                    else
                    {
                        "conditional_input": conditional_input
                    }
                )
                for nf_type, key in (
                    self
                    .conditional_bijectors
                    .keys()
                )
            },
            **kwargs
        )


class NFChain(tfb.Bijector):
    def __init__(
        self,
        nf_type_kwargs_per_bijector: List[
            Tuple[
                Literal["MAF", "affine", "realNVP", "Highway"],
                Dict
            ]
        ],
        event_size: int,
        with_permute: bool = True,
        with_batch_norm: bool = True,
        **kwargs
    ):
        """Stacks multiple NF bijectors of various types

        Parameters
        ----------
        nf_type_kwargs_per_bijector : List[
            Tuple[Literal["MAF", "affine", "realNVP", "Highway"], Dict]
        ]
            list of (nf_type, nf_kwargs)
            determines number, type and parametrization
            of conditional NF bijectors in the chain
        event_size : int
            from base distribution
        with_permute : bool, optional
            permute event tensor between conditional NF bijectors,
            by default True
        with_batch_norm : bool, optional
            normalize event tensor between conditional NF bijectors,
            by default True
        """
        super().__init__(
            forward_min_event_ndims=1,
            **kwargs
        )
        self.event_size = event_size

        self.bijectors = []
        for b, (nf_type, nf_kwargs) in enumerate(
            nf_type_kwargs_per_bijector
        ):
            if with_permute:
                self.bijectors.append(
                    tfb.Permute(
                        permutation=(
                            np
                            .random
                            .permutation(self.event_size,)
                        )
                    )
                )
            if with_batch_norm:
                self.bijectors.append(
                    tfb.Invert(
                        tfb.BatchNormalization()
                    )
                )
            name, bijector = self.build_bijector(
                b=b,
                nf_type=nf_type,
                nf_kwargs=nf_kwargs
            )
            self.bijectors.append(bijector)

        self.chain = tfb.Chain(self.bijectors)

    def build_bijector(
        self,
        b: int,
        nf_type: Literal["MAF", "affine", "realNVP", "Highway"],
        nf_kwargs: Dict,
        **kwargs
    ) -> Tuple[str, tfb.Bijector]:
        """Returns a conditional NF bijector
        of given type

        Parameters
        ----------
        b : int
            used for naming
        nf_type : Literal["MAF", "affine", "realNVP", "Highway"]
            determines type of conditional NF bijector
        nf_kwargs : Dict
            dependent on the NF type

        Returns
        -------
        Tuple[str, tfb.Bijector]
            tuple (name : str, nf bijector : tfb.Bijector)

        Raises
        ------
        NotImplementedError
            if nf_type not in ["MAF", "affine", "realNVP", "Highway"]
        """
        if nf_type == "MAF":
            made = tfb.AutoregressiveNetwork(
                params=2,
                event_shape=self.event_size,
                **nf_kwargs
            )
            name = f"MAF_{b}"

            return (
                name,
                tfb.MaskedAutoregressiveFlow(
                    made,
                    name=name
                )
            )
        elif nf_type == "affine":
            name = f"affine_{b}"
            return (
                name,
                Affine(
                    event_size=self.event_size,
                    name=name,
                    **nf_kwargs
                )
            )
        elif nf_type == "realNVP":
            name = f"realNVP_{b}"

            shift_and_log_scale_fn = tfb.real_nvp_default_template(
                **nf_kwargs
            )

            return (
                name,
                tfb.RealNVP(
                    num_masked=self.event_size // 2,
                    shift_and_log_scale_fn=shift_and_log_scale_fn,
                    name=name
                )
            )
        elif nf_type == "Highway":
            name = f"highway_{b}"
            return (
                name,
                build_batched_highway_flow(
                    width=self.event_size,
                    **nf_kwargs
                )
            )
        else:
            raise NotImplementedError(
                f"{nf_type} is not a valid NF type"
            )

    def forward(
        self,
        chain_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Transforms chain_input, broadcasting
        conditional_input to all conditional NF bijectors

        Parameters
        ----------
        chain_input : tf.Tensor
            from base distribution
            shape: batch_shape + (event_size,)

        Returns
        -------
        tf.Tensor
            transformed chain input
            shape: batch_shape + (event_size,)
        """
        return self.chain.forward(
            x=chain_input,
            **kwargs
        )

    def inverse(
        self,
        chain_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        """Transforms chain_input, broadcasting
        conditional_input to all conditional NF bijectors

        Parameters
        ----------
        chain_input : tf.Tensor
            from base distribution
            shape: batch_shape + (event_size,)

        Returns
        -------
        tf.Tensor
            transformed chain input
            shape: batch_shape + (event_size,)
        """
        return self.chain.inverse(
            y=chain_input,
            **kwargs
        )

    def inverse_log_det_jacobian(
        self,
        chain_input: tf.Tensor,
        event_ndims: int,
        **kwargs
    ) -> tf.Tensor:
        """Retruns inverse log det Jacobian at chain_input,
        broadcasting conditional_input to all conditional NF bijectors

        Parameters
        ----------
        chain_input : tf.Tensor
            from base distribution
            shape: batch_shape + (event_size,)
        event_ndims : int
            rank of event shape (i.e. 1 here)

        Returns
        -------
        tf.Tensor
            shape: batch_shape
        """
        return self.chain.inverse_log_det_jacobian(
            y=chain_input,
            event_ndims=event_ndims,
            **kwargs
        )


class BijectorArray(tfb.Bijector):

    def __init__(
        self,
        nf_chain_kwargs: Dict,
        shape: Iterable[int],
        event_size: int,
        **kwargs
    ) -> None:

        super().__init__(
            forward_min_event_ndims=len(shape) + 1,
            **kwargs
        )

        batch_size = int(tf.reduce_prod(shape))

        self.reshaper = tfb.Reshape(
            event_shape_in=shape + (event_size,),
            event_shape_out=(1, batch_size * event_size,)
        )

        self.bijectors = [
            NFChain(
                **nf_chain_kwargs,
                event_size=event_size,
                name=f"nf_chain_{b}"
            )
            for b in range(batch_size)
        ]

        self.blockwise_bijector = tfb.Blockwise(
            bijectors=self.bijectors,
            block_sizes=[event_size] * batch_size,
            name="block"
        )

        self.chain = tfb.Chain(
            bijectors=[
                tfb.Invert(self.reshaper),
                self.blockwise_bijector,
                self.reshaper
            ]
        )

    def forward(
        self,
        chain_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        return self.chain.forward(
            x=chain_input,
            **kwargs
        )

    def inverse(
        self,
        chain_input: tf.Tensor,
        **kwargs
    ) -> tf.Tensor:
        return self.chain.inverse(
            y=chain_input,
            **kwargs
        )

    def inverse_log_det_jacobian(
        self,
        chain_input: tf.Tensor,
        event_ndims: int,
        **kwargs
    ) -> tf.Tensor:
        return self.chain.inverse_log_det_jacobian(
            chain_input,
            event_ndims=event_ndims
        )


# ! Create method for a batched inv log det jac we can then slice

class ConditionalBijectorArray(tfb.Bijector):

    def __init__(
        self,
        conditional_nf_chain_kwargs: Dict,
        shape: Iterable[int],
        event_size: int,
        conditional_event_size: int,
        **kwargs
    ) -> None:

        super().__init__(
            forward_min_event_ndims=len(shape) + 1,
            **kwargs
        )

        batch_size = int(tf.reduce_prod(shape))

        self.event_reshaper = tfb.Reshape(
            event_shape_in=tuple(shape) + (event_size,),
            event_shape_out=(batch_size * event_size,),
            # ! to deal with affine non-batched processing: event_shape_out=(1, batch_size * event_size,),
        )
        self.conditional_event_reshaper = tfb.Reshape(
            event_shape_in=tuple(shape) + (conditional_event_size,),
            event_shape_out=(batch_size * conditional_event_size,)
            # ! to deal with affine non-batched processing: event_shape_out=(1, batch_size * conditional_event_size,),
        )
        self.ildj_reshaper = tfb.Reshape(
            event_shape_in=tuple(shape),
            event_shape_out=(batch_size,),
        )

        self.bijectors = [
            ConditionalNFChain(
                **conditional_nf_chain_kwargs,
                event_size=event_size,
                conditional_event_size=conditional_event_size,
                name=f"nf_chain_{b}"
            )
            for b in range(batch_size)
        ]

        self.blockwise_bijector = tfb.Blockwise(
            bijectors=self.bijectors,
            block_sizes=[event_size] * batch_size,
            name="block"
        )

        self.conditional_event_splitter = tfb.Split(
            num_or_size_splits=(
                [conditional_event_size]
                * batch_size
            )
        )

        self.encoding_restructurer = tfb.Chain(
            [
                tfb.Restructure(
                    output_structure={
                        f"nf_chain_{b}": {
                            "conditional_input": b
                        }
                        for b in range(batch_size)
                    }
                ),
                self.conditional_event_splitter
            ]
        )

        self.chain = tfb.Chain(
            bijectors=[
                tfb.Invert(self.event_reshaper),
                self.blockwise_bijector,
                self.event_reshaper
            ]
        )

        self.event_splitter = tfb.Split(
            num_or_size_splits=(
                [event_size]
                * batch_size
            )
        )

    def forward(
        self,
        chain_input: tf.Tensor,
        conditional_input: tf.Tensor,
        # **kwargs
    ) -> tf.Tensor:
        encodings = self.encoding_restructurer(
            self.conditional_event_reshaper(
                conditional_input
            )
        )
        return self.chain.forward(
            chain_input,
            block=encodings,
            # **kwargs
        )

    def inverse(
        self,
        chain_input: tf.Tensor,
        conditional_input: tf.Tensor,
        # **kwargs
    ) -> tf.Tensor:
        encodings = self.encoding_restructurer(
            self.conditional_event_reshaper(
                conditional_input
            )
        )
        return self.chain.inverse(
            y=chain_input,
            block=encodings,
            # **kwargs
        )

    def inverse_log_det_jacobian(
        self,
        chain_input: tf.Tensor,
        conditional_input: tf.Tensor,
        event_ndims: int,
        # **kwargs
    ) -> tf.Tensor:
        encodings = self.encoding_restructurer(
            self.conditional_event_reshaper(
                conditional_input
            )
        )
        return self.chain.inverse_log_det_jacobian(
            y=chain_input,
            event_ndims=event_ndims,
            block=encodings,
            # **kwargs
        )

    def inverse_log_det_jacobian_parts(
        self,
        chain_input: tf.Tensor,
        conditional_input: tf.Tensor,
    ) -> tf.Tensor:
        xs = (
            self.event_splitter
            .forward(
                self.event_reshaper
                .forward(chain_input)
            )
        )
        cis = (
            self.conditional_event_splitter
            .forward(
                self.conditional_event_reshaper
                .forward(conditional_input)
            )
        )

        ildjs = [
            bij.inverse_log_det_jacobian(
                chain_input=x,
                conditional_input=ci,
                event_ndims=1
            )
            for bij, x, ci in zip(
                self.bijectors,
                xs,
                cis
            )
        ]

        ildj = tf.stack(ildjs, axis=1)

        return self.ildj_reshaper.inverse(ildj)
