from typing import Dict, Literal, Tuple, List, Union

import tensorflow as tf
import tensorflow_probability as tfp

from ...utils import repeat_to_shape
from ...normalizing_flow.bijectors import (
    ConditionalNFChain,
)

tfd = tfp.distributions
tfb = tfp.bijectors


class UIVFamily(tf.keras.Model):

    def __init__(
        self,
        generative_hbm: tfd.JointDistributionNamed,
        link_functions: Dict[str, tfb.Bijector],
        observed_rv: str,
        conditional_nf_chain_kwargs: Dict,
        embedding_RV_size: int,
        **kwargs
    ) -> None:
        """Creates an Unbiased Implicit (UIVI) variational family.
        UIVI is an unstructured implicit VI architecture. UIVI infers over the full
        parameter space by reparameterizing a base distribution with a stochastic transform.

        Parameters
        ----------
        generative_hbm (tfd.JointDistributionNamed):
            prior model over which inference is performed
        link_functions (Dict[str, tfb.Bijector]):
            {rv: bijector} functions projecting the RV's event space into an unbounded real space
        observed_rv (str):
            observed RV, must be one of the generative HBM's keys
        conditional_nf_chain_kwargs (Dict):
            ConditionalNFChain kwargs for the stochastic transform h
        embedding_RV_size (int):
            size of the RV epsilon that will randomize the transform h
        """
        super().__init__(**kwargs)

        self.generative_hbm = generative_hbm
        self.link_functions = link_functions
        self.observed_rv = observed_rv

        self.embedding_RV_size = embedding_RV_size
        self.build_architecture(
            conditional_nf_chain_kwargs=conditional_nf_chain_kwargs
        )

    def build_architecture(
        self,
        conditional_nf_chain_kwargs: Dict
    ) -> None:
        """Builds architecture: conditional NF chain and inner distributions

        Args:Z
            conditional_nf_chain_kwargs (Dict):
                ConditionalNFChain kwargs for the stochastic transform h
        """

        self.latent_rvs = [
            rv
            for rv in self.generative_hbm.event_shape.keys()
            if rv != self.observed_rv
        ]

        constrainers = {}
        event_sizes = {}
        for rv in self.latent_rvs:
            shape = (
                self
                .generative_hbm
                .event_shape
                [rv]
            )
            constrained_shape = (
                self
                .link_functions[rv]
                .inverse_event_shape(
                    shape
                )
            )
            event_size = tf.reduce_prod(
                constrained_shape
            )
            event_sizes[rv] = event_size
            reshaper = tfb.Reshape(
                event_shape_in=(event_size,),
                event_shape_out=constrained_shape
            )
            constrainers[rv] = tfb.Chain(
                [
                    self.link_functions[rv],
                    reshaper
                ]
            )

        self.splitter = tfb.Split(
            [
                event_sizes[rv]
                for rv in self.latent_rvs
            ]
        )

        self.restructurer = tfb.Restructure(
            {
                rv: rank
                for rank, rv in enumerate(
                    self.latent_rvs
                )
            }
        )

        self.constrainer = tfb.JointMap(
            constrainers
        )

        total_event_size = int(
            tf.reduce_sum(
                [
                    event_sizes[rv]
                    for rv in self.latent_rvs
                ]
            )
        )
        self.base_dist = tfd.Independent(
            tfd.Normal(
                loc=tf.zeros((total_event_size,)),
                scale=1
            ),
            reinterpreted_batch_ndims=1,
            name="u"
        )

        self.conditional_nf_chain = ConditionalNFChain(
            **conditional_nf_chain_kwargs,
            event_size=total_event_size,
            conditional_event_size=self.embedding_RV_size,
            name="h"
        )

        self.event_bijector = tfb.Chain(
            [
                self.constrainer,
                self.restructurer,
                self.splitter,
                self.conditional_nf_chain
            ]
        )

        self.transformed_dist = tfd.TransformedDistribution(
            self.base_dist,
            bijector=self.event_bijector,
            name="z"
        )

        self.embedded_dist = tfd.Independent(
            tfd.Normal(
                loc=tf.zeros((self.embedding_RV_size,)),
                scale=1
            ),
            reinterpreted_batch_ndims=1,
            name="epsilon"
        )

    def sample(
        self,
        sample_shape: Tuple[int, ...],
        return_epsilon: bool = False
    ) -> Union[
        Dict[str, tf.Tensor],
        Tuple[Dict[str, tf.Tensor], tf.Tensor]
    ]:
        """Samples from the base distribution u and espilon,
        and apply transform to generate h

        Args:
            sample_shape (Tuple[int, ...]):
                shape of the sample
            return_epsilon (bool, optional):
                return the embedding distribution value. Defaults to False.

        Returns:
            Union[ Dict[str, tf.Tensor], Tuple[Dict[str, tf.Tensor], tf.Tensor] ]:
                z or (z, epsilon) tensors
        """

        epsilon = self.embedded_dist.sample(sample_shape)

        z = self.transformed_dist.sample(
            sample_shape,
            bijector_kwargs=dict(
                h=dict(
                    conditional_input=epsilon
                )
            )
        )

        if return_epsilon:
            return z, epsilon
        else:
            return z

    def log_prob_parts_ze(
        self,
        z: Dict[str, tf.Tensor],
        epsilon: tf.Tensor
    ) -> Dict[str, tf.Tensor]:
        """computes log probs of innfer distributions

        Args:
            z (Dict[str, tf.Tensor]):
                {rv: value} sample
            epsilon (tf.Tensor):
                embedding distribution sample

        Returns:
            Dict[str, tf.Tensor]:
                {"z": value, "epsilon": value} log probs dict
        """

        q_z = self.transformed_dist.log_prob(
            z,
            bijector_kwargs=dict(
                h=dict(
                    conditional_input=epsilon
                )
            )
        )

        q_e = self.embedded_dist.log_prob(epsilon)

        return {
            "z": q_z,
            "epsilon": q_e
        }

    def log_prob_ze(
        self,
        z: Dict[str, tf.Tensor],
        epsilon: tf.Tensor
    ) -> tf.Tensor:
        """Calls log_prob_parts_ze and sums out results
        for z and espilon

        Returns:
            tf.Tensor: joint log prob value
        """
        return tf.reduce_sum(
            [
                q
                for q in self.log_prob_parts_ze(
                    z, epsilon
                ).values()
            ]
        )

    def run_mcmc(
        self,
        z: Dict[str, tf.Tensor],
        epsilon=tf.Tensor
    ) -> tf.Tensor:
        """Draws uncorrelated samples epsilon from q(espilon | z)
        via an HMC run starting at the epsilon value

        Args:
            z (Dict[str, tf.Tensor]):
                parameter sample for the joint density
            epsilon (tf.Tensor):
                Starting epsilon point, also used for joint density

        Returns:
            tf.Tensor:
                tensor of epsilons of shape (
                    self.n_mcmc_samples,
                    self.n_theta_draws,
                    self.embedding_RV_size
                )
        """

        return tfp.mcmc.sample_chain(
            num_results=self.n_mcmc_samples,
            num_burnin_steps=self.t_mcmc_burn_in,
            current_state=epsilon,
            kernel=tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=lambda e: self.log_prob_ze(z, e),
                step_size=0.1,
                num_leapfrog_steps=5
            )
        ).all_states

    def q_z(
        self,
        z: Dict[str, tf.Tensor],
        epsilon: tf.Tensor
    ) -> tf.Tensor:
        """Samples uncorrelated espilon samples from
        q(epsilon | z) and then estimates via Monte Carlo q(z)

        Args:
            z (Dict[str, tf.Tensor]):
                {rv: value} parameter sample
            epsilon (tf.Tensor):
                used for MCMC run and joint density

        Returns:
            tf.Tensor:
                Monte Carlo estimated log prob for z
        """

        epsilons_given_z = (
            self
            .run_mcmc(z, epsilon)
        )

        q_z_list = [
            self.log_prob_parts_ze(z, e)["z"]
            for e in epsilons_given_z
        ]

        return tf.reduce_mean(q_z_list, axis=0)

    def compile(
        self,
        n_theta_draws: int,
        t_mcmc_burn_in: int,
        n_mcmc_samples: int,
        **kwargs
    ) -> None:
        """Wrapper for Keras compilation, additionally
        specifying hyper-parameters

        Args:
            n_theta_draws (int):
                for Monte Carlo estimation of the ELBO
            t_mcmc_burn_in (int):
                for inner MCMC epsilon run
            n_mcmc_samples (int):
                for inner MCMC epsilon run
        """

        self.n_theta_draws = n_theta_draws
        self.t_mcmc_burn_in = t_mcmc_burn_in
        self.n_mcmc_samples = n_mcmc_samples

        super().compile(**kwargs)

    def train_step(
        self,
        train_data: Tuple[Dict[str, tf.Tensor]]
    ) -> Dict[str, tf.Tensor]:
        """keras train step to be compiled

        Parameters
        ----------
        data : Tuple[tf.Tensor]
            full data point

        Returns
        -------
        Dict
            {loss_type: loss_value}
        """

        data = train_data[0]
        x = data[self.observed_rv]

        repeated_x = tf.repeat(
            x,
            repeats=(self.n_theta_draws,),
            axis=0
        )

        with tf.GradientTape() as tape:
            z, epsilon = self.sample(
                sample_shape=(self.n_theta_draws,),
                return_epsilon=True
            )

            p = self.generative_hbm.log_prob(
                {
                    **z,
                    self.observed_rv: repeated_x
                }
            )

            with tape.stop_recording():
                epsilons_given_z = (
                    self
                    .run_mcmc(z, epsilon)
                )

            q_zs = self.log_prob_parts_ze(
                z, epsilons_given_z
            )["z"]

            q = tf.reduce_mean(q_zs, axis=0)

            loss = tf.reduce_mean(q - p)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(
            zip(gradients, trainable_vars)
        )

        return {"reverse_KL": loss}
