from typing import Dict, Literal, Tuple, List
from collections import defaultdict

import tensorflow as tf
import tensorflow_probability as tfp

from ...normalizing_flow.bijectors import (
    ConditionalAffine
)
from ..graph import (
    get_inference_order
)

tfd = tfp.distributions
tfb = tfp.bijectors


class CascadingFlows(tf.keras.Model):
    """see Ambrogioni et al., 2021"""

    def __init__(
        self,
        generative_hbm: tfd.JointDistributionNamed,
        observed_rvs: List[str],
        link_functions: Dict[str, tfb.Bijector],
        observed_rv_reshapers: Dict[str, tfb.Bijector],
        auxiliary_variables_size: int,
        rff_kwargs: Dict,
        nf_kwargs: Dict,
        amortized: bool,
        auxiliary_target_type: Literal["identity", "MF"],
        **kwargs
    ):
        """Cascading flows architecture for HBM inference

        Parameters
        ----------
        generative_hbm : tfd.JointDistributionNamed
            HBM to perform inference upon
        observed_rvs : List[str]
            [rv] list of observed RVs
        link_functions : Dict[str, tfb.Bijector]
            {rv: link_function}
        observed_rv_reshapers : Dict[str, tfb.Bijector]
            {rv: reshaper}
        auxiliary_variables_size : int
            used to instantiate the auxiliary graph
        rff_kwargs : Dict
            used for observed RV embedders
        nf_kwargs : Dict
            build_trainable_highway_flow kwargs
        amortized : bool
            is the architecture to be amortized
        auxiliary_target_type : Literal["identity", "MF"]
            describes the type of the auxiliary target distribution
            r for the augmented ELBO
        """
        super().__init__(**kwargs)

        self.generative_hbm = generative_hbm
        self.observed_rvs = observed_rvs
        self.link_functions = link_functions
        self.reshapers = observed_rv_reshapers

        self.auxiliary_variables_size = auxiliary_variables_size
        self.amortized = amortized
        self.auxiliary_target_type = auxiliary_target_type

        self.analyse_generative_hbm_graph()
        self.build_architecture(
            rff_kwargs=rff_kwargs,
            nf_kwargs=nf_kwargs
        )

    def analyse_generative_hbm_graph(self) -> None:
        """Analyses dependencies in HBM's graph
        """
        graph = self.generative_hbm.resolve_graph()

        self.prior_rv_order = get_inference_order(
            graph=graph
        )

        self.parents = defaultdict(lambda: list())
        self.children = defaultdict(lambda: list())

        for child, parents in graph:
            for parent in parents:
                self.parents[child].append(parent)
                self.children[parent].append(child)

        self.inverse_rv_order = get_inference_order(
            graph=tuple(
                (rv, tuple(self.children[rv]))
                for rv in self.prior_rv_order
            )
        )

    def build_architecture(
        self,
        rff_kwargs: Dict,
        nf_kwargs: Dict
    ) -> None:
        """Builds whole architecture: auxiliary graph,
        normalizing flows, and observed RV embedders

        Parameters
        ----------
        rff_kwargs : Dict
            used for observed RV embedders
        nf_kwargs : Dict
            build_trainable_highway_flow kwargs
        """
        # ? First deal with the auxiliary model:
        self.auxiliary_coupling_weights = {}
        if self.amortized:
            self.amortizing_bijectors = {}
        auxiliary_model = {}

        for rv in self.inverse_rv_order:
            children = self.children[rv]
            if (
                rv in self.observed_rvs
                and self.amortized
            ):
                self.amortizing_bijectors[rv] = ConditionalAffine(
                    scale_type="none",
                    rff_kwargs=rff_kwargs,
                    event_size=self.auxiliary_variables_size
                )
            if len(children) == 0:
                base_dist = tfd.Independent(
                    tfd.Normal(
                        loc=tf.zeros((self.auxiliary_variables_size,)),
                        scale=1
                    ),
                    reinterpreted_batch_ndims=1
                )
                auxiliary_model[rv] = (
                    tfd.TransformedDistribution(
                        base_dist,
                        bijector=self.amortizing_bijectors[rv]
                    )
                    if rv in self.observed_rvs
                    and self.amortized
                    else
                    base_dist
                )
            else:
                self.auxiliary_coupling_weights[rv] = (
                    tfp.util
                    .TransformedVariable(
                        tf.ones((len(children) + 1,))
                        /
                        (len(children) + 1),
                        bijector=tfb.SoftmaxCentered()
                    )
                )
                auxiliary_model[rv] = eval(
                    f"""lambda {', '.join(children)}: (
                        tfd.TransformedDistribution(
                            tfd.Independent(
                                tfd.Normal(
                                    loc=tf.squeeze(
                                        tf.stack(
                                            [{', '.join(children)}],
                                            axis=-1
                                        )
                                        @
                                        tf.expand_dims(
                                            self.auxiliary_coupling_weights[rv][:-1],
                                            axis=-1
                                        ),
                                        axis=-1
                                    ),
                                    scale=self.auxiliary_coupling_weights[rv][-1]
                                ),
                                reinterpreted_batch_ndims=1
                            ),
                            bijector=self.amortizing_bijectors[rv]
                        )
                        if rv in self.observed_rvs
                        and self.amortized
                        else
                        tfd.Independent(
                            tfd.Normal(
                                loc=tf.squeeze(
                                    tf.stack(
                                        [{', '.join(children)}],
                                        axis=-1
                                    )
                                    @
                                    tf.expand_dims(
                                        self.auxiliary_coupling_weights[rv][:-1],
                                        axis=-1
                                    ),
                                    axis=-1
                                ),
                                scale=self.auxiliary_coupling_weights[rv][-1]
                            ),
                            reinterpreted_batch_ndims=1
                        )
                    )""",
                    {
                        "tfd": tfd,
                        "tf": tf,
                        "self": self,
                        "rv": rv
                    }
                )
        self.auxiliary_model = auxiliary_model

        # ? We then construct the HighwayFlow bijectors
        self.hflows = {}
        self.prior_model = {}
        for rv in self.prior_rv_order:
            if rv in self.observed_rvs:
                continue
            shape = (
                self
                .generative_hbm
                .event_shape
                [rv]
            )
            constrained_shape = (
                self
                .link_functions[rv]
                .inverse_event_shape(
                    shape
                )
            )
            event_size = tf.reduce_prod(
                constrained_shape
            )
            self.reshapers[rv] = tfb.Reshape(
                event_shape_in=(event_size,),
                event_shape_out=constrained_shape
            )

            self.prior_model[rv] = self.generative_hbm.model[rv]

            self.hflows[rv] = (
                tfp.experimental.bijectors
                .build_trainable_highway_flow(
                    width=event_size + self.auxiliary_variables_size,
                    activation_fn=tf.nn.softplus,
                    gate_first_n=event_size,
                    **nf_kwargs
                )
            )

        self.constrainers = {}
        for rv in self.prior_rv_order:
            self.constrainers[rv] = tfb.Chain(
                [
                    self.link_functions[rv],
                    self.reshapers[rv]
                ]
            )

        # ? Finally, we build the auxiliary variational density
        self.auxiliary_target_model = {}
        if self.auxiliary_target_type == "identity":
            for rv in self.prior_rv_order:
                if rv in self.observed_rvs:
                    continue
                self.auxiliary_target_model[rv] = auxiliary_model[rv]
        elif self.auxiliary_target_type == "MF":
            self.auxiliary_MF_locs = {}
            self.auxiliary_MF_scales = {}
            for rv in self.prior_rv_order:
                if rv in self.observed_rvs:
                    continue
                self.auxiliary_MF_locs[rv] = tf.Variable(
                    tf.zeros((self.auxiliary_variables_size,))
                )
                fill_triangular = tfb.FillTriangular()
                self.auxiliary_MF_scales[rv] = tfp.util.TransformedVariable(
                    tf.eye(self.auxiliary_variables_size),
                    bijector=fill_triangular
                )

                self.auxiliary_target_model[rv] = tfd.TransformedDistribution(
                    tfd.MultivariateNormalDiag(
                        loc=tf.zeros((self.auxiliary_variables_size,)),
                        scale_diag=tf.ones((self.auxiliary_variables_size,))
                    ),
                    tfb.Chain([
                        tfb.Shift(self.auxiliary_MF_locs[rv]),
                        tfb.ScaleMatvecTriL(self.auxiliary_MF_scales[rv])
                    ])
                )

    def sample_parameters_conditioned_to_data(
        self,
        data: Dict[str, tf.Tensor],
        return_internals: bool = False
    ) -> Tuple[Dict[str, tf.Tensor], ...]:
        """samples from auxiliary graph, then from prior,
        cascading NF transforms on the prior sample values

        Parameters
        ----------
        data : Dict[str, tf.Tensor]
            {observed_rv: value}
        return_internals : bool, optional
            return intermediate results useful for
            other methods, by default False

        Returns
        -------
        Tuple[Dict[str, tf.Tensor], ...]
            {latent_rv: value}
        """
        batch_size = list(data.values())[0].shape[0]

        auxiliary_values = {}
        for rv in self.inverse_rv_order:
            if (
                self.amortized
                and
                rv in self.observed_rvs
            ):
                conditional_dict = dict(
                    bijector_kwargs=dict(
                        conditional_input=(
                            self.constrainers[rv]
                            .inverse(
                                data[rv]
                            )
                        )
                    )
                )
            else:
                conditional_dict = dict()

            if issubclass(
                type(self.auxiliary_model[rv]),
                tfd.Distribution
            ):
                auxiliary_values[rv] = (
                    self.auxiliary_model[rv]
                    .sample(
                        (batch_size or 1,),
                        **conditional_dict
                    )
                )
            else:
                auxiliary_values[rv] = (
                    self.auxiliary_model[rv](
                        *[
                            auxiliary_values[child]
                            for child in self.children[rv]
                        ]
                    )
                    .sample(
                        tuple(),
                        **conditional_dict
                    )
                )

        augmented_prior_values = {}
        augmented_posterior_values = {}
        sample = {}
        for rv in self.prior_rv_order:
            if rv in self.observed_rvs:
                continue
            prior_value = self.constrainers[rv].inverse(
                self.prior_model[rv].sample((batch_size or 1,))
                if issubclass(type(self.prior_model[rv]), tfd.Distribution)
                else
                self.prior_model[rv](
                    *[
                        {
                            **data,
                            **sample
                        }[parent]
                        for parent in self.parents[rv]
                    ]
                ).sample()
            )

            augmented_prior_values[rv] = tf.concat(
                [
                    prior_value,
                    auxiliary_values[rv]
                ],
                axis=-1
            )
            augmented_posterior_values[rv] = (
                self.hflows[rv]
                .forward(augmented_prior_values[rv])
            )
            sample[rv] = self.constrainers[rv].forward(
                augmented_posterior_values[rv]
                [..., :-self.auxiliary_variables_size]
            )

        if return_internals:
            return (
                sample,
                augmented_posterior_values,
                augmented_prior_values,
                auxiliary_values
            )
        else:
            return (sample,)

    def joint_log_prob_conditioned_to_data(
        self,
        data: Dict[str, tf.Tensor],
        augmented_posterior_values: Dict[str, tf.Tensor],
        auxiliary_values: Dict[str, tf.Tensor]
    ) -> tf.Tensor:
        """Needs auxiliary values to compute latent RVs log prob conditional
        on observed RV's value

        Parameters
        ----------
        data : Dict[str, tf.Tensor]
            {rv: value}
        augmented_posterior_values : Dict[str, tf.Tensor]
            {rv: augmented_value}
        auxiliary_values : Dict[str, tf.Tensor]
            {rv: auxiliary_value}

        Returns
        -------
        tf.Tensor
            log prob tensor
        """

        batch_size = list(data.values())[0].shape[0]

        log_prob = 0.
        for rv in self.prior_rv_order:
            if rv in self.observed_rvs:
                continue
            prior_dist = tfd.TransformedDistribution(
                tfd.BatchBroadcast(
                    self.prior_model[rv],
                    to_shape=(batch_size,)
                )
                if issubclass(type(self.prior_model[rv]), tfd.Distribution)
                else
                self.prior_model[rv](
                    *[
                        data[parent]
                        for parent in self.parents[rv]
                    ]
                ),
                bijector=tfb.Invert(
                    self.constrainers[rv]
                )
            )
            auxiliary_dist = (
                self.auxiliary_model[rv]
                if issubclass(type(self.auxiliary_model[rv]), tfd.Distribution)
                else
                self.auxiliary_model[rv](
                    *[
                        auxiliary_values[child]
                        for child in self.children[rv]
                    ]
                )
            )

            augmented_prior_dist = tfd.Blockwise(
                [
                    prior_dist,
                    auxiliary_dist
                ]
            )
            augmented_posterior_dist = tfd.TransformedDistribution(
                augmented_prior_dist,
                bijector=self.hflows[rv]
            )

            log_prob += augmented_posterior_dist.log_prob(
                augmented_posterior_values[rv]
            )

        return log_prob

    def MF_log_prob(
        self,
        augmented_posterior_values: Dict[str, tf.Tensor],
        auxiliary_values: Dict[str, tf.Tensor]
    ) -> tf.Tensor:
        """target auxiliary density log prob for
        augmented ELBO

        Parameters
        ----------
        augmented_posterior_values : Dict[str, tf.Tensor]
            {rv: augmented_value}
        auxiliary_values : Dict[str, tf.Tensor]
            {rv: auxiliary_value}

        Returns
        -------
        tf.Tensor
            log prob tensor
        """
        batch_size = list(
            augmented_posterior_values
            .values()
        )[0].shape[0]

        log_prob = 0.
        for rv in self.prior_rv_order:
            if rv in self.observed_rvs:
                continue
            if self.auxiliary_target_type == "identity":
                auxiliary_target_dist = (
                    tfd.Sample(
                        self.auxiliary_target_model[rv],
                        sample_shape=(batch_size,)
                    )
                    if issubclass(
                        type(self.auxiliary_target_model[rv]),
                        tfd.Distribution
                    )
                    else
                    self.auxiliary_target_model[rv](
                        *[
                            auxiliary_values[child]
                            for child in self.children[rv]
                        ]
                    )
                )
                log_prob += (
                    auxiliary_target_dist
                    .log_prob(
                        augmented_posterior_values[rv]
                        [..., -self.auxiliary_variables_size:]
                    )
                )
            elif self.auxiliary_target_type == "MF":
                log_prob += (
                    self.auxiliary_target_model[rv]
                    .log_prob(
                        augmented_posterior_values[rv]
                        [..., -self.auxiliary_variables_size:]
                    )
                )
        return log_prob

    def compile(
        self,
        train_method: Literal[
            "reverse_KL",
            "unregularized_ELBO"
        ],
        n_theta_draws_per_x: int,
        **kwargs
    ) -> None:
        """Wrapper for Keras compilation, additionally
        specifying hyper-parameters

        Parameters
        ----------
        train_method : Literal[
            "reverse_KL",
            "unregularized_ELBO"
        ]
            defines which loss to use during training
        n_theta_draws_per_x : int
            for Monte Carlo estimation of the ELBO
            (not used if train_method = "forwardKL")

        Raises
        ------
        NotImplementedError
            train_method not in [
                "reverse_KL",
                "unregularized_ELBO"
            ]
        """
        if train_method not in [
            "reverse_KL",
            "unregularized_ELBO"
        ]:
            raise NotImplementedError(
                f"unrecognized train method {train_method}"
            )
        self.train_method = train_method
        self.n_theta_draws_per_x = n_theta_draws_per_x

        super().compile(**kwargs)

    def train_step(
        self,
        train_data: Tuple[Dict[str, tf.Tensor]]
    ) -> Dict[str, tf.Tensor]:
        """Performs a train step on training data (behavior depends
        on compiled train method)

        Parameters
        ----------
        train_data : Tuple[Dict[str, tf.Tensor]]
            tuple containing the {rv: value} dict
            corresponding to the train data batch

        Returns
        -------
        Dict[str, tf.Tensor]
            {train_method: loss_value}
        """
        data = train_data[0]
        if self.train_method in [
            "reverse_KL",
            "unregularized_ELBO"
        ]:
            repeated_rvs = {
                rv: tf.repeat(
                    value,
                    repeats=(self.n_theta_draws_per_x,),
                    axis=0
                )
                for rv, value in data.items()
            }
            with tf.GradientTape() as tape:
                (
                    parameters_sample,
                    augmented_posterior_values,
                    _,
                    auxiliary_values
                ) = self.sample_parameters_conditioned_to_data(
                    data=repeated_rvs,
                    return_internals=True
                )

                p = self.generative_hbm.log_prob(
                    **parameters_sample,
                    **{
                        observed_rv: repeated_rvs[observed_rv]
                        for observed_rv in self.observed_rvs
                    }
                )

                if self.train_method == "unregularized_ELBO":
                    loss = tf.reduce_mean(-p)
                else:
                    r = self.MF_log_prob(
                        augmented_posterior_values=augmented_posterior_values,
                        auxiliary_values=auxiliary_values,
                    )

                    q = (
                        self
                        .joint_log_prob_conditioned_to_data(
                            data={
                                **parameters_sample,
                                **{
                                    observed_rv: repeated_rvs[observed_rv]
                                    for observed_rv in self.observed_rvs
                                }
                            },
                            augmented_posterior_values=(
                                augmented_posterior_values
                            ),
                            auxiliary_values=auxiliary_values
                        )
                    )

                    loss = tf.reduce_mean(q - p - r)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(
            zip(
                (
                    tf.where(
                        tf.math.is_nan(
                            grad
                        ),
                        tf.zeros_like(grad),
                        grad
                    )
                    if grad is not None
                    else None
                    for grad in gradients
                ),
                trainable_vars
            )
        )

        return {self.train_method: loss}
