import tensorflow as tf
from math import pi

class MORE():
    def __init__(self, dim, eta_offset, omega_offset, constrain_entropy):
        self._omega_offset = tf.constant(omega_offset)
        self._constrain_entropy = tf.constant(constrain_entropy)
        self._dim = tf.cast(tf.constant(dim), tf.float32)
        self._dual_const_part = tf.constant(dim * tf.constant(tf.math.log(2 * pi)))
        self._entropy_const_part = tf.constant(0.5 * (self._dual_const_part + dim))
        self.regularization = 1e-12

    def get_natural_params(self, eta, omega, old_lin_term, old_precision, reward_lin_term, reward_quad_term):
        new_lin = (eta * old_lin_term + reward_lin_term) / (eta + omega)
        new_precision = (eta * old_precision + reward_quad_term) / (eta + omega)
        return new_lin, new_precision

    @tf.function
    def _dual(self, eta, eta_off, omega_off, eps, old_lin_term, old_precision, reward_lin_term, reward_quad_term, old_term):
        new_lin, new_precision = self.get_natural_params(eta_off, omega_off, old_lin_term, old_precision, reward_lin_term, reward_quad_term)
        new_mean = tf.reshape(tf.linalg.solve(new_precision, tf.expand_dims(new_lin, 1)), (-1))

        dual = eta * eps + eta * old_term #- omega * beta
        dual += 0.5 * (eta_off + omega_off) * (self._dual_const_part - tf.math.log(tf.linalg.det(new_precision))
                                                                              + tf.tensordot(new_lin, new_mean, axes=1))
        return dual

    def _safe_bracketing_search(self, eps, omega, lower_bound, upper_bound, old_lin_term, old_precision, old_chol_precision_T,
                                reward_lin_term, reward_quad_term, kl_const_part, old_mean):
        eta = 0.5 * (upper_bound + lower_bound)
        while tf.math.minimum(upper_bound - eta, eta-lower_bound) > 1e-1:
            try:
                kl = self.kl(eta, omega, old_lin_term, old_precision, old_chol_precision_T,
                             reward_lin_term, reward_quad_term, kl_const_part, old_mean)
            except:
                # new precision is not pd => the optimal eta must be larger
                lower_bound = eta
                eta = 0.5 * (upper_bound + lower_bound)
                continue

            if tf.abs(eps - kl) < 1e-1 * eps:
                # we indicate that we already found a sufficiently good eta, by setting:
                lower_bound = upper_bound = eta
                break

            if eps > kl:
                upper_bound = eta
            else:
                lower_bound = eta
                # we found a lower bound that is large enough to ensure pd precision matrices => we are done
                break
            eta = 0.5* (upper_bound + lower_bound)
        return lower_bound, upper_bound

  #  @tf.function
    def _bracketing_search(self, eps, omega, lower_bound, upper_bound, old_lin_term, old_precision, old_chol_precision_T,
                           reward_lin_term, reward_quad_term, kl_const_part, old_mean):
        # if the bracket becomes small but the the KL does not get close to eps we return -1 to indicate failure
        # This can happen if the initial upper_bound was actually too small
        eta_opt = -1.
        eta = 0.5 * (upper_bound + lower_bound)
        i = 0
        while tf.math.minimum(upper_bound - eta, eta-lower_bound) > 0.000005 and i< 500:
            i+=1
            kl = self.kl(eta, omega, old_lin_term, old_precision, old_chol_precision_T,
                         reward_lin_term, reward_quad_term, kl_const_part, old_mean)

            if tf.abs(eps - kl) < 1e-2 * eps:
                eta_opt = eta
                break

            if eps > kl:
                upper_bound = eta
            else:
                lower_bound = eta


            eta = 0.5* (upper_bound + lower_bound)
        kl = self.kl(upper_bound, omega, old_lin_term, old_precision, old_chol_precision_T,
                     reward_lin_term, reward_quad_term, kl_const_part, old_mean)
        if kl<eps:
            return upper_bound

        return eta_opt

 #   @tf.function
    def kl(self, eta_off, omega_off, old_lin_term, old_precision, old_inv_chol,
           reward_lin_term, reward_quad_term, kl_const_part, old_mean):
        ''' This function assumes that eta is large enough to ensure that the new precision is pd '''
        new_lin, new_precision = self.get_natural_params(eta_off, omega_off, old_lin_term, old_precision, reward_lin_term, reward_quad_term)
        new_mean = tf.reshape(tf.linalg.solve(new_precision, tf.expand_dims(new_lin, 1)), [-1])

        chol_precision = tf.linalg.cholesky(new_precision)

        new_logdet = -2 * tf.reduce_sum(tf.math.log(tf.linalg.tensor_diag_part(chol_precision)))
        new_inv_chol = tf.linalg.inv(chol_precision)
        trace_term = tf.square(tf.norm(new_inv_chol @ tf.transpose(old_inv_chol)))
        diff = old_mean - new_mean
        kl = 0.5 * (kl_const_part - new_logdet + trace_term
                    + tf.reduce_sum(tf.square(tf.linalg.matvec(old_inv_chol, diff))))
        return kl

  #  @tf.function
    def _get_distribution(self, eta_off, omega_off, old_lin_term, old_precision, old_inv_chol,
            reward_lin_term, reward_quad_term, kl_const_part, old_mean):
        ''' This function assumes that eta is large enough to ensure that the new precision is pd '''
        new_lin, new_precision = self.get_natural_params(eta_off, omega_off, old_lin_term, old_precision, reward_lin_term, reward_quad_term)
        new_mean = tf.reshape(tf.linalg.solve(new_precision, tf.expand_dims(new_lin, 1)), [-1])
        chol_precision = tf.linalg.cholesky(new_precision)
        new_logdet = -2 * tf.reduce_sum(tf.math.log(tf.linalg.tensor_diag_part(chol_precision)))
        new_inv_chol = tf.linalg.inv(chol_precision)
        new_cov = tf.transpose(new_inv_chol) @ new_inv_chol
        trace_term = tf.square(tf.norm(new_inv_chol @ tf.transpose(old_inv_chol)))
        diff = old_mean - new_mean
        kl = 0.5 * (kl_const_part - new_logdet + trace_term
                    + tf.reduce_sum(tf.square(tf.linalg.matvec(old_inv_chol, diff))))
        entropy = 0.5 * (new_logdet + self._dim * (tf.math.log(2 * pi) + 1))
        return new_mean, new_cov, kl, entropy

    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32), tf.TensorSpec(shape=[], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.float32), tf.TensorSpec(shape=[], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None], dtype=tf.float32),  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32), tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None], dtype=tf.float32)
                                  ])
    def more_step(self, eps, beta, eta_offset, omega_offset, old_mean, old_chol, old_cov, reward_quad_term, reward_lin_term):
        self._succ = False

        # precompute terms that do not depend on eta
        old_logdet = 2 * tf.reduce_sum(tf.math.log(tf.linalg.tensor_diag_part(old_chol)))
        old_inv_chol = tf.linalg.inv(old_chol)
        old_precision = tf.transpose(old_inv_chol) @ old_inv_chol
        old_lin_term = tf.linalg.matvec(old_precision, old_mean)
        kl_const_part = old_logdet - self._dim

        # We first need to refine our lower bound such that it ensures that the new precision is positive definite
        # We use a "safe" bracketing search that does not use @tf.function such that it can catch exceptions when
        # the cholesky decomposition fails
        lower_bound = tf.constant(0.)
        upper_bound = tf.constant(1e20)
        new_lower, new_upper = self._safe_bracketing_search(eps, omega_offset, lower_bound, upper_bound, old_lin_term,
                                                            old_precision, old_inv_chol,
                                                            reward_lin_term, reward_quad_term, kl_const_part, old_mean)
        if new_lower == new_upper:
            opt_eta = new_lower
        else:
            # The following faster bracketing search assumes that the cholesky decomposition is successful for all
            # etas in [lowerbound, upperbound], which we already ensured
            opt_eta = self._bracketing_search(eps, omega_offset, new_lower, new_upper, old_lin_term,
                                              old_precision, old_inv_chol,
                                              reward_lin_term, reward_quad_term, kl_const_part, old_mean)
        if opt_eta == -1.:
            return old_mean, old_cov, False, opt_eta, 0., -1., -1.
        else:
            eta_off = tf.maximum(opt_eta, eta_offset)
            new_mean, new_covar, kl, entropy = self._get_distribution(eta_off, omega_offset, old_lin_term, old_precision, old_inv_chol,
                                                        reward_lin_term, reward_quad_term, kl_const_part, old_mean)
            return new_mean, new_covar, True, eta_off, 0., kl, entropy

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.float32)])
    def gradient_step(self, old_mean, old_chol, delta_lin, delta_precision, stepsize):
        old_inv_chol = tf.linalg.inv(old_chol)
        old_precision = tf.transpose(old_inv_chol) @ old_inv_chol
        old_lin = old_precision @ old_mean

        new_lin = old_lin + stepsize * delta_lin
        new_precision = old_precision + stepsize * delta_precision

        new_mean = tf.reshape(tf.linalg.solve(new_precision, new_lin), [-1])
        chol_precision = tf.linalg.cholesky(new_precision)
        new_logdet = -2 * tf.reduce_sum(tf.math.log(tf.linalg.tensor_diag_part(chol_precision)))
        new_inv_chol = tf.linalg.inv(chol_precision)
        old_logdet = 2 * tf.reduce_sum(tf.math.log(tf.linalg.tensor_diag_part(old_chol)))
        kl_const_part = old_logdet - self._dim
        trace_term = tf.square(tf.norm(new_inv_chol @ tf.transpose(old_inv_chol)))
        diff = old_mean - new_mean
        kl = 0.5 * (kl_const_part - new_logdet + trace_term
                    + tf.reduce_sum(tf.square(tf.linalg.matvec(old_inv_chol, diff))))

        eta_off = 0
        return new_mean, tf.linalg.inv(new_precision), True, eta_off, 0., kl, 0.
