import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd


class HybridSummaryNetwork(tf.keras.Model):
    """Represents a network which in addition to learnable summaries embeds expert knowledge."""
    def __init__(self, num_expert_summaries, summary_net, non_linear=False, non_linear_kwargs=None, **kwargs):
        super().__init__(**kwargs)

        self.num_expert_summaries = num_expert_summaries
        self.summary_net = summary_net

        self.embedder = tf.keras.models.Sequential()
        if non_linear:
            self._configure_hidden(non_linear_kwargs)
        self.embedder.add(tf.keras.layers.Dense(num_expert_summaries))

    def call(self, input_tuple, **kwargs):
        """Forward pass through summary and embedding networks."""

        raw_data, expert_summaries = input_tuple
        summary = self.summary_net(raw_data, **kwargs)
        embedding = self.embedder(expert_summaries, **kwargs)
        out = tf.concat([summary, embedding], axis=-1)
        return out

    def _configure_hidden(self, non_linear_kwargs):
        """Configures the hidden layer structure of the embedding network."""

        for hidden_dim in non_linear_kwargs.get('hidden_dims'):
            self.embedder.add(tf.keras.layers.Dense(hidden_dim, activation=non_linear_kwargs['activation']))
            if non_linear_kwargs['dropout'] is not None:
                self.embedder.add(tf.keras.layers.Dropout(non_linear_kwargs['dropout']))


class AmortizedHybrid(tf.keras.Model):
    """Represents an amortizer which learns a conditional generative model for the summary statistics."""

    def __init__(self, inference_net, summary_net, summary_learner, latent_dist_params=None, latent_dist_summary=None):
        super().__init__()

        self.inference_net = inference_net
        self.summary_net = summary_net
        self.summary_learner = summary_learner
        self.latent_dim = self.inference_net.latent_dim
        self.summary_dim = self.summary_learner.latent_dim

        # Default Gaussian latent dist for parameters
        if latent_dist_params is None:
            self.latent_dist_params = tfp.distributions.MultivariateNormalDiag(loc=[0.0] * self.latent_dim)
        else:
            self.latent_dist_params = latent_dist_params

        # Default Student-T dist for summary space
        if latent_dist_summary is None:
            loc = [0.] * self.summary_dim
            scale = tf.linalg.LinearOperatorDiag([1.] * self.summary_dim)
            self.latent_dist_summary = tfd.MultivariateStudentTLinearOperator(df=100, loc=loc, scale=scale)
        else:
            self.latent_dist_summary = latent_dist_summary

    def call(self, input_dict, **kwargs):
        """Forward pass through the networks."""

        # Compute summary
        summary_out = self.summary_net(input_dict['summary_conditions'], **kwargs)

        # Compute z-distro of summary outputs
        out_learner = self.summary_learner(
            tf.stop_gradient(summary_out), input_dict['direct_conditions'], **kwargs)

        out_learner_2, _ = self.summary_learner(
            summary_out, input_dict['direct_conditions'], **kwargs)

        # Create full conditions
        full_condition = tf.concat([out_learner_2, input_dict['direct_conditions']], axis=-1)

        # Compute z-distro of targets
        out_inference = self.inference_net(input_dict['parameters'], full_condition, **kwargs)

        return out_learner, out_inference

    def compute_loss(self, input_dict, **kwargs):
        """Computes the loss of the posterior amortizer given an input dictionary, which will
        typically be the output of a Bayesian ``GenerativeModel`` instance.

        Parameters
        ----------
        input_dict : dict
            Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
            ``parameters``         - the latent model parameters over which a condition density is learned
            ``summary_conditions`` - the conditioning variables that are first passed through a summary network
            ``direct_conditions``  - the conditioning variables that the directly passed to the inference network
        **kwargs   : dict, optional, default: {}
            Additional keyword arguments passed to the networks
            For instance, ``kwargs={'training': True}`` is passed automatically during training.

        Returns
        -------
        total_loss : tf.Tensor of shape (1,) - the total computed loss given input variables
        """

        # Get amortizer outputs
        out_learner, out_inference = self(input_dict, **kwargs)

        logpdf_learner = self.latent_dist_summary.log_prob(out_learner[0])
        logpdf_params = self.latent_dist_params.log_prob(out_inference[0])

        # Compute and return total loss
        loss_learner = tf.reduce_mean(-logpdf_learner - out_learner[1])
        loss_params = tf.reduce_mean(-logpdf_params - out_inference[1])
        return {'KL.P': loss_params, 'KL.S': loss_learner}

    def sample(self, input_dict, n_samples, to_numpy=True, **kwargs):
        """Generates random draws from the approximate posterior given a dictionary with conditonal variables.

        Parameters
        ----------
        input_dict  : dict
            Input dictionary containing at least one of the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
            ``summary_conditions`` : the conditioning variables (including data) that are first passed through a summary network
            ``direct_conditions``  : the conditioning variables that the directly passed to the inference network
        n_samples   : int
            The number of posterior draws (samples) to obtain from the approximate posterior
        to_numpy    : bool, optional, default: True
            Flag indicating whether to return the samples as a ``np.ndarray`` or a ``tf.Tensor``.
        **kwargs    : dict, optional, default: {}
            Additional keyword arguments passed to the networks

        Returns
        -------
        post_samples : tf.Tensor or np.ndarray of shape (n_data_sets, n_samples, n_params)
            The sampled parameters from the approximate posterior of each data set
        """

        # Compute summary
        summary_out = self.summary_net(input_dict['summary_conditions'], training=False)
        summary_out, _ = self.summary_learner(summary_out, input_dict['direct_conditions'], **kwargs)

        # Create full conditions
        conditions = tf.concat([summary_out, input_dict['direct_conditions']], axis=-1)

        # Obtain number of data sets
        n_data_sets = conditions.shape[0]

        # Obtain random draws from the approximate posterior given conditioning variables
        # Case dynamic, assume tensorflow_probability instance, so need to reshape output from
        # (n_samples, n_data_sets, latent_dim) to (n_data_sets, n_samples, latent_dim)
        z_samples = self.latent_dist_params.sample((n_data_sets, n_samples))

        # Obtain random draws from the approximate posterior given conditioning variables
        post_samples = self.inference_net.inverse(z_samples, conditions, training=False, **kwargs)

        # Only return 2D array, if first dimensions is 1
        if post_samples.shape[0] == 1:
            post_samples = post_samples[0]

        # Return numpy version of tensor or tensor itself
        if to_numpy:
            return post_samples.numpy()
        return post_samples

    def log_posterior(self, input_dict, to_numpy=True, **kwargs):

        # Compute summary
        summary_out = self.summary_net(input_dict['summary_conditions'], training=False)
        summary_out, _ = self.summary_learner(summary_out, input_dict['direct_conditions'], **kwargs)

        # Create full conditions
        conditions = tf.concat([summary_out, input_dict['direct_conditions']], axis=-1)

        # Forward pass through the network
        z, log_det_J = self.inference_net.forward(
            input_dict["parameters"], conditions, training=False, **kwargs
        )
        log_post = self.latent_dist_params.log_prob(z) + log_det_J

        if to_numpy:
            return log_post.numpy()
        return log_post
