from typing import (
    Dict,
    Literal,
    Tuple,
    List,
    Optional,
    Union,
    Any
)
from collections import defaultdict
from functools import partial

import networkx as nx
import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp

from networkx.algorithms import tree
from ...set_transformer.layers import (
    RFF
)
from ...set_transformer.models import (
    SetTransformer
)
from ...normalizing_flow.bijectors import (
    ConditionalNFChain
)
from ..graph import get_inference_order
from ...utils import (
    repeat_to_shape,
)

tfd = tfp.distributions
tfb = tfp.bijectors
Root = tfd.JointDistributionCoroutine.Root


class FrozenPartial():
    """Class for a partial function usage with fixed kwargs
    """

    def __init__(
        self,
        func,
        frozen_kwargs: Dict[str, Any]
    ) -> None:
        self.func = func
        self.frozen_kwargs = frozen_kwargs

    def __call__(
        self,
        *variable_args: Dict[str, Any]
    ) -> Any:
        return self.func(
            *variable_args,
            **self.frozen_kwargs
        )


class PAVEFamily(tf.keras.Model):
    """Implementation of the PAVI-E architecture
    """

    def __init__(
        self,
        full_hbm: tfd.JointDistributionNamed,
        reduced_hbm: tfd.JointDistributionNamed,
        plates_per_rv: Dict[str, List[str]],
        link_functions: Dict[str, tfb.Bijector],
        posterior_schemes_kwargs: Dict[
            str,
            Tuple[
                Literal[
                    'observed',
                    'flow',
                    'non-parametric',
                    'mean-field',
                    'ASVI'
                ],
                Dict
            ]
        ],
        encoding_sizes: Dict[
            Tuple[str, ...], int
        ],
        embedder_rff_kwargs: Dict,
        set_transformer_kwargs: Dict,
        expansion_encoding_sizes: Dict[
            str, int
        ] = {},
        frozen_plates: List[str] = [],
        **kwargs
    ):
        """PAVI-E variational family, using an encoder of the observed data

        Parameters
        ----------
        full_hbm : tfd.JointDistributionNamed
            HBM instanciated with full cardinalities
        reduced_hbm : tfd.JointDistributionNamed
            HBM instanciated with reduced cardinalities
        plates_per_rv : Dict[str, List[str]]
            dict {rv: ['P0', 'P1']}
        link_functions : Dict[str, tfb.Bijector]
            functions projecting the RV's event space into an unbounded real space
        posterior_schemes_kwargs : Dict[ 
            str,
            Tuple[
                Literal[
                    'observed',
                    'flow',
                    'non-parametric',
                    'mean-field',
                    'ASVI'
                ],
                Dict
            ]
        ]
            for each RV, specify a posterior scheme, describing
            how we intend to approximate the RV's posterior, and the
            kwargs associated to the method
        encoding_sizes : Dict[ Tuple[str, ...], int ]
            describe te encoding size associated to each plate level
        embedder_rff_kwargs : Dict
            RFF kwargs for the embeder of observed RV's value
        set_transformer_kwargs : Dict
            SetTransformer kwargs for encoder plate contraction
        expansion_encoding_sizes : Dict[ str, int ], optional
            when a child plate level supresses a plate from its parent, size
            of the encoding, by default {}
        frozen_plates : List[str], optional
            not sampled plates, by default []
        """
        super().__init__(**kwargs)

        self.hbms = {
            "reduced": reduced_hbm,
            "full": full_hbm
        }

        self.plates_per_rv = plates_per_rv

        self.link_functions = link_functions
        self.encoding_sizes = encoding_sizes
        self.expansion_encoding_sizes = expansion_encoding_sizes
        self.posterior_schemes, self.posterior_kwargs = {}, {}
        for rv, (posterior_scheme, posterior_kwargs) in (
            posterior_schemes_kwargs.items()
        ):
            self.posterior_schemes[rv] = posterior_scheme
            self.posterior_kwargs[rv] = posterior_kwargs

        self.frozen_plates = frozen_plates

        self.analyse_generative_hbm_graph()
        self.build_architecture(
            set_transformer_kwargs=set_transformer_kwargs,
            embedder_rff_kwargs=embedder_rff_kwargs
        )

    def _rec_collect_observed_rv_dependencies(
        self,
        observed_rv: str,
        current_rv: str
    ) -> None:
        """Recursively goes up the rv dependencies to determine
        the ancestry of observed rvs
        Parameters
        ----------
        observed_rv : str
            current observed rv, the ancestry of which
            we try to determine
        current_rv : str
            current explored rv in the recursion
        """
        for parent in self.parents[current_rv]:
            self.observed_rv_dependencies[parent].add(
                (observed_rv, tuple(self.plates_per_rv[parent]))
            )
            self._rec_collect_observed_rv_dependencies(
                observed_rv=observed_rv,
                current_rv=parent
            )

    def analyse_generative_hbm_graph(self) -> None:
        """Analyse full and reduced HBM to determine items such as
        batch sizes, etc...
        """

        all_rvs = self.hbms["full"].model.keys()
        self.all_rvs = all_rvs

        self.observed_rvs = [
            rv
            for rv, scheme in self.posterior_schemes.items()
            if scheme == "observed"
        ]
        self.latent_rvs = [
            rv for rv in all_rvs if rv not in self.observed_rvs
        ]

        graph = self.hbms["full"].resolve_graph()
        self.parents = {}
        self.rvs_per_plates = defaultdict(lambda: [])
        for child, parents in graph:
            self.parents[child] = list(parents)
            child_p = tuple(self.plates_per_rv[child])
            self.rvs_per_plates[child_p].append(child)

        parent_plates = {
            plates: set()
            for plates in self.rvs_per_plates.keys()
        }
        for child, parents in graph:
            child_p = tuple(self.plates_per_rv[child])
            for parent in parents:
                parent_p = tuple(self.plates_per_rv[parent])
                if child_p != parent_p:
                    parent_plates[child_p].add(parent_p)

        plate_graph = nx.from_dict_of_lists(
            parent_plates,
            create_using=nx.DiGraph
        )
        plate_branching = tree.maximum_branching(plate_graph)
        self.parent_plates = nx.to_dict_of_lists(plate_branching)
        self.plates_order = get_inference_order(
            tuple(
                (child_p, tuple(parent_ps))
                for child_p, parent_ps in self.parent_plates.items()
            )
        )        

        self.observed_rv_dependencies = defaultdict(lambda: set())
        for observed_rv in self.observed_rvs:
            self._rec_collect_observed_rv_dependencies(
                observed_rv=observed_rv,
                current_rv=observed_rv
            )

        batch_shapes = {
            hbm_type: {}
            for hbm_type in self.hbms.keys()
        }
        event_shapes = {
            hbm_type: {}
            for hbm_type in self.hbms.keys()
        }
        for hbm_type, hbm in self.hbms.items():
            total_shape = hbm.event_shape
            for rv in all_rvs:
                n_plates = len(self.plates_per_rv[rv])
                batch_shapes[hbm_type][rv] = total_shape[rv][:n_plates]
                event_shapes[hbm_type][rv] = total_shape[rv][n_plates:]
        for rv in all_rvs:
            assert event_shapes["full"][rv] == event_shapes["reduced"][rv], (
                f"Inconsistent event shape for rv {rv}: "
                f"Shape {event_shapes['full'][rv]} in full HBM vs "
                f"Shape {event_shapes['reduced'][rv]} in reduced HBM."
            )

        cardinalities = {
            hbm_type: {}
            for hbm_type in self.hbms.keys()
        }

        for rv in all_rvs:
            for plate_idx, plate in enumerate(self.plates_per_rv[rv]):
                card = {
                    hbm_type: batch_shape[rv][plate_idx]
                    for hbm_type, batch_shape in batch_shapes.items()
                }
                if plate not in cardinalities["full"].keys():
                    for hbm_type in cardinalities.keys():
                        cardinalities[hbm_type][plate] = card[hbm_type]
                else:
                    for hbm_type in cardinalities.keys():
                        assert cardinalities[hbm_type][plate] == card[hbm_type], (
                            f"Problem in plates declaration ({hbm_type} HBM): "
                            f"inconsistent cardinality for plate {plate} "
                            f"{card[hbm_type]} (observed for RV {rv}) "
                            f"vs {cardinalities[hbm_type][plate]} (parent RV)."
                        )

        self.event_shapes = event_shapes["full"]
        self.batch_shapes = batch_shapes
        self.cardinalities = cardinalities

        for plate in self.frozen_plates:
            full_card = cardinalities["full"][plate]
            reduced_card = cardinalities["reduced"][plate]
            assert full_card == reduced_card, (
                f"Problem with frozen plate {plate}: "
                f"full HBM cardinality ({full_card}) should be equal "
                f"to reduced HBM cardinality ({reduced_card})."
            )

        self.plates_full_batch_shapes = {}
        for rv, plates in self.plates_per_rv.items():
            if rv in self.observed_rvs:
                continue
            self.plates_full_batch_shapes[tuple(plates)] = batch_shapes["full"][rv]

        plate_ratios = {
            plate: full_card / cardinalities["reduced"][plate]
            for plate, full_card in cardinalities["full"].items()
        }
        self.log_prob_factors = {
            rv: tf.reduce_prod(
                [plate_ratios[plate] for plate in plates]
            )
            for rv, plates in self.plates_per_rv.items()
        }

    def _rec_build_encoder(
        self,
        set_transformer_kwargs: Dict,
        current_encoding_level: Tuple[str, Tuple[str, ...]]
    ) -> None:
        """Recursively goes up the plate dependencies
        to construct encoders
        Parameters
        ----------
        set_transformer_kwargs : Dict
            SetTransformer kwargs to contract plates
        current_encoding_level : Tuple[str, Tuple[str, ...]]
            tuple of (observed_rv, plates)
        """
        current_observed_rv, current_plates = current_encoding_level
        for parent_plates in self.parent_plates[current_plates]:
            encoding_level = (current_observed_rv, tuple(parent_plates))
            contracted_plates = [
                plate
                for plate in current_plates
                if plate not in parent_plates
            ]

            layers = []
            for contracted_plate in reversed(contracted_plates):
                attention_axis = 1 + current_plates.index(contracted_plate)
                layers.append(
                    SetTransformer(
                        **set_transformer_kwargs,
                        attention_axes=(attention_axis,)
                    )
                )
                layers.append(
                    tf.keras.layers.Lambda(
                        partial(tf.squeeze, axis=attention_axis)
                    )
                )

            plates_after_contraction = [
                plate
                for plate in current_plates
                if plate in parent_plates
            ]
            for plate_idx, parent_plate in enumerate(parent_plates):
                if parent_plate not in plates_after_contraction:
                    if parent_plate not in self.expansion_encodings.keys():
                        self.expansion_encodings[parent_plate] = self.add_weight(
                            shape=(
                                (self.cardinalities["full"][parent_plate],)
                                + (self.expansion_encoding_sizes[parent_plate],)
                            ),
                            initializer="random_normal",
                            trainable=True,
                            name=f"expansion_encoding_{parent_plate}"
                        )

                    layers.append(
                        tf.keras.layers.Lambda(
                            FrozenPartial(
                                lambda encoding, target_shape, axis, expansion_encoding: tf.concat(
                                    [
                                        repeat_to_shape(
                                            encoding,
                                            target_shape=target_shape,
                                            axis=axis
                                        ),
                                        repeat_to_shape(
                                            repeat_to_shape(
                                                expansion_encoding,
                                                target_shape=encoding.shape[:axis],
                                                axis=-3
                                            ),
                                            target_shape=encoding.shape[axis:-1],
                                            axis=-2
                                        )
                                    ],
                                    axis=-1
                                ),
                                frozen_kwargs=dict(
                                    target_shape=(self.cardinalities["full"][parent_plate],),
                                    axis=plate_idx + 1,
                                    expansion_encoding=self.expansion_encodings[parent_plate]
                                )
                            )
                        )
                    )

            layers.append(
                RFF(
                    units_per_layers=[self.encoding_sizes[parent_plates]],
                    dense_kwargs=dict()
                )
            )

            self.encoders[encoding_level] = tf.keras.Sequential(
                layers=layers
            )
            self._rec_build_encoder(
                set_transformer_kwargs=set_transformer_kwargs,
                current_encoding_level=encoding_level
            )

    def build_architecture(
        self,
        set_transformer_kwargs: Dict,
        embedder_rff_kwargs: Dict
    ) -> None:
        """Build architecture: encoder and conditional density estimators

        Parameters
        ----------
        set_transformer_kwargs : Dict
            SetTransformer kwargs to contract plates
        embedder_rff_kwargs : Dict
            RFF kwargs for observed RV values
        """

        self.expansion_encodings = {}
        self.encoders = {}

        for observed_rv in self.observed_rvs:
            current_encoding_level = (
                observed_rv,
                tuple(self.plates_per_rv[observed_rv])
            )
            self.encoders[current_encoding_level] = RFF(
                **embedder_rff_kwargs
            )
            self._rec_build_encoder(
                set_transformer_kwargs=set_transformer_kwargs,
                current_encoding_level=current_encoding_level
            )

        event_sizes = {}
        self.formatters = {}
        self.bijectors = {}

        self.asvi_rffs = {}

        self.mean_field_rffs = {}
        self.mean_field_parameter_reshapers = {}

        self.bijector_arrays = {}

        for rv in self.all_rvs:
            constrained_shape = (
                self
                .link_functions[rv]
                .inverse_event_shape(
                    self.event_shapes[rv]
                )
            )
            event_size = tf.reduce_prod(
                constrained_shape
            )
            event_sizes[rv] = event_size

            reshaper = tfb.Reshape(
                event_shape_in=(event_size,),
                event_shape_out=constrained_shape
            )

            formatter = tfb.Chain(
                bijectors=[
                    self.link_functions[rv],
                    reshaper,
                ]
            )
            self.formatters[rv] = formatter

            if self.posterior_schemes[rv] in [
                "observed",
                "non-parametric"
            ]:
                pass
            elif self.posterior_schemes[rv] == "mean-field":

                self.mean_field_rffs[rv] = {}
                self.mean_field_parameter_reshapers[rv] = {}
                for parameter, shape in (
                    self.posterior_kwargs[rv]
                    ["parameter_shapes"]
                    .items()
                ):
                    parameter_size = int(tf.reduce_prod(shape))
                    self.mean_field_rffs[rv][parameter] = RFF(
                        units_per_layers=[parameter_size],
                        dense_kwargs=dict()
                    )
                    self.mean_field_parameter_reshapers[rv][parameter] = tfb.Reshape(
                        event_shape_in=(parameter_size,),
                        event_shape_out=shape
                    )
            elif self.posterior_schemes[rv] == "flow":
                nf = ConditionalNFChain(
                    event_size=event_size.numpy(),
                    conditional_event_size=int(
                        len(self.observed_rv_dependencies[rv])
                        *
                        self.encoding_sizes[tuple(self.plates_per_rv[rv])]
                    ),
                    name=f"nf_{rv}",
                    **self.posterior_kwargs[rv]["conditional_nf_chain_kwargs"]
                )

                self.bijectors[rv] = tfb.Chain(
                    bijectors=[
                        formatter,
                        nf,
                        tfb.Invert(formatter)
                    ],
                    name=f"chain_{rv}"
                )
            elif self.posterior_schemes[rv] == "ASVI":
                parents = self.parents[rv]
                if len(parents) == 0:
                    raise NotImplementedError(
                        f"Provided RV {rv} with `ASVI` posterior scheme "
                        "yet the RV has no parents in the graph."
                    )

                self.asvi_rffs[rv] = {}

                for parent in parents:
                    self.asvi_rffs[rv][parent] = {
                        "alpha": RFF(
                            units_per_layers=[event_sizes[parent]],
                            dense_kwargs=dict()
                        ),
                        "lambda": RFF(
                            units_per_layers=[1],
                            dense_kwargs=dict(activation="sigmoid")
                        )
                    }
            else:
                raise NotImplementedError(
                    f"Provided RV {rv} with unknown posterior scheme "
                    f"`{self.posterior_schemes[rv]}`."
                )

    def repeat_to_plates(
        self,
        tensor: tf.Tensor,
        source_plates: List[str],
        target_plates: List[str],
        batch_ndims: int,
        hbm_type: str
    ) -> tf.Tensor:
        """Repeats a tensor with source plates
        into a tensor with target plates

        Parameters
        ----------
        tensor : tf.Tensor
            value to repeat
        source_plates : List[str]
            plates for the input tensor
        target_plates : List[str]
            plates for the output tensor
        batch_ndims : int
            to ignore in input tensor
        hbm_type : str
            "full" or "reduced", to infer cardinalities

        Returns
        -------
        tf.Tensor
            repeated tensor
        """
        repeated_tensor = tensor
        for plate_idx, plate in enumerate(target_plates):
            if plate not in source_plates:
                insertion_axis = batch_ndims + plate_idx
                repeated_tensor = tf.expand_dims(
                    repeated_tensor,
                    axis=insertion_axis
                )
                repeated_tensor = tf.repeat(
                    repeated_tensor,
                    repeats=self.cardinalities[hbm_type][plate],
                    axis=insertion_axis
                )

        return repeated_tensor

    def _rec_encode_data(
        self,
        encodings: Dict[Tuple[str, Tuple[str, ...]], tf.Tensor],
        current_encoding_level: Tuple[str, Tuple[str, ...]]
    ) -> Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]:
        """Recursively goes upstream in the plates dependencies
        to generate encodings
        Parameters
        ----------
        encodings : Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]
            dict of {encoding_level: value} for encodings
        current_encoding_level : Tuple[str, Tuple[str, ...]]
            tuple of (observed_rv, plates)
        Returns
        -------
        Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]
            updated encodings
        """
        current_observed_rv, current_plates = current_encoding_level
        for parent_plates in self.parent_plates[current_plates]:
            encoding_level = (current_observed_rv, tuple(parent_plates))
            encodings[encoding_level] = (
                self.encoders[encoding_level](
                    encodings[current_encoding_level]
                )
            )
            encodings = self._rec_encode_data(
                encodings=encodings,
                current_encoding_level=encoding_level
            )

        return encodings

    def encode_data(
        self,
        data: Dict[str, tf.Tensor]
    ) -> Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]:
        """Encodes the data in the backward direction
        Parameters
        ----------
        data : Dict[str, tf.Tensor]
            dict of {rv: value} for observed rvs
        Returns
        -------
        Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]
            dict of {encoding_level: value} for encodings
        """
        encodings = {}
        for observed_rv in self.observed_rvs:
            encoding_level = (observed_rv, tuple(self.plates_per_rv[observed_rv]))
            encodings[encoding_level] = self.encoders[encoding_level](
                self.formatters[observed_rv]
                .inverse(
                    data[observed_rv]
                )
            )
            self._rec_encode_data(
                encodings=encodings,
                current_encoding_level=encoding_level
            )

        return {
            rv: tf.concat(
                [
                    encodings[key]
                    for key in self.observed_rv_dependencies[rv]
                ],
                axis=-1
            )
            for rv in self.latent_rvs
        }

    def sample(
        self,
        sample_shape: Tuple[int],
        observed_values: Dict[str, tf.Tensor],
        hbm_type: Literal["full", "reduced"] = "full",
        encodings: Optional[Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]] = None,
        return_observed_values: bool = False,
        return_latent_values: bool = False,
        repeat_observed_rvs_and_encodings: bool = True,
        excluded_rvs: List[str] = []
    ) -> Union[
        Dict[str, tf.Tensor],
        Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor]]
    ]:
        """sample from the architecture

        Parameters
        ----------
        sample_shape : Tuple[int]
            shape for the sample (n_samples,)
        observed_values : Dict[str, tf.Tensor]
            {observed_rv: value}
        hbm_type : Literal['full', 'reduced'], optional
            determines sampling from full or reduced HBM, by default "full"
        encodings : Optional[Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]], optional
            {plate_level: encoding_value}, by default None
        return_observed_values : bool, optional
            by default False
        return_latent_values : bool, optional
            before application of NFs, by default False
        repeat_observed_rvs_and_encodings : bool, optional
            to the sample shape, by default True
        excluded_rvs : List[str], optional
            do not sample those RVs, by default []

        Returns
        -------
        Union[ Dict[str, tf.Tensor], Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor]] ]
            either {rv: value} or ({rv: value}, {rv: latent_value})
        """
        sample_ndims = len(sample_shape)
        if sample_ndims != 1:
            raise NotImplementedError(
                f"Sampling with batch shape of len {sample_ndims} "
                "not implemented."
            )
        sample_size = sample_shape[0]

        latent_values = {}
        observed_rvs = [
            observed_rv for observed_rv in observed_values.keys()
        ]
        ignored_rvs = observed_rvs + excluded_rvs
        values = {
            observed_rv: (
                tf.repeat(
                    observed_value,
                    repeats=sample_size,
                    axis=0
                )
                if repeat_observed_rvs_and_encodings
                else
                observed_value
            )
            for observed_rv, observed_value in observed_values.items()
        }

        if encodings is None:
            encodings = self.encode_data(
                observed_values
            )

        if repeat_observed_rvs_and_encodings:
            encodings = {
                rv: tf.repeat(
                    encoding,
                    repeats=sample_size,
                    axis=0
                )
                for rv, encoding in encodings.items()
            }

        effective_sample_shape = (
            values[observed_rvs[0]].shape[0],
        )

        for rv in self.all_rvs:
            if rv in ignored_rvs:
                continue
            if self.posterior_schemes[rv] == "non-parametric":
                parents = self.parents[rv]
                if len(parents) == 0:
                    values[rv] = self.hbms[hbm_type].model[rv].sample(
                        effective_sample_shape
                    )
                else:
                    values[rv] = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .sample()
                    )
            elif self.posterior_schemes[rv] == "mean-field":
                parameters = {
                    parameter: (
                        self.posterior_kwargs[rv]
                        ["parameter_constraints"]
                        [parameter](
                            self.mean_field_parameter_reshapers[rv][parameter](
                                rff(encodings[rv])
                            )
                        )
                    )
                    for parameter, rff in self.mean_field_rffs[rv].items()
                }
                values[rv] = (
                    self.posterior_kwargs[rv]["functor"](
                        **parameters
                    )
                    .sample(effective_sample_shape)
                )
            elif self.posterior_schemes[rv] == "flow":
                parents = self.parents[rv]
                if len(parents) == 0:
                    latent_values[rv] = self.hbms[hbm_type].model[rv].sample(
                        effective_sample_shape
                    )
                else:
                    latent_values[rv] = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .sample()
                    )
                values[rv] = (
                    self.bijectors[rv]
                    .forward(
                        latent_values[rv],
                        **{
                            f"nf_{rv}": {
                                "conditional_input": encodings[rv]
                            }
                        }
                    )
                )
            elif self.posterior_schemes[rv] == "ASVI":
                model_kwargs = {}
                for parent in self.parents[rv]:
                    alpha, lambd = (
                        self.asvi_rffs[rv][parent][key](
                            encodings[rv]
                        )
                        for key in ["alpha", "lambda"]
                    )
                    formatter = self.formatters[parent]
                    parent_value = self.repeat_to_plates(
                        tensor=formatter.inverse(values[parent]),
                        source_plates=self.plates_per_rv[parent],
                        target_plates=self.plates_per_rv[rv],
                        batch_ndims=len(sample_shape),
                        hbm_type=hbm_type
                    )
                    model_kwargs[parent] = formatter.forward(
                        lambd * parent_value
                        + (1 - lambd) * alpha
                    )
                values[rv] = (
                    self.posterior_kwargs[rv]["functor"](
                        **model_kwargs
                    )
                    .sample()
                )

        if return_latent_values:
            if return_observed_values:
                return values, latent_values
            else:
                return tuple(
                    {
                        rv: value
                        for rv, value in values.items()
                        if rv not in observed_rvs
                    }
                    for values in [values, latent_values]
                )
        else:
            if return_observed_values:
                return values
            else:
                return {
                    rv: value
                    for rv, value in values.items()
                    if rv not in observed_rvs
                }

    def log_prob_parts(
        self,
        values: Dict[str, tf.Tensor],
        latent_values: Optional[Dict[str, tf.Tensor]] = None,
        hbm_type: Literal["full", "reduced"] = "full",
        encodings: Optional[Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]] = None,
        encodings_repeats: int = 1
    ) -> Dict[str, tf.Tensor]:
        """Return the log prob for the conditional distributions

        Parameters
        ----------
        values : Dict[str, tf.Tensor]
            {rv: value}
        latent_values : Optional[Dict[str, tf.Tensor]], optional
            {rv: latent_value} (before NF), by default None
        hbm_type : Literal[&quot;full&quot;, &quot;reduced&quot;], optional
            'reduced' or 'full' HBM to sample, by default "full"
        encodings : Optional[Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]], optional
            {plate_level: encoding_value}, by default None
        encodings_repeats : int, optional
            if encodings need to be repeated to the sample shape, by default 1

        Returns
        -------
        Dict[str, tf.Tensor]
            {rv: log_prob}
        """
        first_latent_rv = self.latent_rvs[0]
        batch_ndims = (
            len(values[first_latent_rv].shape)
            - len(self.hbms[hbm_type].event_shape[first_latent_rv])
        )
        if batch_ndims != 1:
            raise NotImplementedError(
                f"Sample shape of len {batch_ndims} not implemented, "
                "stick to 0 or 1"
            )

        if encodings is None:
            encodings = self.encode_data(
                {
                    observed_rv: values[observed_rv]
                    for observed_rv in self.observed_rvs
                }
            )
        else:
            encodings = {
                rv: tf.repeat(
                    encoding,
                    repeats=encodings_repeats,
                    axis=0
                )
                for rv, encoding in encodings.items()
            }
        encoding_effective_sample_size = encodings[first_latent_rv].shape[0]
        values_effective_sample_size = values[first_latent_rv].shape[0]
        if encoding_effective_sample_size != values_effective_sample_size:
            raise AssertionError(
                f"Encodings effective sample size ({encoding_effective_sample_size}) "
                f"is different from values effective sample size ({values_effective_sample_size}). "
                "Use `encoding_repeats` accordingly."
            )

        if latent_values is None:
            latent_values = {}
            for rv, value in values.items():
                if self.posterior_schemes[rv] == "flow":
                    latent_values[rv] = self.bijectors[rv].inverse(
                        value,
                        **{
                            f"nf_{rv}": {
                                "conditional_input": encodings[rv]
                            }
                        }
                    )

        log_prob_parts = {}
        for rv in self.latent_rvs:
            if self.posterior_schemes[rv] == "non-parametric":
                parents = self.parents[rv]
                if len(parents) == 0:
                    log_prob_parts[rv] = (
                        self.hbms[hbm_type]
                        .model[rv]
                        .log_prob(
                            values[rv]
                        )
                    )
                else:
                    log_prob_parts[rv] = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .log_prob(
                            values[rv]
                        )
                    )
            elif self.posterior_schemes[rv] == "mean-field":
                parameters = {
                    parameter: (
                        self.posterior_kwargs[rv]
                        ["parameter_constraints"]
                        [parameter](
                            self.mean_field_parameter_reshapers[rv][parameter](
                                rff(encodings[rv])
                            )
                        )
                    )
                    for parameter, rff in self.mean_field_rffs[rv].items()
                }
                log_prob_parts[rv] = (
                    self.posterior_kwargs[rv]["functor"](
                        **parameters
                    )
                    .log_prob(
                        values[rv]
                    )
                )
            elif self.posterior_schemes[rv] == "flow":
                inv_log_det_jac = self.bijectors[rv].inverse_log_det_jacobian(
                    values[rv],
                    **{
                        f"nf_{rv}": {
                            "conditional_input": encodings[rv]
                        }
                    }
                )
                inv_log_det_jac = tf.vectorized_map(
                    tf.reduce_sum,
                    inv_log_det_jac
                )

                parents = self.parents[rv]
                if len(parents) == 0:
                    latent_log_prob_part = (
                        self.hbms[hbm_type]
                        .model[rv]
                        .log_prob(
                            latent_values[rv]
                        )
                    )
                else:
                    latent_log_prob_part = (
                        self.hbms[hbm_type]
                        .model[rv](
                            **{
                                parent: values[parent]
                                for parent in parents
                            }
                        )
                        .log_prob(
                            latent_values[rv]
                        )
                    )

                log_prob_parts[rv] = latent_log_prob_part + inv_log_det_jac
            elif self.posterior_schemes[rv] == "ASVI":
                model_kwargs = {}
                for parent in self.parents[rv]:
                    alpha, lambd = (
                        self.asvi_rffs[rv][parent][key](
                            encodings[rv]
                        )
                        for key in ["alpha", "lambda"]
                    )
                    formatter = self.formatters[parent]
                    parent_value = self.repeat_to_plates(
                        tensor=formatter.inverse(values[parent]),
                        source_plates=self.plates_per_rv[parent],
                        target_plates=self.plates_per_rv[rv],
                        batch_ndims=batch_ndims,
                        hbm_type=hbm_type
                    )
                    model_kwargs[parent] = formatter.forward(
                        lambd * parent_value
                        + (1 - lambd) * alpha
                    )
                log_prob_parts[rv] = (
                    self.posterior_kwargs[rv]["functor"](
                        **model_kwargs
                    )
                    .log_prob(
                        values[rv]
                    )
                )

        return log_prob_parts

    def log_prob(
        self,
        values: Dict[str, tf.Tensor],
        latent_values: Optional[Dict[str, tf.Tensor]] = None,
        hbm_type: Literal["full", "reduced"] = "full",
        encodings: Optional[Dict[Tuple[str, Tuple[str, ...]], tf.Tensor]] = None,
        encodings_repeats: int = 1
    ) -> tf.Tensor:
        """Calls log_prob_parts and sums out the result

        Returns
        -------
        tf.Tensor
            log_prob
        """
        return tf.reduce_sum(
            [
                log_prob_part
                for log_prob_part in self.log_prob_parts(
                    values=values,
                    latent_values=latent_values,
                    hbm_type=hbm_type,
                    encodings=encodings,
                    encodings_repeats=encodings_repeats
                ).values()
            ],
            axis=0
        )

    def compile(
        self,
        train_method: Literal[
            "reverse_KL",
            "unregularized_ELBO"
        ],
        n_theta_draws: int,
        pop_full_hbm: bool = False,
        **kwargs
    ) -> None:
        """Wrapper for Keras compilation, additionally
        specifying hyper-parameters

        Parameters
        ----------
        train_method : Literal[
            "reverse_KL",
            "unregularized_ELBO"
        ]
            defines which loss to use during training
        n_theta_draws: int
            for Monte Carlo estimation of the ELBO
        pop_full_hbm: bool = False
            pop full HBM from architecture to avoid OOM
            when evaluating weights

        Raises
        ------
        NotImplementedError
            train_method not in [
                "reverse_KL",
                "unregularized_ELBO"
            ]
        """
        if train_method not in [
            "reverse_KL",
            "unregularized_ELBO",
            "_MAP_regression"
        ]:
            raise NotImplementedError(
                f"unrecognized train method {train_method}"
            )
        self.train_method = train_method
        self.n_theta_draws = n_theta_draws

        if pop_full_hbm:
            self.hbms.pop("full")

        super().compile(**kwargs)

    def weighted_sum_log_prob_parts(
        self,
        log_prob_parts: Dict[str, tf.Tensor]
    ) -> tf.Tensor:
        """log prob parts where we apply the
        cardinality ratios between full and reduced HBMs

        Parameters
        ----------
        log_prob_parts : Dict[str, tf.Tensor]
            unweighted log prob parts

        Returns
        -------
        tf.Tensor
            weighted log prob parts
        """
        return tf.reduce_sum(
            [
                self.log_prob_factors[rv]
                *
                log_prob_part
                for rv, log_prob_part in log_prob_parts.items()
            ],
            axis=0
        )

    def train_step(
        self,
        data: Tuple[tf.Tensor]
    ) -> Dict:
        """keras train step to be compiled

        Parameters
        ----------
        data : Tuple[tf.Tensor]
            reduced data slice

        Returns
        -------
        Dict
            {loss_type: loss_value}
        """

        observed_values = data[0]

        if self.train_method in [
            "reverse_KL",
            "unregularized_ELBO"
        ]:
            with tf.GradientTape() as tape:

                encodings = self.encode_data(
                    observed_values
                )

                values, latent_values = self.sample(
                    sample_shape=(self.n_theta_draws,),
                    observed_values=observed_values,
                    hbm_type="reduced",
                    encodings=encodings,
                    return_observed_values=True,
                    return_latent_values=True
                )
                p_log_prob_parts = (
                    self.hbms["reduced"]
                    .log_prob_parts(values)
                )
                p = self.weighted_sum_log_prob_parts(p_log_prob_parts)

                if self.train_method == "unregularized_ELBO":
                    loss = tf.reduce_mean(-p)
                elif self.train_method == "reverse_KL":
                    q_log_prob_parts = self.log_prob_parts(
                        values=values,
                        latent_values=latent_values,
                        hbm_type="reduced",
                        encodings=encodings,
                        encodings_repeats=self.n_theta_draws
                    )
                    q = self.weighted_sum_log_prob_parts(q_log_prob_parts)
                    loss = tf.reduce_mean(q - p)

        trainable_vars = self.trainable_variables
        for bijector in self.bijectors.values():
            trainable_vars += bijector.trainable_variables
        for bijector_array in self.bijector_arrays.values():
            trainable_vars += bijector_array.trainable_variables

        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(
            zip(
                (
                    tf.where(
                        tf.math.is_nan(
                            grad
                        ),
                        tf.zeros_like(grad),
                        grad
                    )
                    if grad is not None
                    else None
                    for grad in gradients
                ),
                trainable_vars
            )
        )

        return {self.train_method: loss}

    def save_weights_np(
        self,
        path: str
    ) -> None:
        """Saves weights to path in numpy format
        """

        weights = self.get_weights()
        np.save(
            path,
            weights,
            allow_pickle=True
        )

    def load_weights_np(
        self,
        path: str
    ) -> None:
        """Loads weights from path in numpy format
        """
        weights = np.load(
            path,
            allow_pickle=True
        )
        self.set_weights(weights)
