from typing import (
    Dict,
    Literal,
    Tuple,
    List,
    Optional,
    Union
)
from collections import defaultdict
from copy import copy
import gc

import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp

from ...set_transformer.layers import (
    RFF
)
from ...normalizing_flow.bijectors import (
    ConditionalNFChain,
    ConditionalBijectorArray
)
from ..graph import get_inference_order
from ...utils import (
    repeat_to_shape,
    get_ranges,
    TFWithoutReplacementSampler,
    TFFrozenSampler
)

tfd = tfp.distributions
tfb = tfp.bijectors
Root = tfd.JointDistributionCoroutine.Root


class PAVFFamily(tf.keras.Model):
    """Implementation of the PAVI-F architecture
    """

    def __init__(
        self,
        full_hbm: tfd.JointDistributionNamed,
        reduced_hbm: tfd.JointDistributionNamed,
        plates_per_rv: Dict[str, List[str]],
        link_functions: Dict[str, tfb.Bijector],
        posterior_schemes_kwargs: Dict[
            str,
            Tuple[
                Literal[
                    'observed',
                    'flow',
                    'non-parametric',
                    'mean-field',
                    'ASVI',
                    'individual_flow'
                ],
                Dict
            ]
        ],
        encoding_sizes: Dict[
            Tuple[str, ...], int
        ],
        frozen_plates: List[str] = [],
        **kwargs
    ):
        """PAVI-F variational family, using free encodings

        Parameters
        ----------
        full_hbm : tfd.JointDistributionNamed
            HBM instanciated with full cardinalities
        reduced_hbm : tfd.JointDistributionNamed
            HBM instanciated with reduced cardinalities
        plates_per_rv : Dict[str, List[str]]
            dict {rv: ['P0', 'P1']}
        link_functions : Dict[str, tfb.Bijector]
            functions projecting the RV's event space into an unbounded real space
        posterior_schemes_kwargs : Dict[ 
            str,
            Tuple[
                Literal[
                    'observed',
                    'flow',
                    'non-parametric',
                    'mean-field',
                    'ASVI'
                ],
                Dict
            ]
        ]
            for each RV, specify a posterior scheme, describing
            how we intend to approximate the RV's posterior, and the
            kwargs associated to the method
        encoding_sizes : Dict[ Tuple[str, ...], int ]
            describe te encoding size associated to each plate level
        frozen_plates : List[str], optional
            not sampled plates, by default []
        """
        super().__init__(**kwargs)

        self.hbms = {
            "reduced": reduced_hbm,
            "full": full_hbm
        }

        self.plates_per_rv = plates_per_rv

        self.link_functions = link_functions
        self.encoding_sizes = encoding_sizes
        self.posterior_schemes, self.posterior_kwargs = {}, {}
        for rv, (posterior_scheme, posterior_kwargs) in (
            posterior_schemes_kwargs.items()
        ):
            self.posterior_schemes[rv] = posterior_scheme
            self.posterior_kwargs[rv] = posterior_kwargs

        self.frozen_plates = frozen_plates

        self.analyse_generative_hbm_graph()
        self.build_architecture()

    def analyse_generative_hbm_graph(self) -> None:
        """Analyse full and reduced HBM to determine items such as
        batch sizes, etc...
        """

        all_rvs = self.hbms["full"].model.keys()
        self.all_rvs = all_rvs

        self.observed_rvs = [
            rv
            for rv, scheme in self.posterior_schemes.items()
            if scheme == "observed"
        ]
        self.latent_rvs = [
            rv for rv in all_rvs if rv not in self.observed_rvs
        ]

        graph = self.hbms["full"].resolve_graph()
        self.parents = {}
        self.rvs_per_plates = defaultdict(lambda: [])
        for child, parents in graph:
            self.parents[child] = list(parents)
            child_p = tuple(self.plates_per_rv[child])
            self.rvs_per_plates[child_p].append(child)

        parent_plates = {
            plates: set()
            for plates in self.rvs_per_plates.keys()
        }
        for child, parents in graph:
            child_p = tuple(self.plates_per_rv[child])
            for parent in parents:
                parent_p = tuple(self.plates_per_rv[parent])
                if child_p != parent_p:
                    parent_plates[child_p].add(parent_p)

        plate_graph = tuple(
            (child_p, tuple(parent_ps))
            for child_p, parent_ps in parent_plates.items()
        )
        self.plates_order = get_inference_order(plate_graph)

        batch_shapes = {
            hbm_type: {}
            for hbm_type in self.hbms.keys()
        }
        event_shapes = {
            hbm_type: {}
            for hbm_type in self.hbms.keys()
        }
        for hbm_type, hbm in self.hbms.items():
            total_shape = hbm.event_shape
            for rv in all_rvs:
                n_plates = len(self.plates_per_rv[rv])
                batch_shapes[hbm_type][rv] = total_shape[rv][:n_plates]
                event_shapes[hbm_type][rv] = total_shape[rv][n_plates:]
        for rv in all_rvs:
            assert event_shapes["full"][rv] == event_shapes["reduced"][rv], (
                f"Inconsistent event shape for rv {rv}: "
                f"Shape {event_shapes['full'][rv]} in full HBM vs "
                f"Shape {event_shapes['reduced'][rv]} in reduced HBM."
            )

        cardinalities = {
            hbm_type: {}
            for hbm_type in self.hbms.keys()
        }

        for rv in all_rvs:
            for plate_idx, plate in enumerate(self.plates_per_rv[rv]):
                card = {
                    hbm_type: batch_shape[rv][plate_idx]
                    for hbm_type, batch_shape in batch_shapes.items()
                }
                if plate not in cardinalities["full"].keys():
                    for hbm_type in cardinalities.keys():
                        cardinalities[hbm_type][plate] = card[hbm_type]
                else:
                    for hbm_type in cardinalities.keys():
                        assert cardinalities[hbm_type][plate] == card[hbm_type], (
                            f"Problem in plates declaration ({hbm_type} HBM): "
                            f"inconsistent cardinality for plate {plate} "
                            f"{card[hbm_type]} (observed for RV {rv}) "
                            f"vs {cardinalities[hbm_type][plate]} (parent RV)."
                        )

        self.event_shapes = event_shapes["full"]
        self.batch_shapes = batch_shapes
        self.cardinalities = cardinalities

        for plate in self.frozen_plates:
            full_card = cardinalities["full"][plate]
            reduced_card = cardinalities["reduced"][plate]
            assert full_card == reduced_card, (
                f"Problem with frozen plate {plate}: "
                f"full HBM cardinality ({full_card}) should be equal "
                f"to reduced HBM cardinality ({reduced_card})."
            )

        self.plates_full_batch_shapes = {}
        for rv, plates in self.plates_per_rv.items():
            if rv in self.observed_rvs:
                continue
            self.plates_full_batch_shapes[tuple(plates)] = batch_shapes["full"][rv]

        plate_ratios = {
            plate: full_card / cardinalities["reduced"][plate]
            for plate, full_card in cardinalities["full"].items()
        }
        self.log_prob_factors = {
            rv: tf.reduce_prod(
                [plate_ratios[plate] for plate in plates]
            )
            for rv, plates in self.plates_per_rv.items()
        }

    def build_architecture(self) -> None:
        """Build architecture, conditional density estimators and
        encoding arrays
        """

        self.plate_encodings = {}
        latent_plates = {
            tuple(plates)
            for rv, plates in self.plates_per_rv.items()
        }
        for plates, batch_shape in self.plates_full_batch_shapes.items():
            if plates in latent_plates:
                self.plate_encodings[plates] = self.add_weight(
                    shape=(
                        batch_shape
                        + (self.encoding_sizes[tuple(plates)],)
                    ),
                    initializer="random_normal",
                    trainable=True,
                    name=f"latent_vector_{'_'.join(plates)}"
                )

        event_sizes = {}
        self.formatters = {}
        self.bijectors = {}
        self.rv_encodings = {}

        self.asvi_rffs = {}

        self.mean_field_rffs = {}
        self.mean_field_parameter_reshapers = {}

        self.bijector_arrays = {}

        for rv in self.all_rvs:
            constrained_shape = (
                self
                .link_functions[rv]
                .inverse_event_shape(
                    self.event_shapes[rv]
                )
            )
            event_size = tf.reduce_prod(
                constrained_shape
            )
            event_sizes[rv] = event_size

            reshaper = tfb.Reshape(
                event_shape_in=(event_size,),
                event_shape_out=constrained_shape
            )

            formatter = tfb.Chain(
                bijectors=[
                    self.link_functions[rv],
                    reshaper,
                ]
            )
            self.formatters[rv] = formatter

            if self.posterior_schemes[rv] in [
                "observed",
                "non-parametric"
            ]:
                pass
            elif self.posterior_schemes[rv] == "mean-field":
                self.rv_encodings[rv] = self.plate_encodings[
                    tuple(self.plates_per_rv[rv])
                ]

                self.mean_field_rffs[rv] = {}
                self.mean_field_parameter_reshapers[rv] = {}
                for parameter, shape in (
                    self.posterior_kwargs[rv]
                    ["parameter_shapes"]
                    .items()
                ):
                    parameter_size = int(tf.reduce_prod(shape))
                    self.mean_field_rffs[rv][parameter] = RFF(
                        units_per_layers=[parameter_size],
                        dense_kwargs=dict()
                    )
                    self.mean_field_parameter_reshapers[rv][parameter] = tfb.Reshape(
                        event_shape_in=(parameter_size,),
                        event_shape_out=shape
                    )
            elif self.posterior_schemes[rv] == "flow":
                self.rv_encodings[rv] = self.plate_encodings[
                    tuple(self.plates_per_rv[rv])
                ]
                nf = ConditionalNFChain(
                    event_size=event_size.numpy(),
                    conditional_event_size=int(
                        self.encoding_sizes[tuple(self.plates_per_rv[rv])]
                    ),
                    name=f"nf_{rv}",
                    **self.posterior_kwargs[rv]["conditional_nf_chain_kwargs"]
                )

                self.bijectors[rv] = tfb.Chain(
                    bijectors=[
                        formatter,
                        nf,
                        tfb.Invert(formatter)
                    ],
                    name=f"chain_{rv}"
                )
            elif self.posterior_schemes[rv] == "ASVI":
                parents = self.parents[rv]
                if len(parents) == 0:
                    raise NotImplementedError(
                        f"Provided RV {rv} with `ASVI` posterior scheme "
                        "yet the RV has no parents in the graph."
                    )

                self.rv_encodings[rv] = self.plate_encodings[
                    tuple(self.plates_per_rv[rv])
                ]
                self.asvi_rffs[rv] = {}

                for parent in parents:
                    self.asvi_rffs[rv][parent] = {
                        "alpha": RFF(
                            units_per_layers=[event_sizes[parent]],
                            dense_kwargs=dict()
                        ),
                        "lambda": RFF(
                            units_per_layers=[1],
                            dense_kwargs=dict(activation="sigmoid")
                        )
                    }
            elif self.posterior_schemes[rv] == "individual_flow":
                self.rv_encodings[rv] = self.plate_encodings[
                    tuple(self.plates_per_rv[rv])
                ]
                self.bijector_arrays[rv] = ConditionalBijectorArray(
                    conditional_nf_chain_kwargs=self.posterior_kwargs[rv]["conditional_nf_chain_kwargs"],
                    event_size=int(event_size.numpy()),
                    conditional_event_size=int(
                        self.encoding_sizes[tuple(self.plates_per_rv[rv])]
                    ),
                    shape=self.batch_shapes["full"][rv],
                    name=f"bijector_array_{rv}"
                )
            else:
                raise NotImplementedError(
                    f"Provided RV {rv} with unknown posterior scheme "
                    f"`{self.posterior_schemes[rv]}`."
                )

    def get_indices_from_plate_idxs(
        self,
        rv: str,
        plate_idxs: Dict[str, List[int]]
    ) -> tf.Tensor:
        """From a dictionary of plate indices, outputs an
        indexing tensor to slice matrices

        Parameters
        ----------
        rv : str
            RV from which we slice data
        plate_idxs : Dict[str, List[int]]
            {plate: indices} selected plate indices

        Returns
        -------
        tf.Tensor
            indexing tensor
        """
        return tf.stack(
            tf.meshgrid(
                *[
                    plate_idxs[plate]
                    for plate in self.plates_per_rv[rv]
                ],
                indexing='ij'
            ),
            axis=-1
        )

    def gather_from_tensor(
        self,
        rv: str,
        tensor: tf.Tensor,
        plate_idxs: Optional[Dict[str, List[int]]] = None
    ) -> tf.Tensor:
        """Slices tensor based on plate indices

        Parameters
        ----------
        rv : str
            RV the tensor is associated to
        tensor : Union[tf.Tensor, np.ndarray]
            tensor to be sliced
        plate_idxs : Optional[Dict[str, List[int]]], optional
            {plate: indices} selected plate indices, by default None

        Returns
        -------
        tf.Tensor
            sliced tensor
        """
        if plate_idxs is None:
            return tensor
        else:
            return tf.gather_nd(
                params=tensor,
                indices=self.get_indices_from_plate_idxs(
                    rv=rv,
                    plate_idxs=plate_idxs
                )
            )

    def update_tensor(
        self,
        rv: str,
        tensor: tf.Tensor,
        updates: tf.Tensor,
        plate_idxs: Dict[str, List[int]]
    ) -> tf.Tensor:
        """Updates large tensor based on plate indices

        Parameters
        ----------
        rv : str
            RV the tensor is associated to
        tensor : Union[tf.Tensor, np.ndarray]
            tensor to be updated
        updates: tf.Tensor
            update values
        plate_idxs : Optional[Dict[str, List[int]]], optional
            {plate: indices} selected plate indices, by default None

        Returns
        -------
        tf.Tensor
            updated tensor
        """
        return tf.tensor_scatter_nd_update(
            tensor=tensor,
            indices=self.get_indices_from_plate_idxs(
                rv=rv,
                plate_idxs=plate_idxs
            ),
            updates=updates
        )

    def scatter_tensor(
        self,
        rv: str,
        shape: Tuple[int],
        updates: tf.Tensor,
        plate_idxs: Dict[str, List[int]]
    ) -> tf.Tensor:
        """Scatter updates into a zero tensor

        Parameters
        ----------
        rv : str
            RV the tensor is associated to
        shape : Tuple[int]
            _description_
        updates : tf.Tensor
            update values
        plate_idxs : Dict[str, List[int]]
            {plate: indices} selected plate indices

        Returns
        -------
        tf.Tensor
            updates scattered into 0s
        """
        return tf.scatter_nd(
            indices=self.get_indices_from_plate_idxs(
                rv=rv,
                plate_idxs=plate_idxs
            ),
            updates=updates,
            shape=shape
        )

    def get_all_encodings(
        self,
        plate_idxs: Optional[Dict[str, List[int]]] = None,
        excluded_rvs: List[str] = []
    ) -> Dict[str, Dict[str, Dict[str, tf.Tensor]]]:
        """For all RVs, slice from encoding arrays to
        returned encodings related to the selected paths

        Parameters
        ----------
        plate_idxs : Optional[Dict[str, List[int]]], optional
            {plate: indices} selected plate indices, by default None
        excluded_rvs : List[str], optional
            RVs to ignore, not returning encodings, by default []

        Returns
        -------
        Dict[str, Dict[str, Dict[str, tf.Tensor]]]
            {rv: encoding}
        """
        return {
            rv: self.gather_from_tensor(
                rv,
                latent_vector,
                plate_idxs
            )
            for rv, latent_vector in self.rv_encodings.items()
            if rv not in excluded_rvs
        }

    def repeat_to_plates(
        self,
        tensor: tf.Tensor,
        source_plates: List[str],
        target_plates: List[str],
        batch_ndims: int,
        hbm_type: str
    ) -> tf.Tensor:
        """Repeats a tensor with source plates
        into a tensor with target plates

        Parameters
        ----------
        tensor : tf.Tensor
            value to repeat
        source_plates : List[str]
            plates for the input tensor
        target_plates : List[str]
            plates for the output tensor
        batch_ndims : int
            to ignore in input tensor
        hbm_type : str
            "full" or "reduced", to infer cardinalities

        Returns
        -------
        tf.Tensor
            repeated tensor
        """
        repeated_tensor = tensor
        for plate_idx, plate in enumerate(target_plates):
            if plate not in source_plates:
                insertion_axis = batch_ndims + plate_idx
                repeated_tensor = tf.expand_dims(
                    repeated_tensor,
                    axis=insertion_axis
                )
                repeated_tensor = tf.repeat(
                    repeated_tensor,
                    repeats=self.cardinalities[hbm_type][plate],
                    axis=insertion_axis
                )

        return repeated_tensor

    def sample(
        self,
        sample_shape: Tuple[int],
        observed_values: Dict[str, tf.Tensor],
        plate_idxs: Optional[Dict[str, List[int]]] = None,
        return_observed_values: bool = False,
        return_latent_values: bool = False,
        repeat_observed_rvs: bool = True,
        excluded_rvs: List[str] = []
    ) -> Union[
        Dict[str, tf.Tensor],
        Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor]]
    ]:
        """sample from the architecture

        Parameters
        ----------
        sample_shape : Tuple[int]
            shape for the sample (n_samples,)
        observed_values : Dict[str, tf.Tensor]
            {observed_rv: value}
        plate_idxs : Optional[Dict[str, List[int]]], optional
            {plate: indices} selected plate indices, by default None
        return_observed_values : bool, optional
            by default False
        return_latent_values : bool, optional
            before application of NFs, by default False
        repeat_observed_rvs : bool, optional
            to the sample shape, by default True
        excluded_rvs : List[str], optional
            do not sample those RVs, by default []

        Returns
        -------
        Union[ Dict[str, tf.Tensor], Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor]] ]
            either {rv: value} or ({rv: value}, {rv: latent_value})
        """
        sample_ndims = len(sample_shape)
        if sample_ndims > 1:
            raise NotImplementedError(
                f"Sampling with batch shape of len {sample_ndims} "
                "not implemented."
            )
        hbm_type = (
            "full"
            if plate_idxs is None
            else "reduced"
        )

        latent_values = {}
        observed_rvs = [
            observed_rv for observed_rv in observed_values.keys()
        ]
        ignored_rvs = observed_rvs + excluded_rvs
        values = {
            observed_rv: (
                repeat_to_shape(
                    observed_value,
                    target_shape=sample_shape,
                    axis=0
                )
                if repeat_observed_rvs
                else
                observed_value
            )
            for observed_rv, observed_value in observed_values.items()
        }
        encodings = self.get_all_encodings(
            plate_idxs=plate_idxs,
            excluded_rvs=ignored_rvs
        )

        for rv in self.all_rvs:
            if rv in ignored_rvs:
                continue
            if self.posterior_schemes[rv] == "non-parametric":
                parents = self.parents[rv]
                if len(parents) == 0:
                    values[rv] = self.hbms[hbm_type].model[rv].sample(sample_shape)
                else:
                    values[rv] = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .sample()
                    )
            elif self.posterior_schemes[rv] == "mean-field":
                parameters = {
                    parameter: (
                        self.posterior_kwargs[rv]
                        ["parameter_constraints"]
                        [parameter](
                            self.mean_field_parameter_reshapers[rv][parameter](
                                rff(encodings[rv])
                            )
                        )
                    )
                    for parameter, rff in self.mean_field_rffs[rv].items()
                }
                values[rv] = (
                    self.posterior_kwargs[rv]["functor"](
                        **parameters
                    )
                    .sample(sample_shape)
                )
            elif self.posterior_schemes[rv] == "flow":
                parents = self.parents[rv]
                if len(parents) == 0:
                    latent_values[rv] = self.hbms[hbm_type].model[rv].sample(sample_shape)
                else:
                    latent_values[rv] = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .sample()
                    )
                values[rv] = (
                    self.bijectors[rv]
                    .forward(
                        latent_values[rv],
                        **{
                            f"nf_{rv}": {
                                "conditional_input": encodings[rv]
                            }
                        }
                    )
                )
            elif self.posterior_schemes[rv] == "ASVI":
                model_kwargs = {}
                for parent in self.parents[rv]:
                    alpha, lambd = (
                        self.asvi_rffs[rv][parent][key](
                            encodings[rv]
                        )
                        for key in ["alpha", "lambda"]
                    )
                    formatter = self.formatters[parent]
                    parent_value = self.repeat_to_plates(
                        tensor=formatter.inverse(values[parent]),
                        source_plates=self.plates_per_rv[parent],
                        target_plates=self.plates_per_rv[rv],
                        batch_ndims=len(sample_shape),
                        hbm_type=hbm_type
                    )
                    model_kwargs[parent] = formatter.forward(
                        lambd * parent_value
                        + (1 - lambd) * alpha
                    )
                values[rv] = (
                    self.posterior_kwargs[rv]["functor"](
                        **model_kwargs
                    )
                    .sample()
                )
            elif self.posterior_schemes[rv] == "individual_flow":
                parents = self.parents[rv]
                if len(parents) == 0:
                    latent_values[rv] = self.hbms[hbm_type].model[rv].sample(sample_shape)
                else:
                    latent_values[rv] = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .sample()
                    )

                if plate_idxs is None:
                    full_latent_value = latent_values[rv]
                else:
                    if sample_ndims == 0:
                        full_latent_value = self.scatter_tensor(
                            rv=rv,
                            shape=(
                                self.batch_shapes["full"][rv]
                                + self.event_shapes[rv]
                            ),
                            updates=latent_values[rv],
                            plate_idxs=plate_idxs
                        )
                    else:
                        full_latent_value = tf.vectorized_map(
                            lambda lv: self.scatter_tensor(
                                rv=rv,
                                shape=(
                                    self.batch_shapes["full"][rv]
                                    + self.event_shapes[rv]
                                ),
                                updates=lv,
                                plate_idxs=plate_idxs
                            ),
                            latent_values[rv]
                        )

                full_value = (
                    self.formatters[rv]
                    .forward(
                        self.bijector_arrays[rv]
                        .forward(
                            self.formatters[rv].inverse(full_latent_value),
                            conditional_input=self.rv_encodings[rv]
                        )
                    )
                )

                if plate_idxs is None:
                    values[rv] = full_value
                else:
                    if sample_ndims == 0:
                        values[rv] = self.gather_from_tensor(
                            rv=rv,
                            tensor=full_value,
                            plate_idxs=plate_idxs
                        )
                    else:
                        values[rv] = tf.vectorized_map(
                            lambda fv: self.gather_from_tensor(
                                rv=rv,
                                tensor=fv,
                                plate_idxs=plate_idxs
                            ),
                            full_value
                        )

        if return_latent_values:
            if return_observed_values:
                return values, latent_values
            else:
                return tuple(
                    {
                        rv: value
                        for rv, value in values.items()
                        if rv not in observed_rvs
                    }
                    for values in [values, latent_values]
                )
        else:
            if return_observed_values:
                return values
            else:
                return {
                    rv: value
                    for rv, value in values.items()
                    if rv not in observed_rvs
                }

    def _MAP_regression(
        self,
        observed_values: Dict[str, tf.Tensor],
        plate_idxs: Optional[Dict[str, List[int]]] = None,
        return_observed_values: bool = False,
        excluded_rvs: List[str] = []
    ) -> Dict[str, tf.Tensor]:
        """Returns MAP point estimate from the architecture

        Parameters
        ----------
        observed_values : Dict[str, tf.Tensor]
            {observed_rv: value}
        plate_idxs : Optional[Dict[str, List[int]]], optional
            {plate: indices} selected plate indices, by default None
        return_observed_values : bool, optional
            by default False
        repeat_observed_rvs : bool, optional
            to the sample shape, by default True
        excluded_rvs : List[str], optional
            do not sample those RVs, by default []

        Returns
        -------
        Dict[str, tf.Tensor]
            {rv: value}
        """
        hbm_type = (
            "full"
            if plate_idxs is None
            else "reduced"
        )

        observed_rvs = [
            observed_rv for observed_rv in observed_values.keys()
        ]
        ignored_rvs = observed_rvs + excluded_rvs
        values = observed_values
        encodings = self.get_all_encodings(
            plate_idxs=plate_idxs,
            excluded_rvs=ignored_rvs
        )

        for rv in self.all_rvs:
            if rv in ignored_rvs:
                continue
            if self.posterior_schemes[rv] == "non-parametric":
                parents = self.parents[rv]
                if len(parents) == 0:
                    dist = self.hbms[hbm_type].model[rv]
                    try:
                        values[rv] = dist.mode()
                    except NotImplementedError:
                        try:
                            values[rv] = dist.mean()
                        except NotImplementedError:
                            values[rv] = dist.sample()
                else:
                    dist = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                    )
                    try:
                        values[rv] = dist.mean()
                    except NotImplementedError:
                        values[rv] = dist.sample()
            elif self.posterior_schemes[rv] == "mean-field":
                parameters = {
                    parameter: (
                        self.posterior_kwargs[rv]
                        ["parameter_constraints"]
                        [parameter](
                            self.mean_field_parameter_reshapers[rv][parameter](
                                rff(encodings[rv])
                            )
                        )
                    )
                    for parameter, rff in self.mean_field_rffs[rv].items()
                }
                dist = (
                    self.posterior_kwargs[rv]["functor"](
                        **parameters
                    )
                )

                try:
                    values[rv] = dist.mode()
                except NotImplementedError:
                    try:
                        values[rv] = dist.mean()
                    except NotImplementedError:
                        values[rv] = dist.sample()
            elif self.posterior_schemes[rv] == "flow":
                parents = self.parents[rv]
                if len(parents) == 0:
                    dist = self.hbms[hbm_type].model[rv]

                    if rv == "probs":
                        latent_value = dist.mean()
                    else:
                        try:
                            latent_value = dist.mode()
                        except NotImplementedError:
                            try:
                                latent_value = dist.mean()
                            except NotImplementedError:
                                latent_value = dist.sample()
                else:
                    dist = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                    )
                    try:
                        latent_value = dist.mode()
                    except NotImplementedError:
                        try:
                            latent_value = dist.mean()
                        except NotImplementedError:
                            latent_value = dist.sample()
            

                values[rv] = self.formatters[rv](
                    tfb.Shift(
                        shift=(
                            self.bijectors[rv]  # tfb.Chain
                            .bijectors[1]  # ConditionalNFChain
                            .bijectors[-1]  # ConditionalAffine
                            .shift(
                                encodings[rv]
                            )
                        )
                    ).forward(
                        self.formatters[rv]
                        .inverse(
                            latent_value
                        )
                    )
                )

        if return_observed_values:
            return values
        else:
            return {
                rv: value
                for rv, value in values.items()
                if rv not in observed_rvs
            }

    def log_prob_parts(
        self,
        values: Dict[str, tf.Tensor],
        latent_values: Optional[Dict[str, tf.Tensor]] = None,
        plate_idxs: Optional[Dict[str, List[int]]] = None
    ) -> Dict[str, tf.Tensor]:
        """Return the log prob for the conditional distributions

        Parameters
        ----------
        values : Dict[str, tf.Tensor]
            {rv: value}
        latent_values : Optional[Dict[str, tf.Tensor]], optional
            {rv: latent_value} (before NF), by default None
        plate_idxs : Optional[Dict[str, List[int]]], optional
            {plate: indices} selected plate indices, by default None

        Returns
        -------
        Dict[str, tf.Tensor]
            {rv: log_prob}
        """
        hbm_type = (
            "full"
            if plate_idxs is None
            else "reduced"
        )

        encodings = self.get_all_encodings(plate_idxs)
        first_latent_rv = self.latent_rvs[0]
        batch_ndims = (
            len(values[first_latent_rv].shape)
            - len(self.hbms[hbm_type].event_shape[first_latent_rv])
        )
        if batch_ndims > 1:
            raise NotImplementedError(
                f"Sample shape of len {batch_ndims} not implemented, "
                "stick to 0 or 1"
            )

        full_values = None
        if latent_values is None:
            latent_values = {}
            full_values = {}
            for rv, value in values.items():
                if self.posterior_schemes[rv] == "flow":
                    latent_values[rv] = self.bijectors[rv].inverse(
                        value,
                        **{
                            f"nf_{rv}": {
                                "conditional_input": encodings[rv]
                            }
                        }
                    )
                elif self.posterior_schemes[rv] == "individual_flow":
                    if plate_idxs is None:
                        full_values[rv] = values[rv]
                    else:
                        if batch_ndims == 0:
                            full_values[rv] = self.scatter_tensor(
                                rv=rv,
                                shape=(
                                    self.batch_shapes["full"][rv]
                                    + self.event_shapes[rv]
                                ),
                                updates=values[rv],
                                plate_idxs=plate_idxs
                            )
                        else:
                            full_values[rv] = tf.vectorized_map(
                                lambda v: self.scatter_tensor(
                                    rv=rv,
                                    shape=(
                                        self.batch_shapes["full"][rv]
                                        + self.event_shapes[rv]
                                    ),
                                    updates=v,
                                    plate_idxs=plate_idxs
                                ),
                                values[rv]
                            )
                    full_latent_value = (
                        self.formatters[rv]
                        .forward(
                            self.bijector_arrays[rv]
                            .inverse(
                                self.formatters[rv].inverse(full_values[rv]),
                                conditional_input=self.rv_encodings[rv]
                            )
                        )
                    )

                    if plate_idxs is None:
                        latent_values[rv] = full_latent_value
                    else:
                        if batch_ndims == 0:
                            latent_values[rv] = self.gather_from_tensor(
                                rv=rv,
                                tensor=full_latent_value,
                                plate_idxs=plate_idxs
                            )
                        else:
                            latent_values[rv] = tf.vectorized_map(
                                lambda flv: self.gather_from_tensor(
                                    rv=rv,
                                    tensor=flv,
                                    plate_idxs=plate_idxs
                                ),
                                full_latent_value
                            )

        log_prob_parts = {}
        for rv in self.latent_rvs:
            if self.posterior_schemes[rv] == "non-parametric":
                parents = self.parents[rv]
                if len(parents) == 0:
                    log_prob_parts[rv] = (
                        self.hbms[hbm_type]
                        .model[rv]
                        .log_prob(
                            values[rv]
                        )
                    )
                else:
                    log_prob_parts[rv] = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .log_prob(
                            values[rv]
                        )
                    )
            elif self.posterior_schemes[rv] == "mean-field":
                parameters = {
                    parameter: (
                        self.posterior_kwargs[rv]
                        ["parameter_constraints"]
                        [parameter](
                            self.mean_field_parameter_reshapers[rv][parameter](
                                rff(encodings[rv])
                            )
                        )
                    )
                    for parameter, rff in self.mean_field_rffs[rv].items()
                }
                log_prob_parts[rv] = (
                    self.posterior_kwargs[rv]["functor"](
                        **parameters
                    )
                    .log_prob(
                        values[rv]
                    )
                )
            elif self.posterior_schemes[rv] == "flow":
                inv_log_det_jac = self.bijectors[rv].inverse_log_det_jacobian(
                    values[rv],
                    **{
                        f"nf_{rv}": {
                            "conditional_input": encodings[rv]
                        }
                    }
                )
                if batch_ndims == 0:
                    inv_log_det_jac = tf.reduce_sum(inv_log_det_jac)
                elif batch_ndims == 1:
                    inv_log_det_jac = tf.vectorized_map(
                        tf.reduce_sum,
                        inv_log_det_jac
                    )

                parents = self.parents[rv]
                if len(parents) == 0:
                    latent_log_prob_part = (
                        self.hbms[hbm_type]
                        .model[rv]
                        .log_prob(
                            latent_values[rv]
                        )
                    )
                else:
                    latent_log_prob_part = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .log_prob(
                            latent_values[rv]
                        )
                    )

                log_prob_parts[rv] = latent_log_prob_part + inv_log_det_jac
            elif self.posterior_schemes[rv] == "ASVI":
                model_kwargs = {}
                for parent in self.parents[rv]:
                    alpha, lambd = (
                        self.asvi_rffs[rv][parent][key](
                            encodings[rv]
                        )
                        for key in ["alpha", "lambda"]
                    )
                    formatter = self.formatters[parent]
                    parent_value = self.repeat_to_plates(
                        tensor=formatter.inverse(values[parent]),
                        source_plates=self.plates_per_rv[parent],
                        target_plates=self.plates_per_rv[rv],
                        batch_ndims=batch_ndims,
                        hbm_type=hbm_type
                    )
                    model_kwargs[parent] = formatter.forward(
                        lambd * parent_value
                        + (1 - lambd) * alpha
                    )
                log_prob_parts[rv] = (
                    self.posterior_kwargs[rv]["functor"](
                        **model_kwargs
                    )
                    .log_prob(
                        values[rv]
                    )
                )
            elif self.posterior_schemes[rv] == "individual_flow":
                if full_values is None:
                    if plate_idxs is None:
                        full_value = values[rv]
                    else:
                        if batch_ndims == 0:
                            full_value = self.scatter_tensor(
                                rv=rv,
                                shape=(
                                    self.batch_shapes["full"][rv]
                                    + self.event_shapes[rv]
                                ),
                                updates=values[rv],
                                plate_idxs=plate_idxs
                            )
                        else:
                            full_value = tf.vectorized_map(
                                lambda v: self.scatter_tensor(
                                    rv=rv,
                                    shape=(
                                        self.batch_shapes["full"][rv]
                                        + self.event_shapes[rv]
                                    ),
                                    updates=v,
                                    plate_idxs=plate_idxs
                                ),
                                values[rv]
                            )
                else:
                    full_value = full_values[rv]

                full_inv_log_det_jac = (
                    self.bijector_arrays[rv]
                    .inverse_log_det_jacobian_parts(
                        self.formatters[rv].inverse(full_value),
                        conditional_input=self.rv_encodings[rv],
                    )
                )
                if plate_idxs is None:
                    inv_log_det_jac = full_inv_log_det_jac
                else:
                    if batch_ndims == 0:
                        inv_log_det_jac = self.gather_from_tensor(
                            rv=rv,
                            tensor=full_inv_log_det_jac,
                            plate_idxs=plate_idxs
                        )
                    else:
                        inv_log_det_jac = tf.vectorized_map(
                            lambda fildj: self.gather_from_tensor(
                                rv=rv,
                                tensor=fildj,
                                plate_idxs=plate_idxs
                            ),
                            full_inv_log_det_jac
                        )

                if batch_ndims == 0:
                    inv_log_det_jac = tf.reduce_sum(inv_log_det_jac)
                elif batch_ndims == 1:
                    inv_log_det_jac = tf.vectorized_map(
                        tf.reduce_sum,
                        inv_log_det_jac
                    )

                parents = self.parents[rv]
                if len(parents) == 0:
                    latent_log_prob_part = (
                        self.hbms[hbm_type]
                        .model[rv]
                        .log_prob(
                            latent_values[rv]
                        )
                    )
                else:
                    latent_log_prob_part = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .log_prob(
                            latent_values[rv]
                        )
                    )

                log_prob_parts[rv] = latent_log_prob_part + inv_log_det_jac

        return log_prob_parts

    def log_prob(
        self,
        values: Dict[str, tf.Tensor],
        latent_values: Optional[Dict[str, tf.Tensor]] = None,
        plate_idxs: Optional[Dict[int, int]] = None
    ) -> tf.Tensor:
        """Calls log_prob_parts and sums out the result

        Returns
        -------
        tf.Tensor
            log_prob
        """
        return tf.reduce_sum(
            [
                log_prob_part
                for log_prob_part in self.log_prob_parts(
                    values=values,
                    latent_values=latent_values,
                    plate_idxs=plate_idxs
                ).values()
            ],
            axis=0
        )

    def find_mode(
        self,
        observed_values: Dict[str, tf.Tensor],
        optimizer: tf.keras.optimizers.Optimizer,
        optimizer_steps: int = 1_000
    ) -> Dict[str, np.ndarray]:
        """Find posterior mode

        Parameters
        ----------
        observed_values : Dict[str, tf.Tensor]
            {observed_rv: value}
        optimizer : tf.keras.optimizers.Optimizer
            to use to do gradient ascent of the log prob
        optimizer_steps : int, optional
            how long to search for mode, by default 1_000

        Returns
        -------
        Dict[str, np.ndarray]
            {rv: mode_value}
        """

        variable_sample = {
            rv: tfp.util.TransformedVariable(
                initial_value=value,
                bijector=self.formatters[rv],
                name=f"variable_{rv}"
            )
            for rv, value in self.sample(
                sample_shape=(1,),
                observed_values=observed_values
            ).items()
        }

        @tf.function
        def opti_func():
            optimizer.minimize(
                loss=lambda: - self.log_prob(
                    values=variable_sample
                ),
                var_list=[
                    var.trainable_variables[0]
                    for var in variable_sample.values()
                ]
            )

        for _ in range(optimizer_steps):
            opti_func()

        return {
            rv: var.numpy()
            for rv, var in variable_sample.items()
        }

    def _rec_sample_aggregate(
        self,
        aggregate_sample: Dict[str, tf.Tensor],
        sample_shape: Tuple[int],
        current_plates_order: List[List[str]],
        current_plate_idxs: Dict[str, List[int]],
        current_excluded_rvs: List[str],
        observed_rvs: List[str]
    ) -> Dict[str, tf.Tensor]:
        """Explores the graph, sampling from reduced model and filling
        a sample from the full HBM

        Parameters
        ----------
        aggregate_sample : Dict[str, tf.Tensor]
            ongoing sample being filled
        sample_shape : Tuple[int]
            (n_samples,)
        current_plates_order : List[List[str]]
            plates to explore next
        current_plate_idxs : Dict[str, List[int]]
            indices to explore in plates
        current_excluded_rvs : List[str]
            RVs to ignore at this step, for instance child
        observed_rvs : List[str]
            RVs with observed values

        Returns
        -------
        Dict[str, tf.Tensor]
            {rv; value} updated aggregate sample
        """

        if len(current_plates_order) == 0:
            return aggregate_sample

        current_plates, next_plates_order = (
            current_plates_order[0],
            current_plates_order[1:]
        )
        new_plates = [
            plate
            for plate in current_plates
            if plate not in current_plate_idxs.keys()
        ]
        if len(new_plates) == 0:
            # No new plate so nothing to do
            return self._rec_sample_aggregate(
                aggregate_sample=aggregate_sample,
                sample_shape=sample_shape,
                current_plates_order=next_plates_order,
                current_plate_idxs=current_plate_idxs,
                current_excluded_rvs=current_excluded_rvs,
                observed_rvs=observed_rvs
            )
        else:
            plate = new_plates[0]
            plate_idxs = copy(current_plate_idxs)
            excluded_rvs = copy(current_excluded_rvs)
            rvs_in_plates = [rv for rv in self.rvs_per_plates[current_plates]]
            for rv in rvs_in_plates:
                excluded_rvs.pop(
                    excluded_rvs.index(rv)
                )
            for plate_idx in get_ranges(
                full_size=self.cardinalities["full"][plate],
                batch_size=self.cardinalities["reduced"][plate]
            ):
                plate_idxs[plate] = plate_idx
                if len(new_plates) == 1:
                    # We exhausted all the required plates
                    values = self.sample(
                        sample_shape=sample_shape,
                        observed_values={
                            rv: (
                                tf.vectorized_map(
                                    lambda t: self.gather_from_tensor(
                                        rv=rv,
                                        tensor=t,
                                        plate_idxs=plate_idxs
                                    ),
                                    value
                                )
                                if len(sample_shape) == 1
                                else
                                self.gather_from_tensor(
                                    rv=rv,
                                    tensor=value,
                                    plate_idxs=plate_idxs
                                )
                            )
                            for rv, value in aggregate_sample.items()
                            if (
                                (rv not in current_excluded_rvs)
                                or
                                (
                                    (rv in rvs_in_plates)
                                    and
                                    (rv in observed_rvs)
                                )
                            )
                        },
                        plate_idxs=plate_idxs,
                        excluded_rvs=excluded_rvs,
                        repeat_observed_rvs=False
                    )

                    for rv in rvs_in_plates:
                        if rv in observed_rvs:
                            pass
                        else:
                            aggregate_sample[rv] = (
                                tf.vectorized_map(
                                    lambda args: self.update_tensor(
                                        rv=rv,
                                        tensor=args[0],
                                        updates=args[1],
                                        plate_idxs=plate_idxs
                                    ),
                                    (
                                        aggregate_sample[rv],
                                        values[rv]
                                    )
                                )
                                if len(sample_shape) == 1
                                else
                                self.update_tensor(
                                    rv=rv,
                                    tensor=aggregate_sample[rv],
                                    updates=values[rv],
                                    plate_idxs=plate_idxs
                                )
                            )
                    del values
                    gc.collect()

                    aggregate_sample = self._rec_sample_aggregate(
                        aggregate_sample=aggregate_sample,
                        sample_shape=sample_shape,
                        current_plates_order=next_plates_order,
                        current_plate_idxs=plate_idxs,
                        current_excluded_rvs=excluded_rvs,
                        observed_rvs=observed_rvs
                    )
                else:
                    # there are still plates to exhaust
                    aggregate_sample = self._rec_sample_aggregate(
                        aggregate_sample=aggregate_sample,
                        sample_shape=sample_shape,
                        current_plates_order=current_plates_order,
                        current_plate_idxs=plate_idxs,
                        current_excluded_rvs=excluded_rvs,
                        observed_rvs=observed_rvs
                    )
        return aggregate_sample

    def sample_aggregate(
        self,
        sample_shape: Tuple[int],
        observed_values: Dict[str, tf.Tensor]
    ) -> Dict[str, tf.Tensor]:
        """Sample from full HBM, progressively
        exploring its branchings and sampling from reduced HBM

        Parameters
        ----------
        sample_shape : Tuple[int]
            (n_samples,)
        observed_values : Dict[str, tf.Tensor]
            {observed_rv: value}

        Returns
        -------
        Dict[str, tf.Tensor]
            {rv: value}
        """

        if len(sample_shape) > 1:
            raise NotImplementedError(
                "Sampling with shape > 1 not implemented"
            )
        aggregate_sample = {
            rv: tf.ones(
                sample_shape + event_shape
            ) * np.nan
            for rv, event_shape in (
                self.hbms["full"]
                .event_shape
                .items()
            )
        }
        for observed_rv, value in observed_values.items():
            aggregate_sample[observed_rv] = repeat_to_shape(
                value,
                axis=0,
                target_shape=sample_shape
                )

        filled_aggregate_sample = self._rec_sample_aggregate(
            aggregate_sample=aggregate_sample,
            sample_shape=sample_shape,
            current_plates_order=self.plates_order,
            current_plate_idxs={},
            current_excluded_rvs=[rv for rv in self.all_rvs],
            observed_rvs=[rv for rv in observed_values.keys()]
        )

        return filled_aggregate_sample

    def compile(
        self,
        train_method: Literal[
            "reverse_KL",
            "unregularized_ELBO",
            "_MAP_regression"
        ],
        n_theta_draws: int,
        pop_full_hbm: bool = False,
        **kwargs
    ) -> None:
        """Wrapper for Keras compilation, additionally
        specifying hyper-parameters

        Parameters
        ----------
        train_method : Literal[
            "reverse_KL",
            "unregularized_ELBO",
            "_MAP_regression"
        ]
            defines which loss to use during training
        n_theta_draws: int
            for Monte Carlo estimation of the ELBO

        Raises
        ------
        NotImplementedError
            train_method not in [
                "reverse_KL",
                "unregularized_ELBO"
            ]
        """
        if train_method not in [
            "reverse_KL",
            "unregularized_ELBO",
            "_MAP_regression"
        ]:
            raise NotImplementedError(
                f"unrecognized train method {train_method}"
            )
        self.train_method = train_method
        self.n_theta_draws = n_theta_draws

        self.plate_samplers = {
            plate: (
                TFFrozenSampler(
                    values=tf.range(0, card)
                )
                if plate in self.frozen_plates
                else
                TFWithoutReplacementSampler(
                    values=tf.range(0, card),
                    size=int(self.cardinalities["reduced"][plate])
                )
            )
            for plate, card in self.cardinalities["full"].items()
        }

        if pop_full_hbm:
            self.hbms.pop("full")

        super().compile(**kwargs)

    def weighted_sum_log_prob_parts(
        self,
        log_prob_parts: Dict[str, tf.Tensor]
    ) -> tf.Tensor:
        """log prob parts where we apply the
        cardinality ratios between full and reduced HBMs

        Parameters
        ----------
        log_prob_parts : Dict[str, tf.Tensor]
            unweighted log prob parts

        Returns
        -------
        tf.Tensor
            weighted log prob parts
        """
        return tf.reduce_sum(
            [
                self.log_prob_factors[rv]
                *
                log_prob_part
                for rv, log_prob_part in log_prob_parts.items()
            ],
            axis=0
        )

    def train_step(
        self,
        data: Union[
            Tuple[tf.Tensor],
            Tuple[tf.Tensor, Dict[str, List[int]]]
        ]
    ) -> Dict:
        """keras train step to be compiled

        Parameters
        ----------
        data : Option[
            Tuple[tf.Tensor],
            Tuple[tf.Tensor, Dict[str, List[int]]]
        ]
            either (full_data_slice) or (reduced_data_slice, plate_idxs)

        Returns
        -------
        Dict
            {loss_type: loss_value}
        """
        if len(data) == 1:
            observed_values = data[0]
            plate_idxs = {
                plate: sampler.sample()
                for plate, sampler in self.plate_samplers.items()
            }
            sliced_observed_values = {
                observed_rv: self.gather_from_tensor(
                    rv=observed_rv,
                    tensor=observed_value,
                    plate_idxs=plate_idxs
                )
                for observed_rv, observed_value in observed_values.items()
            }
        elif len(data) == 2:
            sliced_observed_values, plate_idxs = data

        if self.train_method in [
            "reverse_KL",
            "unregularized_ELBO"
        ]:
            with tf.GradientTape() as tape:
                values, latent_values = self.sample(
                    sample_shape=(self.n_theta_draws,),
                    observed_values=sliced_observed_values,
                    plate_idxs=plate_idxs,
                    return_observed_values=True,
                    return_latent_values=True
                )
                p_log_prob_parts = (
                    self.hbms["reduced"]
                    .log_prob_parts(values)
                )
                p = self.weighted_sum_log_prob_parts(p_log_prob_parts)

                if self.train_method == "unregularized_ELBO":
                    indicator = tf.math.logical_and(
                        ~tf.math.is_nan(p),
                        ~tf.math.is_inf(p)
                    )
                    loss = tf.reduce_mean(-p[indicator])
                elif self.train_method == "reverse_KL":
                    q_log_prob_parts = self.log_prob_parts(
                        values=values,
                        latent_values=latent_values,
                        plate_idxs=plate_idxs
                    )
                    q = self.weighted_sum_log_prob_parts(q_log_prob_parts)
                    indicator = tf.math.logical_and(
                        tf.math.logical_and(
                            ~tf.math.is_nan(p),
                            ~tf.math.is_inf(p)
                        ),
                        tf.math.logical_and(
                            ~tf.math.is_nan(q),
                            ~tf.math.is_inf(q)
                        )
                    )
                    loss = tf.reduce_mean(q[indicator] - p[indicator])
        elif self.train_method == "_MAP_regression":
            with tf.GradientTape() as tape:
                values = self._MAP_regression(
                    observed_values=sliced_observed_values,
                    plate_idxs=plate_idxs,
                    return_observed_values=True
                )

                p_log_prob_parts = (
                    self.hbms["reduced"]
                    .log_prob_parts(values)
                )
                p = self.weighted_sum_log_prob_parts(p_log_prob_parts)
                indicator = tf.math.logical_and(
                    ~tf.math.is_nan(p),
                    ~tf.math.is_inf(p)
                )
                loss = tf.where(
                    indicator,
                    -p,
                    0.
                )

        trainable_vars = self.trainable_variables
        for bijector in self.bijectors.values():
            trainable_vars += bijector.trainable_variables
        for bijector_array in self.bijector_arrays.values():
            trainable_vars += bijector_array.trainable_variables

        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(
            zip(
                (
                    tf.where(
                        tf.math.is_nan(
                            grad
                        ),
                        tf.zeros_like(grad),
                        grad
                    )
                    if grad is not None
                    else None
                    for grad in gradients
                ),
                trainable_vars
            )
        )

        return {self.train_method: loss}

    def full_train_step(
        self,
        data: Tuple[Dict[str, tf.Tensor]]
    ) -> Dict:
        """Keras train step, not performing stochastic training

        Parameters
        ----------
        data : Tuple[Dict]
            data from the generative_hbm
            various keys will be used depending
            on the train_method

        Returns
        -------
        Dict
            {loss_type: value} depending on
            the train method
        """
        observed_values = data[0]

        if self.train_method in [
            "reverse_KL",
            "unregularized_ELBO"
        ]:
            with tf.GradientTape() as tape:
                values, latent_values = self.sample(
                    sample_shape=(self.n_theta_draws,),
                    observed_values=observed_values,
                    return_observed_values=True,
                    return_latent_values=True
                )
                p = (
                    self.hbms["full"]
                    .log_prob(values)
                )

                if self.train_method == "unregularized_ELBO":
                    loss = tf.reduce_mean(-p)
                elif self.train_method == "reverse_KL":
                    q = self.log_prob(
                        values=values,
                        latent_values=latent_values
                    )
                    loss = tf.reduce_mean(q - p)

        trainable_vars = self.trainable_variables
        for bijector in self.bijectors.values():
            trainable_vars += bijector.trainable_variables
        for bijector_array in self.bijector_arrays.values():
            trainable_vars += bijector_array.trainable_variables

        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(
            zip(
                (
                    tf.where(
                        tf.math.is_nan(
                            grad
                        ),
                        tf.zeros_like(grad),
                        grad
                    )
                    if grad is not None
                    else None
                    for grad in gradients
                ),
                trainable_vars
            )
        )

        return {self.train_method: loss}

    def save_weights_np(
        self,
        path: str
    ) -> None:
        """Saves weights to path in numpy format
        """
        weights = self.get_weights()
        np.save(
            path,
            weights,
            allow_pickle=True
        )

    def load_weights_np(
        self,
        path: str
    ) -> None:
        """Loads weights from path in numpy format
        """
        weights = np.load(
            path,
            allow_pickle=True
        )
        self.set_weights(weights)
