import tensorflow as tf
from bayesflow.amortizers import AmortizedPosterior

from src.self_consistency_real.schedules import (
    ConstantSchedule,
)


class AmortizedPosteriorSC(AmortizedPosterior):
    def __init__(
        self,
        prior,
        simulator,
        real_data,
        lambda_schedule=ConstantSchedule(1.0),
        n_consistency_samples=32,
        theta_clip_value_min=-float("inf"),
        theta_clip_value_max=float("inf"),
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.prior = prior
        self.simulator = simulator
        self.real_data = real_data  # tf.convert_to_tensor(real_data, dtype=tf.float32)
        self.step = tf.Variable(0, trainable=False, dtype=tf.int32)
        self.lambda_schedule = lambda_schedule
        self.n_consistency_samples = n_consistency_samples
        self.theta_clip_value_min = theta_clip_value_min
        self.theta_clip_value_max = theta_clip_value_max

    def compute_loss(self, input_dict, **kwargs):
        self.step.assign_add(1)
        lambda_ = self.lambda_schedule(self.step)

        # Get amortizer outputs
        net_out, sum_out = self(input_dict, return_summary=True, **kwargs)
        z, log_det_J = net_out

        # Case summary loss should be computed
        if self.summary_loss is not None:
            sum_loss = self.summary_loss(sum_out)
        # Case no summary loss, simply add 0 for convenience
        else:
            sum_loss = 0.0

        # Case dynamic latent space - function of summary conditions
        if self.latent_is_dynamic:
            logpdf = self.latent_dist(sum_out).log_prob(z)
        # Case _static latent space
        else:
            logpdf = self.latent_dist.log_prob(z)

        # Compute and return total posterior loss
        posterior_loss = tf.reduce_mean(-logpdf - log_det_J) + sum_loss

        # SELF CONSISTENCY LOSS

        if tf.greater(lambda_, 0.0):
            # x has shape (n_datasets, data_dim)

            # indices = tf.stop_gradient(tf.range(tf.shape(self.real_data)[0]))
            # prior_draw = tf.random.normal((64, 7), 1, 2)
            # x = tf.stop_gradient(self.simulator(prior_draw))
            if callable(self.real_data):
                x = self.real_data()
            else:
                x = self.real_data

            n_datasets = tf.shape(x)[0]
            #
            # z shape: n_consistency_samples, n_datasets, data_dim
            z = self.latent_dist.sample(
                (self.n_consistency_samples, n_datasets), to_numpy=False
            )

            # add a n_consistency_samples dimension as first (0th) index to x
            # conditions shape: n_consistency_samples, n_datasets, summary_dim
            data_summary = self.summary_net(x)
            data_summary = tf.expand_dims(data_summary, axis=0)
            conditions = tf.tile(data_summary, [self.n_consistency_samples, 1, 1])

            # x_repeated shape: n_consistency_samples, n_datasets, data_dim
            x_reshaped = tf.expand_dims(x, axis=0)
            x_repeated = tf.tile(x_reshaped, [self.n_consistency_samples, 1, 1])

            # theta shape: n_consistency_samples, n_datasets, n_params
            theta = tf.stop_gradient(
                self.inference_net.inverse(z, conditions, training=False)
            )

            # log_prior is log(p(theta)) with shape n_consistency_samples, n_datasets
            log_prior = self.prior.log_prob(theta)

            # log_lik is log(p(y | theta)) with shape n_consistency_samples, n_datasets
            log_lik = tf.stop_gradient(self.simulator.log_prob(theta, x_repeated))

            # log_post is log(p(theta | y)) with shape n_consistency_samples, n_datasets
            sc_input_dict = {
                "parameters": tf.reshape(theta, (-1, tf.shape(theta)[-1])),
                "summary_conditions": tf.reshape(
                    x_repeated, (-1, tf.shape(x_repeated)[-1])
                ),
            }
            log_post = self.log_posterior(sc_input_dict, to_numpy=False)
            log_post = tf.reshape(log_post, (tf.shape(theta)[:-1]))

            # marginal likelihood p(y) = p(theta) * p(y | theta) / p(theta | y)
            # shape: n_consistency_samples, n_datasets
            log_ml = log_prior + log_lik - log_post

            # shape: data_size
            log_ml_var = tf.math.reduce_variance(log_ml, axis=-2)

            # shape: 1
            sc_loss = tf.math.reduce_mean(log_ml_var, axis=-1)
        else:
            sc_loss = tf.constant(0.0)

        return {
            "Post.Loss": posterior_loss,
            "SC.Loss": tf.multiply(lambda_, sc_loss),
        }
