import tensorflow as tf
from bayesflow.amortizers import AmortizedPosterior

class ConstantSchedule:
    def __init__(self, value=1.0):
        self.value = value

    def __call__(self, step):
        return self.value

class LinearSchedule:
    def __init__(self, max_steps=32 * 100.0, init_step=1):
        self.init_step = tf.cast(init_step, tf.float32)
        self.max_steps = tf.cast(max_steps, tf.float32)

    def __call__(self, step):
        return tf.cast(step, tf.float32) / self.max_steps

class ZeroOneSchedule:
    def __init__(self, threshold_step, init_step=1):
        self.threshold_step = threshold_step
        self.init_step = init_step

class AmortizedPosteriorSC(AmortizedPosterior):
    def __init__(
        self,
        prior,
        simulator,
        real_data,
        lambda_schedule=ConstantSchedule(1.0),
        n_consistency_samples=32,
        theta_clip_value_min=-float("inf"),
        theta_clip_value_max=float("inf"),
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.prior = prior
        self.simulator = simulator
        self.real_data = tf.convert_to_tensor(real_data, dtype=tf.float32)
        self.step = tf.Variable(0, trainable=False, dtype=tf.int32)
        self.lambda_schedule = lambda_schedule
        self.n_consistency_samples = n_consistency_samples
        self.theta_clip_value_min = theta_clip_value_min
        self.theta_clip_value_max = theta_clip_value_max

    def compute_loss(self, input_dict, **kwargs):
        self.step.assign_add(1)
        lambda_ = self.lambda_schedule(self.step)

        # Get amortizer outputs
        net_out, sum_out = self(input_dict, return_summary=True, **kwargs)
        z, log_det_J = net_out

        if self.summary_loss is not None:
            sum_loss = self.summary_loss(sum_out)
        else:
            sum_loss = 0.0

        # Case dynamic latent space - function of summary conditions
        if self.latent_is_dynamic:
            logpdf = self.latent_dist(sum_out).log_prob(z)
        else:
            logpdf = self.latent_dist.log_prob(z)

        # Compute and return total posterior loss
        posterior_loss = tf.reduce_mean(-logpdf - log_det_J) + sum_loss

        # SELF CONSISTENCY LOSS

        if tf.greater(lambda_, 0.0):
            # either has the shape (num_datasets, data_dim) for no summary network
            # or (num_datasets, num_observations, data_dim) if a summary network
            # is to be used
            x = self.real_data
            n_datasets = x.shape[0]

            # z shape: n_consistency_samples, data_size, data_dim
            z = self.latent_dist.sample((self.n_consistency_samples, n_datasets))

            #print('z shape: ', z.shape)

            # add a n_consistency_samples dimension as first (0th) index to x
            # conditions shape: n_consistency_samples, n_datasets, summary_dim/data_dim
            if self.summary_net is not None:
                data_summary = self.summary_net(x)
                data_summary = tf.expand_dims(data_summary, axis=0)
                conditions = tf.tile(data_summary, [self.n_consistency_samples, 1, 1])
                x_reshaped = tf.reshape(x, (1, x.shape[0], x.shape[1], x.shape[2]))
                x_repeated = tf.tile(x_reshaped, [self.n_consistency_samples, 1, 1, 1])
            else:
                x_reshaped = tf.reshape(x, (1, x.shape[0], x.shape[1]))
                x_repeated = tf.tile(x_reshaped, [self.n_consistency_samples, 1, 1])
                conditions = x_repeated

            # theta shape: n_consistency_samples, n_datasets, summary_dim/data_dim
            theta = tf.stop_gradient(
                self.inference_net.inverse(z, conditions, training=False)
            )

            theta = tf.clip_by_value(
                theta,
                clip_value_min=self.theta_clip_value_min,
                clip_value_max=self.theta_clip_value_max,
            )

            # log_prior is log(p(theta)) with shape n_consistency_samples, n_datasets
            log_prior = self.prior.log_prob(theta)
            #print('log_prior shape: ', log_prior.shape)

            # log_lik is log(p(y | theta)) with shape n_consistency_samples, n_datasets
            if self.summary_net is not None:
                log_lik = self.simulator.log_prob(theta, x_repeated) #tf.expand_dims(theta, 2)
            else:
                log_lik = self.simulator.log_prob(theta, x_repeated)

            if self.summary_net is not None:
                sc_input_dict = {"parameters": theta, "direct_conditions": conditions}
            else:
                sc_input_dict = {"parameters": theta, "direct_conditions": x_repeated}

            # get log_post manually
            z, log_det_J = self.inference_net.forward(
                sc_input_dict["parameters"], conditions, training=False
            )

            log_post = self.latent_dist.log_prob(z) + log_det_J

            # marginal likelihood p(y) = p(theta) * p(y | theta) / p(theta | y)
            # shape: n_consistency_samples, n_datasets
            log_ml = log_prior + log_lik - log_post

            # shape: data_size
            log_ml_var = tf.math.reduce_variance(log_ml, axis=-2)

            # shape: 1
            sc_loss = tf.math.log(tf.math.reduce_mean(log_ml_var, axis=-1))
        else:
            sc_loss = tf.constant(0.0)

        return {"Post.Loss": posterior_loss, "SC.Loss": tf.multiply(lambda_, sc_loss)}
