from optimization.GMMLearner import GMMLearner
import numpy as np
from util.utils import get_effective_samples
from util.ConfigDict import ConfigDict
from recording.Recorder import RecorderKeys as rec
from util.Buffer import RingBuffer
import tensorflow as tf
from optimization.SampleDB import SampleDB
from math import pi
import time
from pathlib import Path

class MoreVIPS:

    @staticmethod
    def get_default_config():
        c = ConfigDict(
            reused_samples_per_component=100,
            desired_samples_per_component=10,
            num_initial_samples=0,
            component_optimizer="gradientMORE",
            train_epochs=int(1e5),
            # Component Updates
            component_kl_bound=0.01,
            # Mixture Updates
            weight_kl_bound=0.01,
            max_components=5000,
            adaptable=True,
            add_iters=3,
            del_iters=10,
            del_threshold=1e-6,
            mmd_alpha=20,
            mmd_rate=50,
        )
        c.finalize_adding()
        return c

    def __init__(self, config, savepath, target_distribution,  initial_w, initial_m, initial_c,  target_sample=0, recorder=0, seed=0, withgrad=True):
        self.c = config
        self.savepath = savepath
        Path(self.savepath).mkdir(parents=True, exist_ok=True)
        self.c.finalize_modifying()
        # build model
        self.withgrad = tf.convert_to_tensor(withgrad, dtype=tf.bool)
        self.recorder = recorder
        self.num_dimensions = initial_m.shape[-1]
        self._gmm_learner = GMMLearner(dim=self.num_dimensions, surrogate_reg_fact=1e-14,
                                       eta_offset=1.0, omega_offset=0.0, constrain_entropy=False, withgrad=self.withgrad)
        self._gmm_learner.initialize_model(initial_w.astype(np.float32),
                                           initial_m.astype(np.float32),
                                           initial_c.astype(np.float32))
        self.target_distribution = target_distribution
        self._target_uld = target_distribution.log_density

        self.adding_hyperparameter = tf.constant([1000., 500., 200., 100., 50.], dtype=tf.float32)
        self.target_sample = target_sample
        self.num_weight_updates = 0
        self.num_comp_updates = 0
        self.kl_bound_cmp = [self.c.component_kl_bound for _ in range(len(initial_m))]

        # build recording
        if self.c.adaptable:
            self._weights_buffer = RingBuffer(
                self.c.del_iters, self.model.num_components)
        if self.recorder:
            self._recorder = recorder
            self._recorder.initialize_module(rec.INITIAL)
            self._recorder(rec.INITIAL, "MORE VIPS", config)
            self._recorder.initialize_module(rec.MODEL, self.c.train_epochs)
            self._recorder.initialize_module(
                rec.WEIGHTS_UPDATE, self.c.train_epochs)
            self._recorder.initialize_module(
                rec.COMPONENT_UPDATE, self.c.train_epochs, self.model.num_components)
            self._recorder.initialize_module(rec.DRE, self.c.train_epochs)

        self.sample_base = SampleDB(dim=initial_m.shape[-1])
        new_samples, new_target_lnpdfs, new_target_grads, mapping = self.sample_where_needed(self.model,
                                                                                             tf.zeros((0, self.num_dimensions)),
                                                                                             tf.zeros((0)),
                                                                                             self.c.num_initial_samples)
        self.sample_base.add_samples(new_samples, self.model.means, self.model.chol_covar,
                                     new_target_lnpdfs, new_target_grads, mapping)
        self.addheuristic_ite = -1
        self.mean_history = [[] for _ in range(20000)]

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def target_uld(self, samples):
        return self._target_uld(samples)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def get_target_grads(self, samples):
        with tf.GradientTape(persistent=False) as gfg:
            gfg.watch(samples)
            target = self.target_uld(samples)
        return gfg.gradient(target, samples), target

    def _get_reward_grad(self, samples, active_target_lnpdfs, active_target_grads):
        model_logpdfs, model_logpdfs_grad = self.model.log_density_and_grad(samples)
        rewards = active_target_lnpdfs - model_logpdfs
        target_grad = active_target_grads - model_logpdfs_grad
        rewards_grad = tf.concat((tf.expand_dims(rewards, axis=1),target_grad), axis=1)
        return rewards_grad

    def _get_reward(self, samples, active_target_lnpdfs):
        model_logpdfs = self.model.log_density_tf(samples)
        return active_target_lnpdfs - model_logpdfs

    @property
    def model(self):
        return self._gmm_learner.model

    def train(self):

        for i in range(self.c.train_epochs):
            if self.recorder:
                self._recorder(rec.TRAIN_ITER, i)
            if i < 100 or i % 50 == 0:
                np.savez(self.savepath+'/gmm_dump_' + str("%01d" % i) + '.npz',
                         weights=np.exp(self.model.log_weights.numpy()), means=self.model.means.numpy(),
                         covs=self.model.covars.numpy(), timestamps=time.time(), fevals=len(self.sample_base.samples.numpy()))
            self.train_iter(i)

            self.test_samples = self.model.sample(2000)
            model_lnpdfs = [self.model.log_density(self.test_samples)]
            mean_reward = np.mean(self.target_uld(self.test_samples))
            print("Checkpoint {:3d} | FEVALS: {:10d} | avg. sample logpdf: {:05.05f} | ELBO: {:05.05f}".format(
                i, len(self.sample_base.samples), mean_reward, mean_reward - np.mean(model_lnpdfs[-1])))

    def train_iter(self, i):
        if self.c.adaptable and i > 1 and i % self.c.add_iters == 0:
            if self.model.num_components < self.c.max_components:
                self._add_heuristic()
                self.kl_bound_cmp.append(self.c.component_kl_bound)
        if self.c.adaptable and i > 1 and i % self.c.del_iters == 0:
            self._del_heuristic()

        self.select_active_samples()
        self.recompute_importance_weights()
        self.adapt_kl_bounds()
        c_res = self.update_components()

        if self.model.num_components > 1:
            w_res = self.update_weight()
            if self.recorder:
                self._recorder(rec.WEIGHTS_UPDATE, w_res)


        if self.c.adaptable:
            self._weights_buffer.write(
                tf.exp(self.model.log_weights))
        if self.recorder:
            self._recorder(rec.COMPONENT_UPDATE, c_res)
            self._recorder(rec.MODEL, self.model, i)

    def sample_where_needed(self, model, samples, oldsamples_pdf, num_desired_samples=None):
        if num_desired_samples is None:
            num_desired_samples = self.c.desired_samples_per_component
        if len(samples) == 0:
            num_effective_samples = tf.zeros((model.num_components))
        else:
            model_logpdfs = model.component_log_densities(samples)
            num_effective_samples = get_effective_samples(model_logpdfs, oldsamples_pdf)
        num_additional_samples = tf.cast(tf.math.maximum(1, tf.math.ceil(num_desired_samples - num_effective_samples)),
                                         tf.int32)
        new_samples = self.model.sample_from_components(num_additional_samples)
        mapping = tf.repeat(tf.range(len(new_samples)), [len(samples) for samples in new_samples])
        new_samples = tf.concat(new_samples, axis=0)
        new_target_grads, new_target_lnpdfs = self.get_target_grads(new_samples)
        return new_samples, new_target_lnpdfs, new_target_grads, mapping

    def select_active_samples(self):
        oldsamples_pdf, samples, _, _ = self.sample_base.get_newest_samples(self.c.reused_samples_per_component*self.model.num_components)

        total = len(samples)
        new_samples, new_target_lnpdfs, new_target_grads, mapping = self.sample_where_needed(self.model, samples,
                                                                                             oldsamples_pdf)
        self.sample_base.add_samples(new_samples, self.model.means, self.model.chol_covar,
                                     new_target_lnpdfs, new_target_grads, mapping)
        total += len(new_samples)
        oldsamples_pdf, samples, target_lnpdfs, target_grads = self.sample_base.get_newest_samples(total)
        self.active_samples = samples
        self.active_samples_background_pdf = oldsamples_pdf
        self.active_target_lnpdfs = target_lnpdfs
        self.active_target_grads = target_grads
        self.update_rewards()

    def update_rewards(self):
        if not self.withgrad:
            self.active_samples_rewards = self._get_reward(self.active_samples, self.active_target_lnpdfs)
        else:
            self.active_samples_rewards = self._get_reward_grad(self.active_samples, self.active_target_lnpdfs, self.active_target_grads)

    def recompute_importance_weights(self):
        log_pdfs = self.model.component_log_densities(self.active_samples)
        log_weights = log_pdfs - self.active_samples_background_pdf
        log_weights -= tf.reduce_logsumexp(log_weights, axis=1, keepdims=True)
        weights = tf.exp(log_weights)
        self.active_samples_weights = weights / tf.reduce_sum(weights, axis=1, keepdims=True)

    def adapt_kl_bounds(self):
        for i in range(len(self.mean_history)):
            if len(self.mean_history[i]) > 1:
                if self.mean_history[i][-2][0] >= self.mean_history[i][-1][0]:
                    self.kl_bound_cmp[i] = tf.math.maximum(0.8*self.kl_bound_cmp[i], 0.01)
                else:
                    self.kl_bound_cmp[i] = tf.math.minimum(1.2*self.kl_bound_cmp[i], 0.2)

    def update_components(self):
        if not self.withgrad:
            rewards = tf.expand_dims(self.active_samples_rewards, 1)
        else:
            rewards = self.active_samples_rewards
        kls, entropies = self._gmm_learner.update_components_MORE(self.active_samples, self.active_samples_weights,
                                                                  rewards, self.kl_bound_cmp,
                                                                  with_grad=self.withgrad)
        return [(self.num_comp_updates, a,b) for a,b in zip(kls, entropies)]

    def update_weight(self):
        self.num_weight_updates += 1
        self.update_rewards()
        self.recompute_importance_weights()
        entropies = expected_log_responsibilities = target_lnpdfs = tf.zeros(self.model.num_components)
        if self.withgrad:
            active_rewards = self.active_samples_rewards[:,0]
        else:
            active_rewards = self.active_samples_rewards
        rewards = tf.linalg.matvec(self.active_samples_weights, active_rewards)
        kl, entropy, del_rewards = self._gmm_learner.update_weights_closed_form(rewards, rewards_are_logits=False)

        for n in range(self.model.num_components):
            self.mean_history[n].append([del_rewards[n]])

        return self.num_weight_updates, kl, entropy, del_rewards, target_lnpdfs, entropies, expected_log_responsibilities

    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  ])
    def _add_heuristic_tf(self, hyperparameter, model_log_densities, des_entropy, target_lnpdfs):
        max_logdensity = tf.reduce_max(model_log_densities)
        rewards = target_lnpdfs - tf.maximum(max_logdensity-hyperparameter, model_log_densities)
        new_mean_idx = tf.argmax(rewards)
        H_unscaled = 0.5 * self.num_dimensions * (tf.math.log(2. * pi) + 1)
        c = tf.math.exp((2 * (des_entropy - H_unscaled)) / self.num_dimensions)
        new_cov = c * tf.eye(self.num_dimensions)
        return new_mean_idx, new_cov

    def _add_heuristic(self):
        samples, target_lnpdfs = self.sample_base.get_random_sample(int(1e5))
        self.addheuristic_ite += 1
        iter = self.addheuristic_ite % len(self.adding_hyperparameter)
        model_log_densities = self.model.log_density_tf(samples)
        des_entropy = self.model.get_average_entropy()
        new_mean_idx, new_cov = self._add_heuristic_tf(self.adding_hyperparameter[iter],
                                                       model_log_densities, des_entropy, target_lnpdfs)
        new_mean = samples[new_mean_idx]
        new_idx = self._gmm_learner.add_component(1e-29, new_mean, new_cov)
        self._weights_buffer.add_entry(new_idx)

    def _del_heuristic(self):
        """Heuristic for component deletion"""
        max_weights = np.max(self._weights_buffer.data, 1)
        n = self.model.num_components - 1
        reward_improvements = []
        for x in self.mean_history:
      #      if len(x) > 4:
            if len(x) > self.c.del_iters:
                reward_improvements.append((x[-1][0] - x[-self.c.del_iters][0]) / tf.abs(x[-self.c.del_iters][0]))
            else:
                reward_improvements.append(np.Inf)
        stagnating_indices = np.where(np.array(reward_improvements) < 0.01)[0]
        del_threshold = self.c.del_threshold
        low_weight_indices = np.where(np.exp(self.model.log_weights) < del_threshold)[0]
        del_idx = list(set(low_weight_indices) & set(stagnating_indices))

        # its important to go backwards to the list if multiple components are deleted
        # (otherwise we have issues since the indices shift)
        for idx in reversed(sorted(del_idx)):
            self._gmm_learner.remove_component(idx)
            self._weights_buffer.del_entry(idx)
            del self.mean_history[idx]
            del self.kl_bound_cmp[idx]

        self.mean_history = [entry[-self.c.del_iters-6::] for entry in self.mean_history]
        return len(del_idx) > 0
