from typing import Dict, Union, Literal, Tuple, Optional

import tensorflow as tf
import tensorflow_probability as tfp

from ...set_transformer.models import (
    SetTransformer
)
from ...normalizing_flow.bijectors import (
    ConditionalNFChain
)

tfd = tfp.distributions
tfb = tfp.bijectors


class ADAVFamily(tf.keras.Model):

    def __init__(
        self,
        generative_hbm: Union[
            tfd.JointDistributionNamed,
            tfd.JointDistributionCoroutine
        ],
        set_transforer_kwargs: Dict,
        conditional_nf_chain_kwargs: Dict,
        hierarchies: Dict,
        link_functions: Dict,
        **kwargs
    ):
        """Automatically constructs the amortized dual variational
        family from the input generative_hbm

        Parameters
        ----------
        generative_hbm : Union[
            tfd.JointDistributionNamed,
            tfd.JointDistributionCoroutine
        ]
            generative Hierarchical Bayesian Model
            on which to perform inference
        set_transforer_kwargs : Dict
            for hierarchical encoder
        conditional_nf_chain_kwargs : Dict
            for conditional density estimators
        hierarchies : Dict
            dict of {key: hierarchy} for all keys
            in the generative_hbm
        link_functions : Dict
            dict of {key: tfb.Bijector} for all keys
            in the generative_hbm
        """
        super().__init__(**kwargs)

        self.generative_hbm = generative_hbm
        self.hierarchies = hierarchies
        self.link_functions = link_functions

        self.analyse_generative_hbm_graph()
        self.build_architecture(
            set_transforer_kwargs=set_transforer_kwargs,
            conditional_nf_chain_kwargs=conditional_nf_chain_kwargs
        )

    def analyse_generative_hbm_graph(self) -> None:
        """Creates internals related to hierarchies
        """
        self.max_hierarchy = max(
            self
            .hierarchies
            .values()
        )
        self.keys_per_hierarchy = {
            h: [
                key
                for key, value in (
                    self
                    .hierarchies
                    .items()
                )
                if value == h
            ]
            for h in range(self.max_hierarchy + 1)
        }
        self.input_key = self.keys_per_hierarchy.pop(0)[0]

    def build_architecture(
        self,
        set_transforer_kwargs: Dict,
        conditional_nf_chain_kwargs: Dict
    ) -> None:
        """Creates parametric hierarchical encoder
        and conditional density estimators

        Parameters
        ----------
        set_transforer_kwargs : Dict
            for hierarchical encoder
        conditional_nf_chain_kwargs : Dict
            for conditional density estimators
        """

        self.set_transformers = {}
        self.conditional_density_estimators = {}
        self.exp_conditional_affine_density_estimators = {}

        for h in range(1, self.max_hierarchy + 1):
            self.set_transformers[h] = SetTransformer(
                **set_transforer_kwargs,
                attention_axes=(-3,)
            )

            embedding_size = (
                self
                .set_transformers[h]
                .embedding_size
            )

            for key in self.keys_per_hierarchy[h]:
                shape = (
                    self
                    .generative_hbm
                    .event_shape
                    [key]
                    if (
                        type(self.generative_hbm)
                        ==
                        tfd.JointDistributionNamed
                    )
                    else
                    self
                    .generative_hbm
                    .event_shape
                    ._asdict()
                    [key]
                )

                constrained_shape = (
                    self
                    .link_functions[key]
                    .inverse_event_shape(
                        shape
                    )
                )

                batch_shape = constrained_shape[:self.max_hierarchy - h]
                event_size = tf.reduce_prod(
                    constrained_shape[self.max_hierarchy - h:]
                )
                latent_shape = (
                    batch_shape
                    +
                    (event_size,)
                )

                latent_distribution = tfd.Independent(
                    tfd.Normal(
                        loc=tf.zeros(latent_shape),
                        scale=1.0
                    ),
                    reinterpreted_batch_ndims=self.max_hierarchy - h + 1
                )

                reshaper = tfb.Reshape(
                    event_shape_out=constrained_shape,
                    event_shape_in=latent_shape
                )

                nf = ConditionalNFChain(
                    event_size=event_size.numpy(),
                    conditional_event_size=embedding_size,
                    name=f"nf_{key}",
                    **conditional_nf_chain_kwargs
                )

                self.conditional_density_estimators[key] = (
                    tfd.TransformedDistribution(
                        distribution=latent_distribution,
                        bijector=tfb.Chain(
                            bijectors=[
                                self.link_functions[key],
                                reshaper,
                                nf
                            ]
                        )
                    )
                )

                # ! EXPERIMENTAL - needs to be refactored
                # assumes that conditional_nf_chain.bijectors[-1]
                # is a ConditionalAffine bijector
                self.exp_conditional_affine_density_estimators[key] = (
                    tfd.TransformedDistribution(
                        distribution=latent_distribution,
                        bijector=tfb.Chain(
                            bijectors=[
                                self.link_functions[key],
                                reshaper,
                                nf.bijectors[-1]
                            ]
                        )
                    )
                )

    def encode_data(
        self,
        x: tf.Tensor
    ) -> Dict:
        """Encodes input data x
        via stacked SetTransformers

        Parameters
        ----------
        x : tf.Tensor
            shape:
                (batch_size,)
                + (
                    Card(P_p),
                    ...,
                    Card(P_0)
                )
                + Shape(x)

        Returns
        -------
        Dict
            encodings from various hierarchies
            shapes: {
                hierarchy:
                    (batch_size,)
                    + (
                        Card(P_p),
                        ...,
                        Card(P_hierarchy)
                    )
                    + (e_hierarchy,)
            }
        """
        z = (
            self
            .link_functions[self.input_key]
            .inverse(x)
        )
        encodings = {}
        for h in range(1, self.max_hierarchy + 1):
            z = self.set_transformers[h](z)
            z = tf.squeeze(z, axis=-2)

            encodings[h] = z

        return encodings

    def sample_parameters_conditioned_to_encodings(
        self,
        encodings: Dict
    ) -> Dict:
        """sample a single point from conditional density estimators

        Parameters
        ----------
        encodings : Dict
            encodings from various hierarchies
            shapes: {
                hierarchy:
                    (batch_size,)
                    + (
                        Card(P_p),
                        ...,
                        Card(P_hierarchy)
                    )
                    + (e_hierarchy,)
            }

        Returns
        -------
        Dict
            sample
            shapes: {
                key:
                    (batch_size,)
                    + (
                        Card(P_p),
                        ...,
                        Card(P_hierarchy)
                    )
                    + Shape(key)
            }
        """
        batch_size = encodings[1].shape[0]
        sample = {}
        for h in range(1, self.max_hierarchy + 1):
            for key in self.keys_per_hierarchy[h]:
                sample[key] = (
                    self
                    .conditional_density_estimators[key]
                    .sample(
                        (batch_size,),
                        bijector_kwargs={
                            f"nf_{key}": dict(
                                conditional_input=encodings[h]
                            )
                        }
                    )
                )

        return sample

    def sample_parameters_conditioned_to_data(
        self,
        x: tf.Tensor
    ) -> Dict:
        """Wrapper for encode_data
        followed by sample_parameters_conditioned_to_encodings
        """
        encodings = self.encode_data(x)

        return self.sample_parameters_conditioned_to_encodings(encodings)

    def parameters_log_prob_conditioned_to_encodings(
        self,
        parameters: Dict,
        encodings: Dict
    ) -> tf.Tensor:
        """Calculate posterior log prob for the parameters

        Parameters
        ----------
        parameters : Dict
            shapes: {
                key:
                    (batch_size,)
                    + (
                        Card(P_p),
                        ...,
                        Card(P_hierarchy)
                    )
                    + Shape(key)
            }
        encodings : Dict
            encodings from various hierarchies
            shapes: {
                hierarchy:
                    (batch_size,)
                    + (
                        Card(P_p),
                        ...,
                        Card(P_hierarchy)
                    )
                    + (e_hierarchy,)
            }

        Returns
        -------
        tf.Tensor
            shape: (batch_size,)
        """
        log_prob = 0
        for h in range(1, self.max_hierarchy + 1):
            for key in self.keys_per_hierarchy[h]:
                log_prob += (
                    self
                    .conditional_density_estimators[key]
                    .log_prob(
                        parameters[key],
                        bijector_kwargs={
                            f"nf_{key}": dict(
                                conditional_input=encodings[h]
                            )
                        }
                    )
                )

        return log_prob

    def parameters_log_prob_conditioned_to_data(
        self,
        parameters: Dict,
        x: tf.Tensor
    ) -> tf.Tensor:
        """Wrapper for encode_data
        followed by parameters_log_prob_conditioned_to_encodings
        """
        encodings = self.encode_data(x)

        return self.parameters_log_prob_conditioned_to_encodings(
            parameters=parameters,
            encodings=encodings
        )

    def exp_affine_MAP_regression_conditioned_to_encodings(
        self,
        encodings: Dict
    ) -> Dict:
        """# ! EXPERIMENTAL - needs to be refactored
        Assumes conditional_nf_chains.bijectors[-1] to be a
        ConditionalAffine bijector, from which we retrieve the shift
        regressor and apply it to the encodings

        Parameters
        ----------
        encodings : Dict
            encodings from various hierarchies
            shapes: {
                hierarchy:
                    (batch_size,)
                    + (
                        Card(P_p),
                        ...,
                        Card(P_hierarchy)
                    )
                    + (e_hierarchy,)
            }

        Returns
        -------
        Dict
            map values
            shapes: {
                key:
                    (batch_size,)
                    + (
                        Card(P_p),
                        ...,
                        Card(P_hierarchy)
                    )
                    + Shape(key)
            }
        """
        map_values = {}
        for h in range(1, self.max_hierarchy + 1):
            for key in self.keys_per_hierarchy[h]:
                map_values[key] = tfb.Chain(
                    bijectors=(
                        self
                        .conditional_density_estimators[key]
                        .bijector
                        .bijectors[:-1]
                    )
                )(
                    self
                    .conditional_density_estimators[key]
                    .bijector
                    .bijectors[-1]
                    .bijectors[-1]
                    .shift(
                        encodings[h]
                    )
                )
        return map_values

    def exp_affine_sample_parameters_conditioned_to_encodings(
        self,
        encodings: Dict
    ) -> Dict:
        """# ! EXPERIMENTAL - needs to be refactored
        sample a single point from conditional affine density estimators

        Parameters
        ----------
        encodings : Dict
            encodings from various hierarchies
            shapes: {
                hierarchy:
                    (batch_size,)
                    + (
                        Card(P_p),
                        ...,
                        Card(P_hierarchy)
                    )
                    + (e_hierarchy,)
            }

        Returns
        -------
        Dict
            sample
            shapes: {
                key:
                    (batch_size,)
                    + (
                        Card(P_p),
                        ...,
                        Card(P_hierarchy)
                    )
                    + Shape(key)
            }
        """
        batch_size = encodings[1].shape[0]
        sample = {}
        for h in range(1, self.max_hierarchy + 1):
            for key in self.keys_per_hierarchy[h]:
                sample[key] = (
                    self
                    .exp_conditional_affine_density_estimators[key]
                    .sample(
                        (batch_size,),
                        bijector_kwargs={
                            (
                                self
                                .exp_conditional_affine_density_estimators[key]
                                .bijector
                                .bijectors[-1]
                                .name
                            ): dict(
                                conditional_input=encodings[h]
                            )
                        }
                    )
                )

        return sample

    def compile(
        self,
        train_method: Literal[
            "forward_KL",
            "reverse_KL",
            "unregularized_ELBO",
            "exp_MAP_regression",
            "exp_affine_unregularized_ELBO"
        ],
        n_theta_draws_per_x: int,
        **kwargs
    ) -> None:
        """Wrapper for Keras compilation, additionally
        specifying hyper-parameters

        Parameters
        ----------
        train_method : Literal[
            "forward_KL",
            "reverse_KL",
            "unregularized_ELBO",
            "exp_MAP_regression",
            "exp_affine_unregularized_ELBO"
        ]
            defines which loss to use during training
        n_theta_draws_per_x : int
            for Monte Carlo estimation of the ELBO
            (not used if train_method = "forwardKL")

        Raises
        ------
        NotImplementedError
            train_method not in [
                "forward_KL",
                "reverse_KL",
                "unregularized_ELBO",
                "exp_MAP_regression",
                "exp_affine_unregularized_ELBO"
            ]
        """
        if train_method not in [
            "forward_KL",
            "reverse_KL",
            "unregularized_ELBO",
            "exp_MAP_regression",
            "exp_affine_unregularized_ELBO"
        ]:
            raise NotImplementedError(
                f"unrecognized train method {train_method}"
            )
        self.train_method = train_method
        self.n_theta_draws_per_x = n_theta_draws_per_x

        super().compile(**kwargs)

    def train_step(
        self,
        data: Tuple[Dict]
    ) -> Dict:
        """Keras train step

        Parameters
        ----------
        data : Tuple[Dict]
            data from the generative_hbm
            various keys will be used depending
            on the train_method

        Returns
        -------
        Dict
            {loss_type: value} depending on
            the train method
        """
        x = data[0][self.input_key]
        if self.train_method == "forward_KL":
            with tf.GradientTape() as tape:
                loss = tf.reduce_mean(
                    - self.parameters_log_prob_conditioned_to_data(
                        parameters=data[0],
                        x=x
                    )
                )
        elif self.train_method in [
            "reverse_KL",
            "unregularized_ELBO"
        ]:
            repeated_x = tf.repeat(
                x,
                repeats=(self.n_theta_draws_per_x,),
                axis=0
            )
            with tf.GradientTape() as tape:
                encodings = self.encode_data(
                    x=repeated_x
                )
                parameters_sample = (
                    self
                    .sample_parameters_conditioned_to_encodings(
                        encodings=encodings
                    )
                )
                p = self.generative_hbm.log_prob(
                    **parameters_sample,
                    **{
                        self.input_key: repeated_x
                    }
                )
                if self.train_method == "unregularized_ELBO":
                    loss = tf.reduce_mean(-p)
                else:
                    q = (
                        self
                        .parameters_log_prob_conditioned_to_encodings(
                            parameters=parameters_sample,
                            encodings=encodings
                        )
                    )
                    loss = tf.reduce_mean(q - p)
        elif self.train_method == "exp_MAP_regression":
            with tf.GradientTape() as tape:
                encodings = self.encode_data(
                    x=x
                )
                map_values = (
                    self
                    .exp_affine_MAP_regression_conditioned_to_encodings(
                        encodings=encodings
                    )
                )
                p = self.generative_hbm.log_prob(
                    **map_values,
                    **{
                        self.input_key: x
                    }
                )
                loss = tf.reduce_mean(-p)
        elif self.train_method == "exp_affine_unregularized_ELBO":
            repeated_x = tf.repeat(
                x,
                repeats=(self.n_theta_draws_per_x,),
                axis=0
            )
            with tf.GradientTape() as tape:
                encodings = self.encode_data(
                    x=repeated_x
                )
                parameters_sample = (
                    self
                    .exp_affine_sample_parameters_conditioned_to_encodings(
                        encodings=encodings
                    )
                )
                p = self.generative_hbm.log_prob(
                    **parameters_sample,
                    **{
                        self.input_key: repeated_x
                    }
                )
                loss = tf.reduce_mean(-p)

        trainable_vars = self.trainable_variables
        for estimator in self.conditional_density_estimators.values():
            trainable_vars += estimator.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        return {self.train_method: loss}
