import tensorflow as tf
from bayesflow.amortizers import AmortizedPosteriorLikelihood
from bayesflow.default_settings import DEFAULT_KEYS
import numpy as np
from schedules import BatchCyclingSchedule, LinearSchedule, ConstantSchedule, ZeroOneSchedule

class AmortizedPosteriorLikelihoodSC(AmortizedPosteriorLikelihood):
    def __init__(
        self,
        prior,
        real_data,
        lambda_schedule=ConstantSchedule(1.0),
        n_consistency_samples=10,
        theta_clip_value_min=-float("inf"),
        theta_clip_value_max=float("inf"),
        output_numpy=True,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.prior = prior
        if callable(real_data):
            self.real_data = real_data
        else:
            self.real_data = tf.convert_to_tensor(real_data, dtype=tf.float32)
        self.step = tf.Variable(0, trainable=False, dtype=tf.int32)
        self.lambda_schedule = lambda_schedule
        self.n_consistency_samples = n_consistency_samples
        self.output_numpy = output_numpy
        self.theta_clip_value_min = theta_clip_value_min
        self.theta_clip_value_max = theta_clip_value_max

    @tf.function
    def compute_loss(self, input_dict, **kwargs):
        self.step.assign_add(1)
        lambda_ = self.lambda_schedule(self.step)

        # POSTERIOR LOSS ####
        posterior_input_dict = input_dict[DEFAULT_KEYS["posterior_inputs"]]
        net_out, sum_out = self.amortized_posterior(
            posterior_input_dict, return_summary=True, **kwargs
        )
        z, log_det_J = net_out

        if self.amortized_posterior.summary_loss is not None:
            sum_loss = self.amortized_posterior.summary_loss(sum_out)
        else:
            sum_loss = 0.0

        if self.amortized_posterior.latent_is_dynamic:
            logpdf = self.amortized_posterior.latent_dist(sum_out).log_prob(z)
        else:
            logpdf = self.amortized_posterior.latent_dist.log_prob(z)

        posterior_loss = tf.reduce_mean(-logpdf - log_det_J) + sum_loss

        # LIKELIHOOD LOSS ####
        likelihood_input_dict = input_dict[DEFAULT_KEYS["likelihood_inputs"]]
        z_lik, log_det_J_lik = self.amortized_likelihood(
            likelihood_input_dict, **kwargs
        )
        likelihood_loss = tf.reduce_mean(
            -self.amortized_likelihood.latent_dist.log_prob(z_lik) - log_det_J_lik
        )

        # SELF CONSISTENCY LOSS ####
        if tf.greater(lambda_, 0.0):
            if callable(self.real_data):
                x = self.real_data(self.step)
            else:
                x = self.real_data
            n_consistency_samples = self.n_consistency_samples
            data_size = x.shape[0]

            z = tf.stop_gradient(self.amortized_posterior.latent_dist.sample(
                (n_consistency_samples, data_size))
            )

            if self.amortized_posterior.summary_net is not None:
                x_reshaped = tf.reshape(
                    x, (1, x.shape[0], x.shape[1], x.shape[2], x.shape[3])
                )
                x_repeated = tf.tile(
                    x_reshaped, [self.n_consistency_samples, 1, 1, 1, 1]
                )
                x_collapsed = tf.reshape(
                    x_repeated, (-1, x.shape[1], x.shape[2], x.shape[3])
                )

                conditions = self.amortized_posterior.summary_net(x_collapsed)
                conditions = tf.reshape(
                    conditions,
                    (self.n_consistency_samples, x.shape[0], conditions.shape[1]),
                )
            else:
                x_reshaped = tf.reshape(x, (1, x.shape[0], x.shape[1], x.shape[2]))
                x_repeated = tf.tile(x_reshaped, [self.n_consistency_samples, 1, 1])
                conditions = x_repeated

            theta = tf.stop_gradient(self.amortized_posterior.inference_net.inverse(z, conditions, training=False))


            theta_norm = tf.clip_by_value(theta, clip_value_min= -1.0, clip_value_max=1.0)

            prior_dict = {
                "parameters": theta_norm,
                "direct_conditions": tf.zeros([n_consistency_samples, data_size, 784])
            }

            log_prior = self.prior.log_prob(prior_dict, to_numpy=False)

            # -------------------------------
            lik_dict = {
                DEFAULT_KEYS["conditions"]: theta_norm,
                DEFAULT_KEYS["observables"]: conditions,
            }

            log_lik = self.amortized_likelihood.latent_dist.log_prob(
                self.amortized_likelihood.inference_net(
                    lik_dict[DEFAULT_KEYS["conditions"]],
                    lik_dict[DEFAULT_KEYS["observables"]]
                )[0]
            )

            # -------------------------------

            x_reshaped = tf.reshape(x, (1, x.shape[0], x.shape[1], x.shape[2], 1))  # Ensure last axis is 1
            x_repeated = tf.tile(x_reshaped, [self.n_consistency_samples, 1, 1, 1, 1])  # Repeat along batch axis
            conditions = tf.reshape(x_repeated, (self.n_consistency_samples * x.shape[0], x.shape[1], x.shape[2], 1))
            ## reshape theta to match conditions shape in the first dimension:
            theta_norm = tf.reshape(theta_norm, (self.n_consistency_samples * x.shape[0], -1)) # (10*20, 784)

            sc_input_dict = {"parameters": theta_norm, "summary_conditions": conditions}
            log_post = self.log_posterior(sc_input_dict, to_numpy=False)


            log_post = tf.reshape(log_post, (self.n_consistency_samples, x.shape[0]))

            # -------------------------------
            log_ml = (log_prior + log_lik - log_post)

            log_ml_var = tf.math.reduce_variance(log_ml, axis=-2)

            sc_loss = tf.math.reduce_mean(log_ml_var, axis=-1)

        else:
            sc_loss = tf.constant(0.0)

        return {
            "Post.Loss": posterior_loss,
            "Lik.Loss": likelihood_loss,
            "SC.Loss": sc_loss * lambda_,
        }
