import tensorflow as tf
from bayesflow.amortizers import AmortizedPosterior

from ..schedules import ConstantSchedule


class AmortizedPosteriorSC(AmortizedPosterior):
    def __init__(
        self,
        prior,
        simulator,
        real_data,
        lambda_schedule=ConstantSchedule(1.0),
        n_consistency_samples=32,
        theta_clip_value_min=-float("inf"),
        theta_clip_value_max=float("inf"),
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.prior = prior
        self.simulator = simulator
        self.real_data = tf.convert_to_tensor(real_data, dtype=tf.float32)
        self.step = tf.Variable(0, trainable=False, dtype=tf.int32)
        self.lambda_schedule = lambda_schedule
        self.n_consistency_samples = n_consistency_samples
        self.theta_clip_value_min = theta_clip_value_min
        self.theta_clip_value_max = theta_clip_value_max

    def compute_loss(self, input_dict, **kwargs):
        self.step.assign_add(1)
        lambda_ = self.lambda_schedule(self.step)

        # Get amortizer outputs
        net_out, sum_out = self(input_dict, return_summary=True, **kwargs)
        z, log_det_J = net_out

        # Case summary loss should be computed
        if self.summary_loss is not None:
            sum_loss = self.summary_loss(sum_out)
        # Case no summary loss, simply add 0 for convenience
        else:
            sum_loss = 0.0

        # Case dynamic latent space - function of summary conditions
        if self.latent_is_dynamic:
            logpdf = self.latent_dist(sum_out).log_prob(z)
        # Case _static latent space
        else:
            logpdf = self.latent_dist.log_prob(z)

        # Compute and return total posterior loss
        posterior_loss = tf.reduce_mean(-logpdf - log_det_J) + sum_loss

        # SELF CONSISTENCY LOSS

        if tf.greater(lambda_, 0.0):
            # either has the shape (num_datasets, data_dim) for no summary network
            # or (num_datasets, num_observations, data_dim) if a summary network
            # is to be used
            x = self.real_data
            n_datasets = x.shape[0]

            # z shape: n_consistency_samples, data_size, data_dim
            z = self.latent_dist.sample((self.n_consistency_samples, n_datasets))

            # add a n_consistency_samples dimension as first (0th) index to x
            # conditions shape: n_consistency_samples, n_datasets, summary_dim/data_dim
            if self.summary_net is not None:
                x_reshaped = tf.reshape(x, (1, x.shape[0], x.shape[1], x.shape[2]))
                x_repeated = tf.tile(x_reshaped, [self.n_consistency_samples, 1, 1, 1])
                conditions = self.summary_net(x_repeated)
            else:
                x_reshaped = tf.reshape(x, (1, x.shape[0], x.shape[1]))
                x_repeated = tf.tile(x_reshaped, [self.n_consistency_samples, 1, 1])
                conditions = x_repeated

            # theta shape: n_consistency_samples, n_datasets, summary_dim/data_dim
            theta = tf.stop_gradient(
                self.inference_net.inverse(z, conditions, training=False)
            )

            theta = tf.clip_by_value(
                theta,
                clip_value_min=self.theta_clip_value_min,
                clip_value_max=self.theta_clip_value_max,
            )

            # log_prior is log(p(theta)) with shape n_consistency_samples, n_datasets
            log_prior = self.prior.log_prob(theta)

            # log_lik is log(p(y | theta)) with shape n_consistency_samples, n_datasets
            if self.summary_net is not None:
                log_lik = self.simulator.log_prob(tf.expand_dims(theta, 2), x_repeated)
                log_lik = tf.math.reduce_sum(log_lik, axis=-1)
            else:
                log_lik = self.simulator.log_prob(theta, x_repeated)

            if self.summary_net is not None:
                sc_input_dict = {"parameters": theta, "summary_conditions": x_repeated}
            else:
                sc_input_dict = {"parameters": theta, "direct_conditions": x_repeated}
            log_post = self.log_posterior(sc_input_dict, to_numpy=False)

            # marginal likelihood p(y) = p(theta) * p(y | theta) / p(theta | y)
            # shape: n_consistency_samples, n_datasets
            log_ml = log_prior + log_lik - log_post

            # shape: data_size
            log_ml_var = tf.math.reduce_variance(log_ml, axis=-2)

            # shape: 1
            sc_loss = tf.math.reduce_mean(log_ml_var, axis=-1)
        else:
            sc_loss = tf.constant(0.0)

        return {"Post.Loss": posterior_loss, "SC.Loss": tf.multiply(lambda_, sc_loss)}
