import tensorflow as tf
from bayesflow.amortizers import AmortizedPosteriorLikelihood
from bayesflow.default_settings import DEFAULT_KEYS

from ..schedules import ConstantSchedule


class AmortizedPosteriorLikelihoodSC(AmortizedPosteriorLikelihood):
    def __init__(
        self,
        prior,
        real_data,
        lambda_schedule=ConstantSchedule(1.0),
        n_consistency_samples=10,
        theta_clip_value_min=-float("inf"),
        theta_clip_value_max=float("inf"),
        output_numpy=False,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.prior = prior
        self.real_data = tf.convert_to_tensor(real_data, dtype=tf.float32)
        self.step = tf.Variable(0, trainable=False, dtype=tf.int32)
        self.lambda_schedule = lambda_schedule
        self.n_consistency_samples = n_consistency_samples
        self.output_numpy = output_numpy
        self.theta_clip_value_min = theta_clip_value_min
        self.theta_clip_value_max = theta_clip_value_max

    @tf.function
    def compute_loss(self, input_dict, **kwargs):
        self.step.assign_add(1)
        lambda_ = self.lambda_schedule(self.step)

        # POSTERIOR LOSS ####
        # Get amortizer outputs
        posterior_input_dict = input_dict[DEFAULT_KEYS["posterior_inputs"]]
        net_out, sum_out = self.amortized_posterior(
            posterior_input_dict, return_summary=True, **kwargs
        )
        z, log_det_J = net_out

        # Case summary loss should be computed
        if self.amortized_posterior.summary_loss is not None:
            sum_loss = self.amortized_posterior.summary_loss(sum_out)
        # Case no summary loss, simply add 0 for convenience
        else:
            sum_loss = 0.0

        # Case dynamic latent space - function of summary conditions
        if self.amortized_posterior.latent_is_dynamic:
            logpdf = self.amortized_posterior.latent_dist(sum_out).log_prob(z)
        # Case _static latent space
        else:
            logpdf = self.amortized_posterior.latent_dist.log_prob(z)

        # Compute and return total posterior loss
        posterior_loss = tf.reduce_mean(-logpdf - log_det_J) + sum_loss

        # LIKELIHOOD LOSS ####

        likelihood_input_dict = input_dict[DEFAULT_KEYS["likelihood_inputs"]]
        z_lik, log_det_J_lik = self.amortized_likelihood(
            likelihood_input_dict, **kwargs
        )
        likelihood_loss = tf.reduce_mean(
            -self.amortized_likelihood.latent_dist.log_prob(z_lik) - log_det_J_lik
        )

        # SELF CONSISTENCY LOSS ####
        if tf.greater(lambda_, 0.0):
            x = self.real_data
            n_consistency_samples = self.n_consistency_samples
            data_size = x.shape[0]
            data_dim = x.shape[1]

            # z shape: n_consistency_samples, data_size, data_dim
            z = self.amortized_posterior.latent_dist.sample(
                (n_consistency_samples, data_size)
            )

            # conditions shape: n_consistency_samples, data_size, data_dim,
            # where the data is repeated across the first dimension, that is, any
            # indexing into the first dimension returns the same data.
            x_reshaped = tf.reshape(x, (1, data_size, data_dim))
            conditions = tf.tile(x_reshaped, [n_consistency_samples, 1, 1])

            # theta shape: n_consistency_samples, data_size, data_dim
            theta = tf.stop_gradient(
                self.amortized_posterior.inference_net.inverse(
                    z, conditions, training=False
                )
            )
            theta = tf.clip_by_value(
                theta,
                clip_value_min=self.theta_clip_value_min,
                clip_value_max=self.theta_clip_value_max,
            )

            # log_prior is log(p(theta)) with shape n_consistency_samples, data_size
            log_prior = self.prior.log_prob(theta)

            # log_lik is log(p(y | theta)) with shape n_consistency_samples, data_size
            lik_dict = {
                DEFAULT_KEYS["conditions"]: theta,
                DEFAULT_KEYS["observables"]: conditions,
            }
            log_lik = self.log_likelihood(lik_dict, to_numpy=False)

            posterior_dict = {"parameters": theta, "direct_conditions": conditions}
            log_post = self.log_posterior(posterior_dict, to_numpy=False)

            # marginal likelihood p(y) = p(theta) * p(y | theta) / p(theta | y)
            log_ml = log_prior + log_lik - log_post
            # shape: data_size
            log_ml_var = tf.math.reduce_variance(log_ml, axis=-2)

            # shape: 1
            sc_loss = tf.math.reduce_mean(log_ml_var, axis=-1)
        else:
            sc_loss = tf.constant(0.0)

        return {
            "Post.Loss": posterior_loss,
            "Lik.Loss": likelihood_loss,
            "SC.Loss": sc_loss * lambda_,
        }
