import numpy as np
from distributions.GMM import GMM
from optimization.MORE import MORE
from optimization.LeastSquares import QuadFunc
import tensorflow as tf


class GMMLearner:

    def __init__(self, dim, surrogate_reg_fact, eta_offset, omega_offset, constrain_entropy,withgrad = True):

        self._dim = dim
        self._surrogate_reg_fact = surrogate_reg_fact
        self.surrogate = QuadFunc(dim, self._surrogate_reg_fact, normalize=False, normalize_output=False, withgrad=withgrad)
        self._eta_offset = eta_offset
        self._omega_offset = omega_offset
        self._constrain_entropy = constrain_entropy

        self._model = None
        self._weight_learner = None
        self._components_learners = []
        self._component_learner = MORE(self._dim, self._eta_offset, self._omega_offset, self._constrain_entropy)
        self.withgrad = withgrad

        ##with parallel_backend("loky", inner_max_num_threads=1):
         #   self._parallel_pool = Parallel(n_jobs=-1)

    def initialize_model(self, weights, means, covars):
        self._model = GMM(weights, means, covars)

    def get_relevant_active_samples(self, weights, min_mass=0.999):
        index = tf.argsort(-weights)
        sorted_weights = tf.gather(weights, index)
        cumulated_weights = tf.cumsum(sorted_weights)
        max_pos = tf.where(cumulated_weights > min_mass)[0][0]
        relevant_indices = index[0:(1 + max_pos)]
        return relevant_indices, tf.cast(max_pos+1, tf.int32)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def get_weights_and_targets(self, my_weights, my_rewards, cur_chol):
        target_vals = my_rewards[:, 0]
        gradients = my_rewards[:, 1:]
        targets_grad = gradients @ cur_chol
        targets_grad = tf.reshape(tf.transpose(targets_grad), [-1])
        if self.surrogate._no_first_order:
            targets = targets_grad
            my_weights = tf.tile(my_weights, [self._dim])
        else:
            targets = tf.concat((target_vals, targets_grad), axis=0)
            my_weights = tf.tile(my_weights, [self._dim + 1])
        return my_weights, targets

    def update_components_VON(self, kl_bound, weights, all_Hessians, all_gradients, with_grad):
        kls = []
        entropies = []
        means = []
        covars = []
        chols = []
        for i in range(self.model.num_components):
            cur_chol = self.model.chol_covar[i]
            cur_mean = self.model.means[i]
            cur_cov = self.model.covars[i]
            regularizer = self.model.l2_regularizers[i]
            try:
                G_hat = -tf.reduce_sum(tf.reshape(weights[i], shape=[-1, 1, 1]) * all_Hessians, 0)
                g_hat = -tf.reduce_sum(tf.reshape(weights[i], shape=[-1, 1]) * all_gradients, 0)
         #       G_hat = tf.expand_dims(g_hat,1) @ tf.expand_dims(g_hat,0)
                success, new_mean, new_covar, kl, entropy, eta, omega = self._upd_comp_VON(cur_mean, cur_chol, cur_cov, regularizer, kl_bound[i],
                                                                                           G_hat, tf.expand_dims(g_hat, 1), with_grad)
                chols.append(tf.linalg.cholesky(new_covar))
                kls.append(kl)
                entropies.append(entropy)
                means.append(new_mean)
                covars.append(new_covar)
            except:
                success = False
                chols.append(cur_chol)
                kls.append(0.)
                entropies.append(np.NAN)
                means.append(cur_mean)
                covars.append(cur_cov)
                print("debug this")

            if success:
                self.model.l2_regularizers.assign(tf.tensor_scatter_nd_update(self.model.l2_regularizers, [[i]], [tf.maximum(0.5 * self.model.l2_regularizers[i], 1.)]))
            else:
                self.model.l2_regularizers.assign(tf.tensor_scatter_nd_update(self.model.l2_regularizers, [[i]], [tf.minimum(1e10, 10 * self._model.l2_regularizers[i])]))
        self.model.replace_components(means, covars, chols)
        return [kls, entropies]

    @tf.function(input_signature=[
                                  tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.bool)
                                  ])
    def _upd_comp_VON(self, cur_mean, cur_chol, cur_cov,
                  regularizer, kl_bound, expected_Hessian, expected_gradient, withgrad=True):
        if False:
            delta_precision = expected_Hessian #- tf.linalg.inv(cur_cov) + 1e-2 * tf.eye(self._dim)
            delta_lin =  expected_Hessian @ tf.expand_dims(cur_mean, 1) - expected_gradient
            new_mean, new_covar, success, eta, omega, kl, entropy = self._component_learner.gradient_step(
                                                                                      tf.expand_dims(cur_mean, 1),
                                                                                      cur_chol, delta_lin,
                                                                                      delta_precision, 1e-6 * kl_bound)
        else:
            reward_quad = expected_Hessian #- tf.linalg.inv(cur_cov) + 1e-2 * tf.eye(self._dim)
            reward_lin = tf.squeeze(reward_quad @ tf.expand_dims(cur_mean, 1) - expected_gradient)
            new_mean, new_covar, success, eta, omega, kl, entropy = self._component_learner.more_step(kl_bound, -1,
                                                                                      self._eta_offset, self._omega_offset,
                                                                                      cur_mean,
                                                                                      cur_chol,
                                                                                      cur_cov,
                                                                                      reward_quad, reward_lin)
        if not success:
            kl = entropy = 0.
            new_mean = cur_mean
            new_covar = cur_cov
        return (success, new_mean, new_covar, kl, entropy, eta, omega)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.bool)
                                  ])
    def _upd_comp(self, my_weights, cur_mean, cur_chol, cur_cov,
                  regularizer, kl_bound, my_samples, my_rewards, withgrad):

        regularizer = self._surrogate_reg_fact * regularizer
        if withgrad:
            my_weights, targets = self.get_weights_and_targets(my_weights, my_rewards, cur_chol)
            reward_quad, reward_lin, const_term, o_std = self.surrogate.fit(regularizer,
                                                                            tf.shape(my_samples)[0],
                                                                            my_samples, targets, my_weights,
                                                                            cur_mean, cur_chol)
        else:
            reward_quad, reward_lin, const_term, o_std = self.surrogate.fit(regularizer,
                                                                            tf.shape(my_samples)[0],
                                                                            my_samples, my_rewards, my_weights,
                                                                            cur_mean, cur_chol)
        eta_offset = self._eta_offset / o_std
        omega_offset = self._omega_offset / o_std

        new_mean, new_covar, success, eta, omega, kl, entropy = self._component_learner.more_step(kl_bound, -1,
                                                                                  eta_offset, omega_offset,
                                                                                  cur_mean,
                                                                                  cur_chol,
                                                                                  cur_cov,
                                                                                  reward_quad, reward_lin)
        if not success:
            kl = entropy = 0.
            new_mean = cur_mean
            new_covar = cur_cov
        return (success, new_mean, new_covar, kl, entropy, eta, omega)


    def update_components_MORE(self, samples, weights, rewards, kl_bound, with_grad):
        kls = []
        entropies = []
        means = []
        covars = []
        chols = []
        for i in range(self.model.num_components):
            cur_chol = self.model.chol_covar[i]
            cur_mean = self.model.means[i]
            cur_cov = self.model.covars[i]
            regularizer = self.model.l2_regularizers[i]
            try:
                indices, num_samples = self.get_relevant_active_samples(weights[i])
                indices = tf.random.shuffle(indices)[:self.model.num_dimensions*50]
                num_samples = tf.shape(indices)[0]
                my_samples = tf.gather(samples, indices)
                my_weights = tf.gather(weights[i], indices)
                my_weights /= tf.reduce_sum(my_weights)
                my_rewards = tf.gather(rewards, indices)
                success, new_mean, new_covar, kl, entropy, eta, omega = self._upd_comp(my_weights, cur_mean, cur_chol, cur_cov, regularizer, kl_bound[i],
                                           my_samples, my_rewards, with_grad)
                chols.append(tf.linalg.cholesky(new_covar))
                kls.append(kl)
                entropies.append(entropy)
                means.append(new_mean)
                covars.append(new_covar)
            except:
                success = False
                chols.append(cur_chol)
                kls.append(0.)
                entropies.append(np.NAN)
                means.append(cur_mean)
                covars.append(cur_cov)
                print("debug this")

            if success:
                self.model.l2_regularizers.assign(tf.tensor_scatter_nd_update(
                    self.model.l2_regularizers, [[i]], [tf.maximum(0.5 * self.model.l2_regularizers[i], 1.)]))
            else:
                self.model.l2_regularizers.assign(tf.tensor_scatter_nd_update(
                    self.model.l2_regularizers, [[i]], [tf.minimum(1e10, 10 * self._model.l2_regularizers[i])]))
        self.model.replace_components(means, covars, chols)
        return [kls, entropies]


    def update_weights_closed_form(self, rewards, rewards_are_logits=False):
        old_weights = self.model.log_weights
        if rewards_are_logits:
            unnormalized_weights = rewards
        else:
            unnormalized_weights = (self._eta_offset * old_weights + rewards) / (self._eta_offset + self._omega_offset)
        new_log_probs = unnormalized_weights - tf.reduce_logsumexp(unnormalized_weights)
        new_log_probs = tf.math.maximum(new_log_probs, np.log(1e-30))
        new_log_probs -= tf.reduce_logsumexp(new_log_probs)
        kl = tf.reduce_sum(tf.exp(new_log_probs) * (new_log_probs - old_weights))
        entropy = -tf.reduce_sum(tf.exp(new_log_probs) * new_log_probs)
        self.model.replace_weights(new_log_probs)
        return kl, entropy, unnormalized_weights

    def update_weights(self, rewards, kl_bound, entropy_loss_bound=-1):
        old_dist = Categorical(self._model.weight_distribution.probabilities)
        entropy_bound = old_dist.entropy() - entropy_loss_bound

        new_probabilities = self._weight_learner.reps_step(kl_bound, entropy_bound, old_dist, rewards)
        if self._weight_learner.success:
            self._model.weight_distribution.probabilities = new_probabilities

        kl = self._model.weight_distribution.kl(old_dist)
        entropy = self._model.weight_distribution.entropy()
        return kl, entropy, self._weight_learner.last_eta, self._weight_learner.last_omega, " "

    def add_component(self, initial_weight, initial_mean, initial_covar):
        self._model.add_component(initial_weight, initial_mean.numpy(), initial_covar.numpy())
       # self._components_learners.append(MoreGaussian(self._dim, self._eta_offset, self._omega_offset,
       #                                               self._constrain_entropy))
        return self._model.num_components

    def remove_component(self, idx):
        self._model.remove_component(idx)
       # del self._components_learners[idx]
        return self._model.num_components

    @property
    def model(self):
        return self._model