Source code for gmmvi.optimization.gmmvi

# -*- coding: utf-8 -*-
import tensorflow as tf

from gmmvi.models.gmm_wrapper import GmmWrapper
from gmmvi.experiments.target_distributions.lnpdf import LNPDF
from gmmvi.optimization.gmmvi_modules.component_stepsize_adaptation import ComponentStepsizeAdaptation
from gmmvi.optimization.gmmvi_modules.component_adaptation import ComponentAdaptation
from gmmvi.optimization.gmmvi_modules.ng_based_component_updater import NgBasedComponentUpdater
from gmmvi.optimization.gmmvi_modules.ng_estimator import NgEstimator
from gmmvi.optimization.gmmvi_modules.sample_selector import SampleSelector
from gmmvi.optimization.gmmvi_modules.weight_stepsize_adaptation import WeightStepsizeAdaptation
from gmmvi.optimization.gmmvi_modules.weight_updater import WeightUpdater
from gmmvi.optimization.sample_db import SampleDB


[docs]class GMMVI: """The main class of this framework, which provides the functionality to perform a complete update step for the GMM. Responsibilities for performing the necessary sub-steps (sample selection, natural gradient estimation, etc.) and for keeping track of data are delegated to the :py:mod:`GMMVI Modules<gmmvi.optimization.gmmvi_modules>`, the :py:class:`SampleDB<gmmvi.optimization.sample_db.SampleDB>` and :py:class:`GmmWrapper<gmmvi.models.gmm_wrapper.GmmWrapper>`. Hence, this class acts mainly as a manager between these components. Parameters: model: :py:class:`GmmWrapper<gmmvi.models.gmm_wrapper.GmmWrapper>` The (wrapped) model that we are optimizing. sample_db: :py:class:`SampleDB<gmmvi.optimization.sample_db.SampleDB>` The database for storing samples. temperature: tf.float32 The temperature parameter :math:`\\beta` for weighting the model entropy :math:`H(q)` in the optimization problem :math:`\\arg\\max_q \\mathbb{E}\\left[ \\log(\\tilde{p}(x)) \\right] + \\beta H(q)`. sample_selector: :py:class:`SampleSelector<gmmvi.optimization.gmmvi_modules.sample_selector.SampleSelector>` The SampleSelector for selecting the samples that are used during each iteration. num_component_adapter: :py:class:`NumComponentAdaptation\ <gmmvi.optimization.gmmvi_modules.component_adaptation.ComponentAdaptation>` The NumComponentAdapter used for adding and deleting components. component_stepsize_adapter: :py:class:`ComponentStepsizeAdaptation\ <gmmvi.optimization.gmmvi_modules.component_stepsize_adaptation.ComponentStepsizeAdaptation>` The ComponentStepsizeAdapter for choosing the learning rates for the component update. ng_estimator: :py:class:`NgEstimator<gmmvi.optimization.gmmvi_modules.ng_estimator.NgEstimator>` The NgEstimator for estimating the natural gradient for the component update. ng_based_updater: :py:class:`NgBasedComponentUpdater\ <gmmvi.optimization.gmmvi_modules.ng_based_component_updater.NgBasedComponentUpdater>` The NgBasedComponentUpdater for updating the components based on the estimated natural gradients. weight_stepsize_adapter: :py:class:`WeightStepsizeAdaptation\ <gmmvi.optimization.gmmvi_modules.weight_stepsize_adaptation.WeightStepsizeAdaptation>` The WeightStepsizeAdapter for choosing the learning rate for updating the mixture weights. weight_updater: :py:class:`WeightUpdater<gmmvi.optimization.gmmvi_modules.weight_updater.WeightUpdater>` The NgBasedComponentUpdater for updating the components based on the estimated natural gradients. """ def __init__( self, model: GmmWrapper, sample_db: SampleDB, temperature: tf.float32, sample_selector: SampleSelector, num_component_adapter: ComponentAdaptation, component_stepsize_adapter: ComponentStepsizeAdaptation, ng_estimator: NgEstimator, ng_based_updater: NgBasedComponentUpdater, weight_stepsize_adapter: WeightStepsizeAdaptation, weight_updater: WeightUpdater ): self.temperature = temperature self.model = model self.num_dimensions = self.model.num_dimensions self.sample_db = sample_db self.sample_selector = sample_selector self.num_component_adapter = num_component_adapter self.component_stepsize_adapter = component_stepsize_adapter self.ng_estimator = ng_estimator self.ng_based_updater = ng_based_updater self.weight_stepsize_adapter = weight_stepsize_adapter self.weight_updater = weight_updater self.num_updates = tf.Variable(0, dtype=tf.int32) if self.sample_selector.target_distribution.safe_for_tf_graph: # it is fine to query the target_distribution within a tf.function(), # so we can compile the whole training-step to a graph self.train_iter = tf.function(self.train_iter, experimental_follow_type_hints=True) else: # Everything not related to querying the target_distribution will be wrapped in a graph. # Needed when the target_distribution is not implemented in tensorflow self._run_updates = tf.function( self._run_updates, input_signature=[ tf.TensorSpec(shape=[None, None], dtype=tf.float32), # samples tf.TensorSpec(shape=[None], dtype=tf.int32), # mapping tf.TensorSpec(shape=[None], dtype=tf.float32), # sample_dist_densities tf.TensorSpec(shape=[None], dtype=tf.float32), # target_lnpdfs tf.TensorSpec(shape=[None, None], dtype=tf.float32)] # target_lnpdf_grads )
[docs] @staticmethod def build_from_config(config: dict, target_distribution: LNPDF, model: GmmWrapper): """Create a :py:class:`GMMVI<gmmvi.optimization.gmmvi.GMMVI>` instance from a configuration dictionary. This static method provides a convenient way to create a :py:class:`GMMVI<gmmvi.optimization.gmmvi.GMMVI>` instance, based on an initial GMM ( a :py:class:`wrapped model<gmmvi.models.gmm_wrapper.GmmWrapper>` ), a :py:class:`target_distribution<gmmvi.experiments.target_distributions.lnpdf.LNPDF>` and a dictionary containing the types and parameters of the :py:mod:`GMMVI modules<gmmvi.optimization.gmmvi_modules>`. Parameters: config: dict The dictionary should contain for each :py:mod:`GMMVI module<gmmvi.optimization.gmmvi_modules>` an entry of the form XXX_type (a string) and XXX_config (a dict) for specifying the type of each module, and the module-specific hyperparameters. For example, the dictionary could contain sample_selector_type={"component-based"} and sample_selector_config={"desired_samples_per_component": 100, "ratio_reused_samples_to_desired": 2.}. Refer to the example yml-configs, or to the individual GMMVI module for the expected parameters, and type-strings. target_distribution: :py:class:`LNPDF<gmmvi.experiments.target_distributions.lnpdf.LNPDF>` The (unnormalized) target distribution that we want to approximate. model: :py:class:`GmmWrapper<gmmvi.models.gmm_wrapper.GmmWrapper>` The (wrapped) model that we are optimizing. """ sample_db = SampleDB.build_from_config(config, model.num_dimensions) ng_estimator = NgEstimator.build_from_config(config, config['temperature'], model) ng_based_updater = NgBasedComponentUpdater.build_from_config(config, model) num_component_adapter = ComponentAdaptation.build_from_config( config, model, sample_db, target_distribution=target_distribution, prior_mean=config["model_initialization"]["prior_mean"], initial_cov=config["model_initialization"]["initial_cov"]) component_stepsize_adapter = ComponentStepsizeAdaptation.build_from_config(config, model) sample_selector = SampleSelector.build_from_config(config, model, sample_db, target_distribution) weight_updater = WeightUpdater.build_from_config(config, model) weight_stepsize_adapter = WeightStepsizeAdaptation.build_from_config(config, model) return GMMVI(model, sample_db, config['temperature'], sample_selector, num_component_adapter, component_stepsize_adapter, ng_estimator, ng_based_updater, weight_stepsize_adapter, weight_updater)
[docs] def train_iter(self): """Perform a single training iteration. This method does not take any parameters, nor does it return anything. However, it may have several effects, such as * drawing new samples from the :py:attr:`~model` and evaluating them on the target distribution, * updating the :py:attr:`gmmvi.optimization.gmmvi.GMMVI.model` parameters, * adapting learning rates, etc. """ samples, mapping, sample_dist_densities, target_lnpdfs, target_lnpdf_grads \ = self.sample_selector.select_samples() self._run_updates(samples, mapping, sample_dist_densities, target_lnpdfs, target_lnpdf_grads) self.num_component_adapter.adapt_number_of_components(self.num_updates)
def _run_updates(self, samples, mapping, sample_dist_densities, target_lnpdfs, target_lnpdf_grads): # Update components new_component_stepsizes = self.component_stepsize_adapter.update_stepsize(self.model.stepsizes) self.model.update_stepsizes(new_component_stepsizes) expected_hessian_neg, expected_grad_neg = self.ng_estimator.get_expected_hessian_and_grad( samples, mapping, sample_dist_densities, target_lnpdfs, target_lnpdf_grads) self.ng_based_updater.apply_NG_update(expected_hessian_neg, expected_grad_neg, self.model.stepsizes) # update weights weight_stepsize = self.weight_stepsize_adapter.update_stepsize() self.weight_updater.update_weights(samples, sample_dist_densities, target_lnpdfs, weight_stepsize) self.num_updates.assign_add(1)