from optimization.VIPS import MoreVIPS
import tensorflow as tf

class VonVIPS(MoreVIPS):
    def __init__(self, **kwargs):
        super(VonVIPS, self).__init__(**kwargs)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def get_expected_Hessian_and_grad(self, samples):
        with tf.GradientTape() as tape:
            tape.watch(samples)
            with tf.GradientTape() as tape2:
                tape2.watch(samples)
                val = self.target_uld(samples) - self.model.log_density_tf(samples)
            gradients = tape2.gradient(val, samples)
        Hessians = tape.batch_jacobian(gradients, samples)
        return Hessians, gradients

    def update_components(self):
        all_Hessians, all_gradients = self.get_expected_Hessian_and_grad(self.active_samples)
        kls, entropies = self._gmm_learner.update_components_VON(self.kl_bound_cmp, self.active_samples_weights,
                                                                 all_Hessians, all_gradients,
                                                                 with_grad=self.withgrad)
        return [(0, a,b) for a,b in zip(kls, entropies)]


class VognVIPS(VonVIPS):
    def __init__(self, exploit_Bayesian=False, **kwargs):
        super(VonVIPS, self).__init__(**kwargs)
        self.exploit_Bayesian = exploit_Bayesian
        if exploit_Bayesian:
            self._target_likelihoods = self.target_distribution.log_likelihood
            self._target_prior_std = self.target_distribution.prior_std

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def get_expected_Hessian_and_grad(self, samples):
        with tf.GradientTape(persistent = True) as tape:
            tape.watch(samples)
            if self.exploit_Bayesian:
                with tf.GradientTape(persistent = True) as tape2:
                    tape2.watch(samples)
                    model_lnpdfs = self.model.log_density_tf(samples)
                model_gradients = tape2.gradient(model_lnpdfs, samples)
                data_likelihoods = self._target_likelihoods(samples)
            else:
                model_lnpdfs = self.model.log_density_tf(samples)
                val = self.target_uld(samples) - model_lnpdfs
        if self.exploit_Bayesian:
            likelihood_gradients_per_datum = tape.batch_jacobian(data_likelihoods, samples)
            likelihood_EF = -tf.reduce_sum(tf.expand_dims(likelihood_gradients_per_datum, 3)
                                           @ tf.expand_dims(likelihood_gradients_per_datum,2), axis=1)
            likelihood_gradients = tf.reduce_sum(likelihood_gradients_per_datum, axis = 1)
            prior_gradients = -(1/(self.target_distribution.prior_std**2) * samples)
            gradients = likelihood_gradients + prior_gradients - model_gradients
            model_Hessians = tape.batch_jacobian(model_gradients, samples)
            Hessians = likelihood_EF - model_Hessians - 1/(self.target_distribution.prior_std**2) * tf.eye(self.num_dimensions)
        else:
            gradients = tape.gradient(val, samples)
            Hessians = -tf.expand_dims(gradients, 2) @ tf.expand_dims(gradients, 1)
        return Hessians, gradients

