import abc
import warnings
from copy import deepcopy
from typing import Optional
from tqdm import tqdm
from scipy.optimize import root_scalar
import numpy as np

import torch
import torch.utils.tensorboard

from nfmc_jax.flows.base import FlowInterface, TorchFlowInterface, PPInterface, GPInterface
from nfmc_jax.DLA.beta import BetaHandler, SingleStageBetaHandler, ESSBetaHandler
from nfmc_jax.DLA.debug import MultiStageDebugger
from nfmc_jax.DLA.optim import ParticleAdagrad, ParticleRMSProp, ParticleAdam, ParticleGradientDescent, ParticleLineSearch, \
    IdentityScheduler, ExponentialDecayScheduler, CosineAnnealingScheduler
from nfmc_jax.DLA.posterior import DifferentiableTemperedPosterior, TorchPosterior


class DLA(abc.ABC):
    def __init__(self,
                 interface: FlowInterface,
                 posterior: DifferentiableTemperedPosterior,
                 beta_handler: BetaHandler = SingleStageBetaHandler(),
                 debugger: Optional[MultiStageDebugger] = None):
        """
        DLA class.

        :param interface: object used to train a normalizing flow, obtain latent samples and compute gradients for
            sample updates.
        :param posterior: object used to obtain gradients of U w.r.t. samples by evaluating the gradient of the
        prior and likelihood functions w.r.t. samples.
        """

        self.interface = interface
        self.posterior = posterior
        self.debugger = debugger
        self._beta_handler = beta_handler

    @abc.abstractmethod
    def run(self, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    def log_iteration_data(self, x, latent_samples, scalars: dict = None):
        if scalars is None:
            scalars = dict()
        if self.debug:
            # Log data
            # The argument x_before_update is irrelevant, because we use a cached tensor of log probabilities
            self.debugger.add_scalar('log_likelihood', self.posterior.log_likelihood(None).mean())
            self.debugger.add_scalar('log_prior', self.posterior.log_prior(None).mean())
            for key, value in scalars.items():
                self.debugger.add_scalar(key, value)
            self.debugger.add_particles(x)
            try:
                flow_samples = self.interface.inverse(z=latent_samples)
                if torch.isinf(flow_samples).any():
                    warnings.warn('A flow particle contains inf.')
                if torch.isnan(flow_samples).any():
                    warnings.warn('A flow particle contains nan.')
                self.debugger.add_flow_samples(flow_samples)
            except NotImplementedError:
                pass
            self.debugger.step()

    @property
    def finished(self):
        return self._beta_handler.finished

    @property
    def debug(self):
        return self.debugger is not None

    @property
    def beta(self):
        return self._beta_handler.beta

    def set_new_beta(self, **kwargs):
        if isinstance(self._beta_handler, ESSBetaHandler):
            if 'x' not in kwargs:
                raise ValueError("x must be provided as a keyword argument")
            x = kwargs['x']
            log_q = self.interface.logq(x)
            self.posterior.clear_cache()
            return self._beta_handler.step(
                log_likelihood=self.posterior.log_likelihood(x),
                log_prior=self.posterior.log_prior(x),
                log_q=log_q,
                **kwargs
            )
        else:
            return self._beta_handler.step(**kwargs)

    def create_beta_stage_animation(self, beta_stage, **animate_kwargs):
        if self.debug:
            self.debugger.animate(ax_title=f'[Stage {beta_stage}] Beta: {self.beta}', **animate_kwargs)


class LatentSDLMC(DLA):
    def __init__(self,
                 interface: TorchFlowInterface,
                 surrogate_interface: TorchFlowInterface,
                 posterior: DifferentiableTemperedPosterior,
                 beta_handler: BetaHandler = SingleStageBetaHandler(),
                 debugger: Optional[MultiStageDebugger] = None):
        """
        Surrogate-based DLMC, used for gradient-free sampling.
        :param gp_interface: Additional GP interface, used to construct target surrogates.
        """
        super().__init__(interface, posterior, beta_handler, debugger)
        self.surrogate_interface = surrogate_interface

    def IMH(self, x):
        self.posterior.clear_cache()
        logp_old = self.posterior.value(self.beta, x)
        logq_old = self.interface.logq(x)

        sample, logq_new = self.interface.sample_with_logq(len(x))
        self.posterior.clear_cache()
        logp_new = self.posterior.value(self.beta, sample)

        logr = logp_new + logq_old - logp_old - logq_new

        accept = torch.log(torch.rand(len(x))) < logr
        x[accept] = sample[accept]
        logp_old[accept] = logp_new[accept]
        return x, accept, logp_old

    def pearson(self, xi, x0):
        x0m = x0 - torch.mean(x0, dim=0)
        xim = xi - torch.mean(xi, dim=0)
        ximx0m = torch.sum(xim * x0m, dim=0)
        x0mx0m = torch.sum(x0m ** 2, dim=0) ** 0.5
        ximxim = torch.sum(xim ** 2, dim=0) ** 0.5
        return torch.abs(ximx0m / (x0mx0m * ximxim))

    def tune_metropolis(self, scale, acc_rate, mh_step):
        # This form limits things to be positive ...
        return torch.exp(torch.log(scale) + (acc_rate - 0.234) / (mh_step + 1.0))

    def CMH(self, x, scale, correlation_threshold=0.01):
        self.posterior.clear_cache()
        logp_old = self.posterior.value(self.beta, x)

        z, logj_forward_old = self.interface.forward_with_logj(x)
        z = z.reshape(len(z), -1)
        old_z = deepcopy(z)
        old_r = 2.0
        cmh_step = 0

        acceptance_rate = []
        replaced = torch.zeros(len(x), dtype=torch.bool, device=x.device)
        while True:
            cmh_step += 1
            if cmh_step > 1:
                scale = self.tune_metropolis(scale, acceptance_rate[-1:].pop(), cmh_step)

            z_new = z + torch.randn_like(z) * scale
            x_new, logj_backward_new = self.interface.inverse_with_logj(z_new)

            self.posterior.clear_cache()
            logp_new = self.posterior.value(self.beta, x_new)

            logr = (logp_new + logj_backward_new) - (logp_old - logj_forward_old)

            accept = torch.log(torch.rand(len(x))) < logr
            acceptance_rate.append(torch.sum(accept).item() / len(x))
            replaced = replaced + accept

            x[accept] = x_new[accept]
            logp_old[accept] = logp_new[accept]
            z[accept] = z_new[accept]
            logj_forward_old[accept] = -logj_backward_new[accept]

            new_r = self.pearson(z, old_z)
            delta_r = (old_r - new_r) > correlation_threshold
            if torch.mean(delta_r.bool().float()) > 0.9:
                old_r = new_r
            else:
                break

        return x, acceptance_rate, torch.sum(replaced).item() / len(x), scale, cmh_step

    def systematic_resampling(self, weights):
        positions = (torch.rand(1) + torch.arange(weights.shape[0])) / weights.shape[0]
        resampling_idx = torch.zeros(weights.shape[0], dtype=torch.long)
        cumulative_sum = torch.cumsum(weights, dim=0)
        i, j = 0, 0
        while i < weights.shape[0]:
            if positions[i] < cumulative_sum[j]:
                resampling_idx[i] = j
                i += 1
            else:
                j += 1
        return resampling_idx

    def tune_hmc(self, step_size, acc_rate, hmc_step):
        particle_scales = torch.exp(torch.log(step_size) + (acc_rate - 0.8) / hmc_step)
        return 0.5 * (particle_scales + torch.mean(particle_scales))

    def hmc_tuning_metric(self, z_leap, z_init, log_mh, mass, num_leapfrog):
        """
        :param z_leap: latent space samples after a single HMC update at step i-1.
        :param z_init: initial latent space locations at step i-1.
        :param log_mh: log-MH acceptance ratio after a single HMC update.
        :param mass: diagonal mass matrix elements for step i-1.
        :param num_leapfrog: number of leapfrog steps assigned to each particle at step i-1, shape (n,).
        :return: log of the tuning metrix, used to define weights for HMC parameters at step i.
        """
        delta_x = torch.sum((z_leap - z_init) ** 2, dim=1)
        return torch.log(delta_x) + torch.clamp(log_mh, min=None, max=0.0) + torch.log(num_leapfrog)

    def hmc_step_size_regression(self,
                                 deltaE: torch.Tensor,
                                 step_sizes: torch.Tensor,
                                 n_epochs: int = 2000,
                                 lr: float = 1e-3):
        """
        :param deltaE: absolute change in energy after leapfrog update.
        :param step_sizes: leapfrog step sizes.
        :param n_epochs: number of training epochs for median regression.
        :param lr: learning rate for median regression training.
        :return: new maximum leapfrog step size.
        """

        class MedianRegression(torch.nn.Module):
            def __init__(self, in_size, out_size):
                super().__init__()
                self.mreg = torch.nn.Linear(in_features=in_size, out_features=out_size)

            def forward(self, X):
                pred = self.mreg(X)
                return (pred)

        model = MedianRegression(1, 1)
        loss_fun = torch.nn.L1Loss()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        for _ in range(n_epochs):
            y_pred = model.forward(step_sizes.reshape(step_sizes.shape[0], 1) ** 2)
            loss = loss_fun(y_pred, torch.abs(deltaE).reshape(deltaE.shape[0], 1))
            loss.backward()
            optimizer.step()
        [w, b] = model.parameters()

        print(f'w = {w.item()}')
        print(f'b = {b.item()}')

        return np.sqrt((-np.log(0.9) - b.item()) / w.item())

    def hmc_pre_tuning(self, x, z, mass, step_size_max, leapfrog_max):
        # Pre-tuning algorithm in Buchholz et al. (2020).
        # Note, currently not implementing the pre-tuning at every step, so no quantile regression.
        hmc_step_size = torch.rand((x.shape[0],)) * step_size_max
        num_leapfrog = torch.randint(low=1, high=leapfrog_max + 1, size=(x.shape[0],))
        # momenta = torch.distributions.normal.Normal(loc=torch.zeros(x.shape[1]), scale=torch.sqrt(mass)).sample(
        #    (x.shape[0],))
        # mvn_p = torch.distributions.multivariate_normal.MultivariateNormal(loc=torch.zeros(z.shape[1]),
        #                                                                   covariance_matrix=torch.eye(z.shape[1]))
        # momenta = mvn_p.sample((z.shape[0],))
        momenta = torch.randn_like(z)

        z_before = deepcopy(z)

        self.posterior.clear_cache()
        logp_before = self.posterior.value(self.beta, x)
        logj_backward_before = self.interface.logj_backward(z)
        hamiltonian_before = 0.5 * torch.sum(momenta ** 2, dim=1) - logp_before - logj_backward_before

        x, z, momenta = self.tuning_leapfrog(x=x, z=z, momenta=momenta, mass=mass, hmc_step_size=hmc_step_size,
                                             num_leapfrog=num_leapfrog)

        self.posterior.clear_cache()
        logp_after = self.posterior.value(self.beta, x)
        logj_backward_after = self.interface.logj_backward(z)
        hamiltonian_after = 0.5 * torch.sum(momenta ** 2, dim=1) - logp_after - logj_backward_after

        log_mh = torch.clamp(hamiltonian_before - hamiltonian_after, min=None, max=0.0)
        log_tuning_metric = self.hmc_tuning_metric(z_leap=z, z_init=z_before, log_mh=log_mh, mass=mass,
                                                   num_leapfrog=num_leapfrog)

        log_tuning_metric -= torch.logsumexp(log_tuning_metric, dim=0)
        tuning_weights = torch.exp(log_tuning_metric)
        tuning_weights /= tuning_weights.sum()
        resampling_idx = self.systematic_resampling(tuning_weights)

        probs_leapfrog = torch.zeros(leapfrog_max)
        resampled_leapfrog = num_leapfrog[resampling_idx]
        for k in range(leapfrog_max):
            probs_leapfrog[k] = resampled_leapfrog[resampled_leapfrog == k + 1].shape[0]
        probs_leapfrog /= probs_leapfrog.sum()

        return hmc_step_size[resampling_idx], probs_leapfrog

    def hmc_ft_tuning(self, hmc_step_size, num_leapfrog, z_leap, z_init, log_mh, mass):
        # FT tuning algorithm in Buchholz et al. (2020)
        log_tuning_metric = self.hmc_tuning_metric(z_leap=z_leap, z_init=z_init, log_mh=log_mh, mass=mass,
                                                   num_leapfrog=num_leapfrog)
        log_tuning_metric -= torch.logsumexp(log_tuning_metric, dim=0)
        tuning_weights = torch.exp(log_tuning_metric)
        tuning_weights /= tuning_weights.sum()
        resampling_idx = self.systematic_resampling(tuning_weights)

        hmc_step_size = hmc_step_size[resampling_idx]
        num_leapfrog = num_leapfrog[resampling_idx]
        hmc_step_size += torch.randn_like(hmc_step_size) * 0.015
        hmc_step_size = torch.clamp(hmc_step_size, min=1.0e-6, max=None)

        '''
        for i in range(hmc_step_size.shape[0]):
            hmc_step_size[i] = torch.nn.init.trunc_normal_(hmc_step_size[i], mean=hmc_step_size[i], std=0.0015,
                                                           a=0.0, b=100.0)
        '''
        delta_leapfrog = torch.tensor([-1, 0, 1], dtype=torch.long)
        perturb_leapfrog_idx = torch.distributions.categorical.Categorical(probs=torch.ones(3) / 3.0).sample(
            (num_leapfrog.shape[0],))
        num_leapfrog += delta_leapfrog[perturb_leapfrog_idx]
        num_leapfrog = torch.clamp(num_leapfrog, min=1)

        return hmc_step_size, num_leapfrog

    def tuning_leapfrog(self, x, z, momenta, mass, hmc_step_size, num_leapfrog):
        # Leapfrog integrator for particle ensemble.
        max_L = torch.amax(num_leapfrog)
        active = torch.arange(x.shape[0])
        for leap in range(1, max_L + 1):
            active = active[torch.where(num_leapfrog > leap)[0]]
            num_leapfrog = num_leapfrog[torch.where(num_leapfrog > leap)[0]]
            grad_x_logp = self.interface.grad_x_logq(x[active])
            grad_z_logp = self.interface.grad_z_logp(z[active], grad_wrt_x=grad_x_logp) + self.interface.grad_z_logj(
                z[active])
            momenta[active] = momenta[active] + 0.5 * hmc_step_size[active].reshape(active.shape[0], 1) * grad_z_logp

            z[active] = z[active] + hmc_step_size[active].reshape(active.shape[0], 1) * momenta[active]
            x[active] = self.interface.inverse(z[active])

            grad_x_logp = self.interface.grad_x_logq(x[active])
            grad_z_logp = self.interface.grad_z_logp(z[active], grad_wrt_x=grad_x_logp) + self.interface.grad_z_logj(
                z[active])
            momenta[active] = momenta[active] + 0.5 * hmc_step_size[active].reshape(active.shape[0], 1) * grad_z_logp

        return x, z, momenta

    def integrate_leapfrog(self, x, z, momenta, mass, hmc_step_size, num_leapfrog):
        # Leapfrog integrator for particle ensemble.
        # max_L = torch.amax(num_leapfrog)
        # active = torch.arange(x.shape[0])
        for leap in range(num_leapfrog):
            # active = active[torch.where(num_leapfrog > leap)[0]]
            # num_leapfrog = num_leapfrog[torch.where(num_leapfrog > leap)[0]]
            grad_x_logp = self.interface.grad_x_logq(x)
            grad_z_logp = self.interface.grad_z_logp(z, grad_wrt_x=grad_x_logp) + self.interface.grad_z_logj(z)
            momenta = momenta + 0.5 * hmc_step_size.reshape(hmc_step_size.shape[0], 1) * grad_z_logp

            z = z + hmc_step_size.reshape(hmc_step_size.shape[0], 1) * momenta
            x = self.interface.inverse(z)

            grad_x_logp = self.interface.grad_x_logq(x)
            grad_z_logp = self.interface.grad_z_logp(z, grad_wrt_x=grad_x_logp) + self.interface.grad_z_logj(z)
            momenta = momenta + 0.5 * hmc_step_size.reshape(hmc_step_size.shape[0], 1) * grad_z_logp

        return x, z, momenta

    def surrogate_hmc(self,
                      x: torch.Tensor,
                      hmc_step_size: float,
                      num_leapfrog: int = 10,
                      hmc_retrain_prob_decay: float = 10.0,
                      correlation_threshold: float = 0.7,
                      surrogate_mh: bool = False,
                      surrogate_train_kwargs: dict = None):
        if surrogate_mh:
            logp_old = self.surrogate_interface.logq(x)
        elif not surrogate_mh:
            self.posterior.clear_cache()
            logp_old = self.posterior.value(self.beta, x)
            self.surrogate_interface.train_flow(x, logp_old)

        z, logj_forward_old = self.interface.forward_with_logj(x)
        z = z.reshape(len(z), -1)
        old_z = deepcopy(z)
        old_x = deepcopy(x)
        old_r = 2.0
        correlation = self.pearson(z, old_z)
        hmc_step = 0
        mass = 1.0 / torch.var(z, dim=0)
        # mass = torch.ones(z.shape[1])

        x_train = x.clone()
        logp_train = logp_old.clone()
        logj_train = logj_forward_old.clone()

        acceptance_rate = []
        replaced = torch.zeros(len(x), dtype=torch.bool, device=x.device)
        for i in range(1):
            hmc_step += 1

            current_logp = logp_train - logj_train

            # Retrain the surrogate with diminishing probability.
            retrain_prob = np.exp(-(hmc_step - 1.0) / hmc_retrain_prob_decay)
            if np.random.randn() < retrain_prob and not surrogate_mh:
                self.surrogate_interface.train_flow(x_train, logp_train, **surrogate_train_kwargs)

            if hmc_step > 1:
                hmc_step_size = self.tune_hmc(hmc_step_size, acceptance_rate[-1:].pop(), hmc_step)

            momenta = torch.randn_like(z)
            current_hamiltonian = 0.5 * torch.sum(momenta ** 2, dim=1) - current_logp
            z_new = deepcopy(z)
            x_new = deepcopy(x)

            for leap in range(num_leapfrog):
                grad_x_logp = self.surrogate_interface.grad_x_logq(x_new)
                grad_z_logp = self.interface.grad_z_logp(z_new, grad_wrt_x=grad_x_logp) + self.interface.grad_z_logj(
                    z_new)
                momenta = momenta + 0.5 * hmc_step_size.reshape(hmc_step_size.shape[0], 1) * grad_z_logp

                z_new = z_new + hmc_step_size.reshape(hmc_step_size.shape[0], 1) * momenta
                x_new = self.interface.inverse(z_new)

                grad_x_logp = self.surrogate_interface.grad_x_logq(x_new)
                grad_z_logp = self.interface.grad_z_logp(z_new, grad_wrt_x=grad_x_logp) + self.interface.grad_z_logj(
                    z_new)
                momenta = momenta + 0.5 * hmc_step_size.reshape(hmc_step_size.shape[0], 1) * grad_z_logp

            if surrogate_mh:
                logp_new = self.surrogate_interface.logq(x_new)
            elif not surrogate_mh:
                self.posterior.clear_cache()
                logp_new = self.posterior.value(self.beta, x_new)
            logj_backward_new = self.interface.logj_backward(z_new)
            new_hamiltonian = 0.5 * torch.sum(momenta ** 2, dim=1) - logp_new - logj_backward_new

            logr = current_hamiltonian - new_hamiltonian
            accept = torch.log(torch.rand(len(x))) < logr
            acceptance_rate.append(torch.sum(accept).item() / len(x))
            replaced = replaced + accept

            if hmc_step == 1:
                z_leap = z_new
                log_mh_leap = logr

            old_x[accept] = x_new[accept]
            logp_old[accept] = logp_new[accept]
            z[accept] = z_new[accept]
            logj_forward_old[accept] = -logj_backward_new[accept]

            x_train = old_x.clone()
            logp_train = logp_old.clone()
            logj_train = logj_forward_old.clone()

            x = old_x.clone()

            if surrogate_mh:
                break
            elif not surrogate_mh:
                new_r = self.pearson(z, old_z)
                if torch.mean(new_r) < correlation_threshold or hmc_step > 10:
                    break
                '''
                delta_r = (old_r - new_r) > correlation_threshold
                if torch.mean(delta_r.bool().float()) > 0.9:
                    old_r = new_r
                else:
                    break
                '''

        return x, logp_old, acceptance_rate, torch.sum(replaced).item() / len(x), hmc_step, old_z, z_leap, log_mh_leap, mass, hmc_step_size

    def tune_ulm(self,
                 step_size: torch.Tensor,
                 log_mh: torch.Tensor):
        log_mh = torch.clamp(step_size, min=None, max=0.0)
        particle_scales = torch.exp(torch.log(step_size) + (torch.exp(log_mh) - 0.98))
        return 0.5 * (particle_scales + torch.mean(particle_scales))

    def generate_alum_noise(self,
                            ndim: int,
                            step_size: float,
                            gamma: float,
                            alpha: torch.Tensor,
                            taylor: bool):
        # Generate the noise terms for ALUM version of underdamped Langevin.
        e_xh = torch.zeros((alpha.shape[0], ndim))
        e_vh = torch.zeros((alpha.shape[0], ndim))
        e_xah = torch.zeros((alpha.shape[0], ndim))

        for i, a in enumerate(alpha):

            if not taylor:
                cov_xtxt = (2.0 * gamma * step_size[i] - 3.0 + 4.0 * np.exp(-gamma * step_size[i]) - np.exp(
                    -2.0 * gamma * step_size[i])) / gamma ** 2
                cov_xtvt = (4.0 * np.sinh(gamma * step_size[i] / 2.0) ** 2 * np.exp(-gamma * step_size[i])) / gamma
                cov_vtvt = (1.0 - np.exp(-2.0 * gamma * step_size[i]))
                cov_xtxat = (2.0 * a * gamma * step_size[i] - 2.0 - 4.0 * np.exp(-gamma * step_size[i]) * np.sinh(
                    a * gamma * step_size[i] / 2.0) ** 2 + 2.0 * np.exp(-a * gamma * step_size[i])) / gamma ** 2
                cov_vtxat = 4.0 * np.sinh(a * gamma * step_size[i] / 2.0) ** 2 * np.exp(
                    -gamma * step_size[i]) / gamma
                cov_xatxat = (2.0 * a * gamma * step_size[i] - 3.0 + 4.0 * np.exp(-a * gamma * step_size[i]) - np.exp(
                    - 2.0 * a * gamma * step_size[i])) / gamma ** 2
            elif taylor:
                cov_xtxt = 2.0 * (gamma * step_size[i]) ** 3 / 3.0
                cov_xtvt = gamma * step_size[i] ** 2 * (1.0 - gamma * step_size[i])
                cov_vtvt = 2.0 * gamma * step_size[i] * (
                            1.0 - gamma * step_size[i] + 2.0 * (gamma * step_size[i]) ** 2 / 3.0)
                cov_xtxat = a ** 2 * gamma * step_size[i] ** 3 * (1.0 - a / 3.0)
                cov_vtxat = a ** 2 * gamma * step_size[i] ** 2 * (1.0 - gamma * step_size[i])
                cov_xatxat = 2.0 * (a * gamma * step_size[i]) ** 3 / 3.0

            cov_alum = torch.tensor([[cov_xtxt, cov_xtvt, cov_xtxat],
                                     [cov_xtvt, cov_vtvt, cov_vtxat],
                                     [cov_xtxat, cov_vtxat, cov_xatxat]], dtype=torch.float32)

            noise_vec = torch.distributions.MultivariateNormal(loc=torch.zeros(cov_alum.shape[0]),
                                                               covariance_matrix=cov_alum).sample((ndim,)).squeeze()
            e_xh[i, :] = noise_vec[:, 0]
            e_vh[i, :] = noise_vec[:, 1]
            e_xah[i, :] = noise_vec[:, 2]

        return e_xh, e_vh, e_xah

    def underdamped_langevin(self,
                             x: torch.Tensor,
                             v: torch.Tensor,
                             step_size: torch.Tensor,
                             gamma: float = 2.0,
                             taylor: bool = True):
        # TODO: Add step-size tuning based on MH acceptance?
        #z = self.interface.forward(x)
        z, logj_forward_old = self.interface.forward_with_logj(x)

        initial_hamiltonian = -self.surrogate_interface.logq(x) + logj_forward_old + torch.sum(v ** 2, dim=1) / 2.0

        if taylor:
            psi0 = lambda h: 1.0 - (gamma * h) + (gamma * h) ** 2 / 2.0 - (gamma * h) ** 3 / 6.0
            psi1 = lambda h: h * (1 - 0.5 * (gamma * h) + (gamma * h) ** 2 / 6.0)
        elif not taylor:
            psi0 = lambda h: torch.exp(-gamma * h)
            psi1 = lambda h: (1.0 - torch.exp(-gamma * h)) / gamma

        # ALUM update scheme
        alpha = torch.clamp(torch.rand(size=(x.shape[0],)), min=1.0e-6, max=1.0)
        e_zh, e_vh, e_zah = self.generate_alum_noise(x.shape[1], step_size, gamma, alpha, taylor)

        z_e = z + psi1(alpha * step_size).reshape(v.shape[0], 1) * v + e_zah
        x_e = self.interface.inverse(z_e)
        grad_x_e_U = -self.surrogate_interface.grad_x_logq(x_e)
        grad_z_e_U = self.interface.grad_z_logp(z, grad_wrt_x=grad_x_e_U) + self.interface.grad_z_logj(z_e)

        z += psi1(torch.tensor(step_size)).reshape(x.shape[0], 1) * v - step_size.reshape(x.shape[0], 1) * psi1(
            step_size - alpha * step_size).reshape(x.shape[0], 1) * grad_z_e_U + e_zh
        #x = self.interface.inverse(z)
        x, logj_backward_new = self.interface.inverse_with_logj(z)
        v = psi0(torch.tensor(step_size)).reshape(x.shape[0], 1) * v - step_size.reshape(x.shape[0], 1) * psi0(
            step_size - alpha * step_size).reshape(v.shape[0], 1) * grad_z_e_U + e_vh

        updated_hamiltonian = -self.surrogate_interface.logq(x) - logj_backward_new + torch.sum(v ** 2, dim=1) / 2.0
        log_mh = initial_hamiltonian - updated_hamiltonian

        return x, v, log_mh

    def initialize_optimizer(self,
                             particles: torch.Tensor,
                             particle_logp: torch.Tensor,
                             particle_logq: torch.Tensor,
                             logp_func: callable,
                             logq_func: callable,
                             step: float,
                             optimizer: str,
                             optim_scheduler: str,
                             exp_decay_rate: float,
                             cos_T_max: int,
                             cos_lr_min: float):

        if optimizer == 'adam':
            particle_optimizer = ParticleAdam(particles=particles, lr=step)
        elif optimizer == 'grad_descent':
            particle_optimizer = ParticleGradientDescent(particles=particles, lr=step)
        elif optimizer == 'line_search':
            self.posterior.clear_cache()
            particle_optimizer = ParticleLineSearch(particles=particles,
                                                    particle_logp=particle_logp,
                                                    particle_logq=particle_logq,
                                                    posterior=self.posterior,
                                                    interface=self.interface,
                                                    logp_func=logp_func,
                                                    logq_func=logq_func,
                                                    lr=step)

        if optim_scheduler == 'identity':
            particle_scheduler = IdentityScheduler(particle_optimizer)
        elif optim_scheduler == 'exp':
            particle_scheduler = ExponentialDecayScheduler(particle_optimizer, exp_decay_rate)
        elif optim_scheduler == 'cosine':
            particle_scheduler = CosineAnnealingScheduler(particle_optimizer, cos_T_max, cos_lr_min)

        return particle_optimizer, particle_scheduler

    def stochastic_calibration(self,
                               particles: torch.Tensor,
                               step: float,
                               beta: float,
                               taylor: bool = False):

        self.posterior.clear_cache()
        eta = torch.randn_like(particles) * torch.sqrt(2.0 * step)
        grad_U = -self.posterior.gradient(beta, particles)
        grad_U = self.interface.grad_z_logp(particles, grad_wrt_x=grad_U) + self.interface.grad_z_logj(particles)
        grad_V = deepcopy(particles)

        if not taylor:

            self.posterior.clear_cache()
            current_logp = self.posterior.value(beta, particles)
            stochastic = particles - step * grad_U + eta
            self.posterior.clear_cache()
            delta_U_stochastic = self.posterior.value(beta, stochastic) - current_logp

            def objective(zeta):
                modified_dlmc = particles - step * grad_U + step * zeta * grad_V
                self.posterior.clear_cache()
                delta_U_dlmc = self.posterior.value(beta, modified_dlmc) - current_logp
                return float(torch.mean(delta_U_dlmc) - torch.mean(delta_U_stochastic))

            try:
                sol = root_scalar(
                    f=objective
                )

                if not sol.converged:
                    warnings.warn(f"Zeta search did not converge. Setting to 1 (standard DLMC).")
                    calib_zeta = 1.0
                else:
                    calib_zeta = sol.root
                    print(f'Stochastic calibration zeta: {calib_zeta}')
            except:
                print('Zeta could not be determined. Setting to 1 (standard DLMC).')
                calib_zeta = 1.0

        elif taylor:

            calib_zeta = torch.sum(torch.linalg.norm(eta * grad_U, dim=1) ** 2) / torch.sum(
                step * torch.linalg.norm(grad_U * grad_V, dim=1) ** 2)
            print(f'Taylor calibration zeta: {calib_zeta}')

        return calib_zeta

    def upsample(self, x, z, logp, num_upsample):
        new_particles = self.interface.sample(num_upsample)
        new_z = self.interface.forward(new_particles)
        x = torch.cat([x, new_particles])
        z = torch.cat([z, new_z])
        _, logj_backward_new = self.interface.inverse_with_logj(new_z)

        self.posterior.clear_cache()
        new_logp = self.posterior.value(self.beta, new_particles) + logj_backward_new
        logp = torch.cat([logp, new_logp])
        logq = torch.distributions.MultivariateNormal(loc=torch.zeros(z.shape[1]),
                                                      covariance_matrix=torch.eye(z.shape[1])).log_prob(z)
        return x, z, logp, logq

    def run(self,
            x: torch.Tensor,
            surrogate_method: str = 'GP',
            step: float = 1e-2,
            burnin_optimizer: str = "adam",
            optimizer: str = "adam",
            optim_scheduler: str = "identity",
            exp_decay_rate: float = 0.999,
            cos_T_max: int = 100,
            cos_lr_min: float = 0.0,
            burnin_optim_steps: int = 1,
            optim_steps: int = 1,
            atol: float = 1e-15,
            max_burnin: int = 20,
            burnin_thresh: float = 0.5,
            burnin_upsample: int = 0,
            virial_threshold: float = 1.0,
            upsample_schedule: dict = None,
            max_iterations: int = 100,
            IMHstep: int = 0,
            do_CMH: bool = True,
            do_surrogate_hmc: bool = True,
            do_surrogate_ulm: bool = True,
            hmc_step_max: float = 0.1,
            leapfrog_max: int = 100,
            hmc_retrain_prob_decay: float = 10.0,
            mcmc_correlation_threshold: float = 0.7,
            stochastic_calibration: bool = False,
            taylor_zeta: bool = True,
            use_tqdm: bool = False,
            flow_train_kwargs: dict = None,
            surrogate_train_kwargs: dict = None,
            underdamped_kwargs: dict = None,
            animate_kwargs: dict = None) -> torch.Tensor:
        """
        Run DLA.

        :param x: prior samples with shape (n, d).
        :param surrogate_method: whether to use GP or NF to construct target surrogates.
        :param step: step size for sample updates.
        :param optimizer: particle optimizer to use for DLMC updates.
        :param optim_scheduler: optimization step size scheduler.
        :param exp_decay_rate: decay rate for the exponential scheduler.
        :param cos_T_max: total number of epochs for the cosine scheduler.
        :param cos_lr_min: minimum learning rate for cosine scheduler.
        :param optim_steps: number of optimizer steps to take at each DLMC iteration.
        :param atol: tolerance parameter when checking for convergence.
        :param burnin_upsample: number of samples to add after burnin.
        :param virial_threshold: threshold * ndim factor at which we determine DLMC has virialized.
        :param upsample_schedule: dictionary with iterations as keys, and number of particles to add as values.
        :param max_iterations: maximum number of DLA iterations for each beta level.
        :param IMHstep: number of IMH steps to apply after DLMC.
        :param do_CMH: whether to apply CMH steps after DLMC.
        :param do_surrogate_hmc: whether to do surrogate HMC after DLMC.
        :param do_surrogate_ulm: whether to do surrogate ULM after DLMC.
        :param hmc_step_max: initial maximum step-size used in HMC pre-tuning.
        :param leapfrog_max: initial maximum number of HMC leapfrog steps used in pre-tuning.
        :param hmc_retrain_prob_decay: decay rate of the re-training probability during surrogate HMC.
        :param mcmc_correlation_threshold: correlation threshold for terminating MCMC.
        :param stochastic_calibration: whether to do stochastic calibration of the particle density gradient.
        :param taylor_zeta: whether to use the linearized correction for stochastic calibration.
        :param use_tqdm: use a tqdm progress bar.
        :param flow_train_kwargs: keyword arguments for interface.train_flow.
        :param surrogate_train_kwargs: keyword arguments for surrogate_interface.train_flow.
        :param ulm_kwargs: keyword arguments for underdamped_langevin.
        :param animate_kwargs: keyword arguments for visualizer.animate.

        :return: torch.Tensor with posterior samples.
        """
        if surrogate_method not in ('GP', 'NF'):
            raise ValueError("surrogate_method must be either GP or NF.")
        if step <= 0.0:
            raise ValueError("step must be positive")
        if burnin_optimizer not in ("adam", "grad_descent", "line_search"):
            raise ValueError("burnin_optimizer must be one of adam, grad_descent or line_search")
        if optimizer not in ("adam", "grad_descent", "line_search"):
            raise ValueError("optimizer must be one of adam, grad_descent or line_search")
        if optim_scheduler not in ("identity", "exp", "cosine"):
            raise ValueError("optim_scheduler must be one of identity, exp or cosine")
        if exp_decay_rate <= 0.0 or exp_decay_rate > 1.0:
            raise ValueError("exp_decay rate must be between 0 and 1")
        if cos_T_max <= 0:
            raise ValueError("cos_T_max must be positive")
        if cos_lr_min < 0:
            raise ValueError("cos_lr_min must be greater than or equal to 0")
        if burnin_optim_steps < 1:
            raise ValueError("burnin_optim_steps must be at least 1")
        if optim_steps < 1:
            raise ValueError("optim_steps must be at least 1")
        if atol <= 0.0:
            raise ValueError("atol must be positive")
        if virial_threshold < 1.0:
            raise ValueError("virial_threshold must be greater than or equal to 1.")
        if max_iterations <= 0:
            raise ValueError("max_iterations must be positive")
        if hmc_step_max <= 0.0:
            raise ValueError("hmc_step_max must be greater than 0.")
        if leapfrog_max < 1:
            raise ValueError("leapfrog_max must be greater than or equal to 1.")
        if hmc_retrain_prob_decay <= 0.0:
            raise ValueError("hmc_retrain_prob_decay must be grater than 0.")
        if mcmc_correlation_threshold <= 0.0 or mcmc_correlation_threshold >= 1.0:
            raise ValueError("mcmc_correlation_threshold must be between 0 and 1.")
        if self.interface.flow is None:
            warnings.warn("Flow is None. Did you forget to create the flow using create_flow?")
        if flow_train_kwargs is None:
            flow_train_kwargs = dict()
        if surrogate_train_kwargs is None:
            surrogate_train_kwargs = dict()
        if underdamped_kwargs is None:
            underdamped_kwargs = dict()
        if animate_kwargs is None:
            animate_kwargs = dict()

        x = deepcopy(x)  # Make a copy of the initial particles, these will be changed during DLA
        initial_step = step

        latent_samples = torch.randn_like(x)  # Only for visualization purposes

        # Initialize CMHscale at theoretical optimum
        CMHscale = torch.tensor(2.38 / x.shape[1] ** 0.5)

        hmc_step_size = hmc_step_max / x.shape[1] ** 0.25 * torch.ones(x.shape[0])
        hmc_iter = 0
        ulm_step_size = hmc_step_max / x.shape[1] ** 0.25 * torch.ones(x.shape[0])
        ulm_iter = 0

        beta_stage = 0
        old_beta = self.beta
        while self.beta <= 1:

            update_norm = torch.inf  # Just so we define this even if no iterations are executed.
            step = initial_step  # Reset the step size for each beta level.
            dlmc_stage = 0
            num_burnin = 0
            burnin = True
            upsample_number = 0
            total_iterations = max_burnin + max_iterations

            for iteration in (pbar := (tqdm(range(total_iterations)) if use_tqdm else range(total_iterations))):
                # Clear the cached log_likelihood, log_prior, and gradients
                self.posterior.clear_cache()

                if dlmc_stage > max_iterations:
                    print(f'Reached maximum number of post-burnin DLMC iterations.')
                    break

                if not burnin:
                    dlmc_stage += 1
                elif burnin:
                    num_burnin += 1
                    if num_burnin == max_burnin:
                        print(f'Reached maximum number of burnin iterations.')
                        burnin = False
                        continue

                if num_burnin == max_burnin and burnin is False:
                    if burnin_upsample > 0:
                        print(f'Drawing {burnin_upsample} new particles from current q for post-burnin DLMC.')
                        x, z, current_logp, current_logq = self.upsample(x, z, current_logp, burnin_upsample)
                        num_burnin = torch.inf
                        hmc_step_size = torch.cat(
                            [hmc_step_size, torch.mean(hmc_step_size) * torch.ones(burnin_upsample)])

                if upsample_schedule is not None:
                    if dlmc_stage in list(upsample_schedule.keys()):
                        print(f'Drawing {list(upsample_schedule.values())[upsample_number]} new samples from q.')
                        num_upsample = list(upsample_schedule.values())[upsample_number]
                        x, z, current_logp, current_logq = self.upsample(x, z, current_logp, num_upsample)
                        upsample_number += 1
                        hmc_step_size = torch.cat([hmc_step_size, torch.mean(hmc_step_size) * torch.ones(num_upsample)])

                self.interface.train_flow(x, **flow_train_kwargs)
                z = self.interface.forward(x)
                _, logj_backward = self.interface.inverse_with_logj(z)

                #if iteration == 0:
                self.posterior.clear_cache()
                train_logp = self.posterior.value(self.beta, x)
                current_logp = train_logp + logj_backward
                current_logq = torch.distributions.MultivariateNormal(loc=torch.zeros(z.shape[1]),
                                                                      covariance_matrix=torch.eye(z.shape[1])).log_prob(z)

                if surrogate_method == 'GP':
                    self.surrogate_interface.train_flow(x, train_logp, **surrogate_train_kwargs)
                elif surrogate_method == 'NF':
                    if iteration == 0:
                        logw = (self.beta - old_beta) * self.posterior.log_likelihood(x)
                        logw -= torch.logsumexp(logw, dim=0)
                        train_weights = torch.exp(logw)
                        train_weights /= torch.sum(train_weights)
                        np.savetxt('./train_weights.txt', train_weights)
                        self.surrogate_interface.train_flow(z, weights=train_weights, **surrogate_train_kwargs)

                if burnin:
                    grad_x_U = -self.surrogate_interface.grad_x_logq(x)
                    virial = torch.mean(torch.abs(x * grad_x_U), dim=0)
                    if torch.sum(virial) <= virial_threshold * z.shape[1] and torch.sum(virial) > z.shape[0]:
                        print(f'Reached virial threshold during burnin.')
                        print(f'sum(<zgradU>)={torch.sum(virial)}')
                        burnin = False
                        num_burnin = max_burnin
                        continue

                # Update samples
                x_before_update = deepcopy(x)

                grad_x_U = -self.surrogate_interface.grad_x_logq(x)
                grad_z_U = self.interface.grad_z_logp(z, grad_wrt_x=grad_x_U) + self.interface.grad_z_logj(z)
                grad = grad_z_U - z
                grad_norm = torch.linalg.norm(grad, dim=1)
                sample_sigma = torch.mean(torch.std(z, dim=0))
                particle_step = step * sample_sigma / torch.mean(grad_norm)
                #particle_step = step / z.shape[1] ** 0.25


                def logp_func(val):
                    x_val = self.interface.inverse(val)
                    return self.posterior.value(self.beta, x_val)
                def logq_func(val):
                    x_val = self.interface.inverse(val)
                    return self.interface.logq(x_val)


                '''
                def logp_func(val):
                    x_val, logj_val = self.interface.inverse_with_logj(val)
                    return self.surrogate_interface.logq(x_val) + logj_val
                logq_func = lambda val: torch.distributions.MultivariateNormal(loc=torch.zeros(val.shape[1]),
                                                                               covariance_matrix=torch.eye(
                                                                                   val.shape[1])).log_prob(val)
                '''

                if burnin:
                    optim_method = burnin_optimizer
                    optim_updates = burnin_optim_steps
                elif not burnin:
                    optim_method = optimizer
                    optim_updates = optim_steps

                particle_optimizer, particle_scheduler = self.initialize_optimizer(particles=z,
                                                                                   step=particle_step,
                                                                                   particle_logp=current_logp,
                                                                                   particle_logq=current_logq,
                                                                                   logp_func=logp_func,
                                                                                   logq_func=logq_func,
                                                                                   optimizer=optim_method,
                                                                                   optim_scheduler=optim_scheduler,
                                                                                   exp_decay_rate=exp_decay_rate,
                                                                                   cos_T_max=cos_T_max,
                                                                                   cos_lr_min=cos_lr_min)

                grad_x_U = -self.surrogate_interface.grad_x_logq(x)
                old_virial = torch.mean(torch.abs(x * grad_x_U), dim=0)
                optim_iters = 0
                for i in range(optim_updates):
                    optim_iters += 1
                    if stochastic_calibration:
                        calib_zeta = self.stochastic_calibration(particles=z,
                                                                 step=particle_step,
                                                                 beta=self.beta,
                                                                 taylor=taylor_zeta)
                    else:
                        calib_zeta = 1.0

                    grad_x_U = -self.surrogate_interface.grad_x_logq(x)
                    grad_z_U = self.interface.grad_z_logp(z, grad_wrt_x=grad_x_U) + self.interface.grad_z_logj(z)
                    grad = grad_z_U - calib_zeta * z
                    particle_optimizer.step(grad)
                    particle_scheduler.step()
                    x = self.interface.inverse(z)

                    grad_x_U = -self.surrogate_interface.grad_x_logq(x)
                    new_virial = torch.mean(torch.abs(x * grad_x_U), dim=0)
                    if torch.abs(torch.sum(new_virial) - x.shape[1]) > torch.abs(torch.sum(old_virial) - x.shape[1]):
                        break
                    else:
                        old_virial = new_virial

                print(f'DLMC optim iters = {optim_iters}')

                update_values = x - x_before_update
                update_norm = torch.max(torch.abs(update_values))  # Using max norm as a worst case precaution

                # Check convergence
                if update_norm < atol:
                    warnings.warn('Particle updates are too small. Stopping.')
                    break

                if IMHstep:
                    acc_rate_IMH = []
                    replaced_IMH = torch.zeros(len(x), dtype=torch.bool, device=x.device)
                    for _ in range(IMHstep):
                        x, accept, train_logp = self.IMH(x)
                        replaced_IMH = replaced_IMH + accept
                        acc_rate_IMH.append(torch.sum(accept).item() / len(x))
                        total_acc_rate_IMH = torch.sum(replaced_IMH).item() / len(x)

                if do_CMH:
                    self.interface.train_flow(x, **flow_train_kwargs)
                    cmh_output = self.CMH(x=x,
                                          scale=CMHscale,
                                          correlation_threshold=mcmc_correlation_threshold)
                    x, acc_rate_CMH, total_acc_rate_CMH, CMHscale, cmh_steps = cmh_output

                if do_surrogate_hmc:
                    self.interface.train_flow(x, **flow_train_kwargs)
                    if not IMHstep:
                        self.posterior.clear_cache()
                        train_logp = self.posterior.value(self.beta, x)
                    self.surrogate_interface.train_flow(x, train_logp, **surrogate_train_kwargs)

                    num_leapfrog = leapfrog_max
                    hmc_iter += 1
                    z = self.interface.forward(x)
                    z_before_hmc = deepcopy(z)
                    correlation = self.pearson(z_before_hmc, z_before_hmc)

                    if hmc_iter > 1:
                        hmc_step_size = self.tune_hmc(hmc_step_size, acc_rate_hmc[-1:].pop(), 1)

                    hmc_count = 0
                    while torch.mean(correlation) > mcmc_correlation_threshold and hmc_count < 1:
                        hmc_count += 1
                        surrogate_mh = False

                        hmc_output = self.surrogate_hmc(x=x,
                                                        hmc_step_size=hmc_step_size,
                                                        num_leapfrog=num_leapfrog,
                                                        hmc_retrain_prob_decay=hmc_retrain_prob_decay,
                                                        correlation_threshold=mcmc_correlation_threshold,
                                                        surrogate_mh=surrogate_mh,
                                                        surrogate_train_kwargs=surrogate_train_kwargs)
                        x, train_logp, acc_rate_hmc, total_acc_rate_hmc, hmc_steps, z_init, z_leap, log_mh_leap, old_mass, hmc_step_size = hmc_output
                        z = self.interface.forward(x)
                        correlation = self.pearson(z, z_before_hmc)

                if do_surrogate_ulm:
                    #self.interface.train_flow(x, **flow_train_kwargs)
                    z = self.interface.forward(x)
                    #self.posterior.clear_cache()
                    #train_logp = self.posterior.value(self.beta, x)
                    self.surrogate_interface.train_flow(x, train_logp, **surrogate_train_kwargs)
                    z_before_ulm = deepcopy(z)
                    v = torch.randn_like(x)

                    correlation = self.pearson(z_before_ulm, z_before_ulm)
                    ulm_iter += 1

                    '''
                    if ulm_iter > 1:
                        ulm_step_size = self.tune_ulm(ulm_step_size, log_mh_ulm)
                    '''
                    ulm_steps = 0
                    while torch.mean(correlation) > mcmc_correlation_threshold and ulm_steps < leapfrog_max:
                        ulm_steps += 1
                        x, v, log_mh_ulm = self.underdamped_langevin(x=x, v=v, step_size=ulm_step_size,
                                                                     **underdamped_kwargs)
                        z = self.interface.forward(x)
                        correlation = self.pearson(z, z_before_ulm)

                if use_tqdm:
                    pbar.set_description(f'[Stage {beta_stage}] Beta: {self.beta}')
                    if IMHstep:
                        pbar.set_postfix(IMH=total_acc_rate_IMH)
                    if do_CMH:
                        pbar.set_postfix(CMH=total_acc_rate_CMH, steps=cmh_steps)
                    if do_surrogate_hmc:
                        pbar.set_postfix(HMC=torch.mean(torch.tensor(acc_rate_hmc)), steps=hmc_count)
                    if do_surrogate_ulm:
                        pbar.set_postfix(ulm_steps=ulm_steps, r=torch.mean(correlation))

            self.log_iteration_data(x, latent_samples)

            # Make the animation
            self.create_beta_stage_animation(beta_stage, **animate_kwargs)

            # Compute new beta
            old_beta = self.beta
            self.set_new_beta(x=x)
            if self.finished:
                break
            if torch.isinf(update_norm):
                warnings.warn("Update norm is infinite. Stopping.")
                break

            beta_stage += 1
            if self.debugger is not None:
                self.debugger.stage_step()

        self.interface.train_flow(x, **flow_train_kwargs)
        logq = self.interface.logq(x)
        self.posterior.clear_cache()
        logp = self.posterior.value(self.beta, x)
        logw = logp - logq
        logw -= torch.logsumexp(logw, dim=0)

        return x, logw


class LatentDLA(DLA):
    def __init__(self,
                 interface: TorchFlowInterface,
                 burnin_interface: TorchFlowInterface,
                 posterior: DifferentiableTemperedPosterior,
                 beta_handler: BetaHandler = SingleStageBetaHandler(),
                 debugger: Optional[MultiStageDebugger] = None):
        super().__init__(interface, posterior, beta_handler, debugger)
        self.burnin_interface = burnin_interface

    def IMH(self, x, nf_interface):
        self.posterior.clear_cache()
        logp_old = self.posterior.value(self.beta, x)
        logq_old = nf_interface.logq(x)
        sample, logq_new = nf_interface.sample_with_logq(len(x))
        self.posterior.clear_cache()
        logp_new = self.posterior.value(self.beta, sample)

        logr = logp_new + logq_old - logp_old - logq_new

        accept = torch.log(torch.rand(len(x))) < logr
        x[accept] = sample[accept]
        return x, accept

    def CMH(self, x, scale, step, nf_interface):
        # TODO: implement proposal scale and step size tuning ala PyMC3
        self.posterior.clear_cache()
        logp_old = self.posterior.value(self.beta, x)

        z, logj_forward_old = nf_interface.forward_with_logj(x)
        z = z.reshape(len(z), -1)

        acceptance_rate = []
        replaced = torch.zeros(len(x), dtype=torch.bool, device=x.device)
        for _ in range(step):
            z_new = z + torch.randn_like(z) * scale
            x_new, logj_backward_new = nf_interface.inverse_with_logj(z_new)

            self.posterior.clear_cache()
            logp_new = self.posterior.value(self.beta, x_new)

            logr = (logp_new + logj_backward_new) - (logp_old - logj_forward_old)

            accept = torch.log(torch.rand(len(x))) < logr
            acceptance_rate.append(torch.sum(accept).item() / len(x))
            replaced = replaced + accept

            x[accept] = x_new[accept]
            logp_old[accept] = logp_new[accept]
            z[accept] = z_new[accept]
            logj_forward_old[accept] = -logj_backward_new[accept]

        return x, acceptance_rate, torch.sum(replaced).item() / len(x)

    def generate_alum_noise(self,
                            ndim: int,
                            step_size: float,
                            gamma: float,
                            alpha: torch.Tensor,
                            taylor: bool):
        # Generate the noise terms for ALUM version of underdamped Langevin.
        e_xh = torch.zeros((alpha.shape[0], ndim))
        e_vh = torch.zeros((alpha.shape[0], ndim))
        e_xah = torch.zeros((alpha.shape[0], ndim))

        for i, a in enumerate(alpha):

            if not taylor:
                cov_xtxt = (2.0 * gamma * step_size - 3.0 + 4.0 * np.exp(-gamma * step_size) - np.exp(
                    -2.0 * gamma * step_size)) / gamma ** 2
                cov_xtvt = (4.0 * np.sinh(gamma * step_size / 2.0) ** 2 * np.exp(-gamma * step_size)) / gamma
                cov_vtvt = (1.0 - np.exp(-2.0 * gamma * step_size))
                cov_xtxat = (2.0 * a * gamma * step_size - 2.0 - 4.0 * np.exp(-gamma * step_size) * np.sinh(
                    a * gamma * step_size / 2.0) ** 2 + 2.0 * np.exp(-a * gamma * step_size)) / gamma ** 2
                cov_vtxat = 4.0 * np.sinh(a * gamma * step_size / 2.0) ** 2 * np.exp(
                    -gamma * step_size) / gamma
                cov_xatxat = (2.0 * a * gamma * step_size - 3.0 + 4.0 * np.exp(-a * gamma * step_size) - np.exp(
                    - 2.0 * a * gamma * step_size)) / gamma ** 2
            elif taylor:
                cov_xtxt = 2.0 * (gamma * step_size) ** 3 / 3.0
                cov_xtvt = gamma * step_size ** 2 * (1.0 - gamma * step_size)
                cov_vtvt = 2.0 * gamma * step_size * (1.0 - gamma * step_size + 2.0 * (gamma * step_size) ** 2 / 3.0)
                cov_xtxat = a ** 2 * gamma * step_size ** 3 * (1.0 - a / 3.0)
                cov_vtxat = a ** 2 * gamma * step_size ** 2 * (1.0 - gamma * step_size)
                cov_xatxat = 2.0 * (a * gamma * step_size) ** 3 / 3.0

            cov_alum = torch.tensor([[cov_xtxt, cov_xtvt, cov_xtxat],
                                     [cov_xtvt, cov_vtvt, cov_vtxat],
                                     [cov_xtxat, cov_vtxat, cov_xatxat]], dtype=torch.float32)

            noise_vec = torch.distributions.MultivariateNormal(loc=torch.zeros(cov_alum.shape[0]),
                                                               covariance_matrix=cov_alum).sample((ndim,)).squeeze()
            e_xh[i, :] = noise_vec[:, 0]
            e_vh[i, :] = noise_vec[:, 1]
            e_xah[i, :] = noise_vec[:, 2]

        return e_xh, e_vh, e_xah

    def underdamped_langevin(self,
                             x: torch.Tensor,
                             v: torch.Tensor,
                             nf_interface: TorchFlowInterface,
                             step_size: float = 1e-3,
                             gamma: float = 2.0,
                             taylor: bool = True):
        # TODO: Add step-size tuning based on MH acceptance?
        self.posterior.clear_cache()
        z = nf_interface.forward(x)

        psi0 = lambda h: torch.exp(-gamma * h)
        psi1 = lambda h: (1.0 - torch.exp(-gamma * h)) / gamma

        # ALUM update scheme
        alpha = torch.rand(size=(z.shape[0],))
        e_zh, e_vh, e_zah = self.generate_alum_noise(z.shape[1], step_size, gamma, alpha, taylor)

        z_e = z + psi1(alpha * step_size).reshape(v.shape[0], 1) * v + e_zah
        x_e = nf_interface.inverse(z_e)
        self.posterior.clear_cache()
        grad_x_e_U = -self.posterior.gradient(self.beta, x_e)
        grad_z_e_U = nf_interface.grad_z_logp(z, grad_wrt_x=grad_x_e_U) + nf_interface.grad_z_logj(z_e)

        z += psi1(torch.tensor(step_size)) * v - step_size * psi1(
            step_size - alpha * step_size).reshape(z.shape[0], 1) * grad_z_e_U + e_zh
        x = nf_interface.inverse(z)
        v = psi0(torch.tensor(step_size)) * v - step_size * psi0(
            step_size - alpha * step_size).reshape(v.shape[0], 1) * grad_z_e_U + e_vh

        return x, v

    def initialize_optimizer(self,
                             particles: torch.Tensor,
                             nf_interface: TorchFlowInterface,
                             grad_sq: torch.Tensor,
                             avg_grad: torch.Tensor,
                             particle_logp: torch.Tensor,
                             particle_logq: torch.Tensor,
                             logp_func: callable,
                             logq_func: callable,
                             step: float,
                             optimizer: str,
                             optim_scheduler: str,
                             exp_decay_rate: float,
                             cos_T_max: int,
                             cos_lr_min: float):

        if optimizer == 'adagrad':
            particle_optimizer = ParticleAdagrad(particles=particles, grad_sq=grad_sq, lr=step)
        elif optimizer == 'rmsprop':
            particle_optimizer = ParticleRMSProp(particles=particles, avg_grad=avg_grad, lr=step)
        elif optimizer == 'adam':
            particle_optimizer = ParticleAdam(particles=particles, lr=step)
        elif optimizer == 'grad_descent':
            particle_optimizer = ParticleGradientDescent(particles=particles, lr=step)
        elif optimizer == 'line_search':
            self.posterior.clear_cache()
            particle_optimizer = ParticleLineSearch(particles=particles,
                                                    accumulated_grad=grad_sq,
                                                    particle_logp=particle_logp,
                                                    particle_logq=particle_logq,
                                                    posterior=self.posterior,
                                                    interface=nf_interface,
                                                    logp_func=logp_func,
                                                    logq_func=logq_func,
                                                    lr=step)

        if optim_scheduler == 'identity':
            particle_scheduler = IdentityScheduler(particle_optimizer)
        elif optim_scheduler == 'exp':
            particle_scheduler = ExponentialDecayScheduler(particle_optimizer, exp_decay_rate)
        elif optim_scheduler == 'cosine':
            particle_scheduler = CosineAnnealingScheduler(particle_optimizer, cos_T_max, cos_lr_min)

        return particle_optimizer, particle_scheduler

    def stochastic_calibration(self,
                               particles: torch.Tensor,
                               nf_interface: TorchFlowInterface,
                               step: float,
                               beta: float,
                               latent: bool,
                               taylor: bool = False):

        self.posterior.clear_cache()
        eta = torch.randn_like(particles) * torch.sqrt(2.0 * step)
        grad_U = -self.posterior.gradient(beta, particles)
        if latent:
            grad_U = nf_interface.grad_z_logp(particles, grad_wrt_x=grad_U) + nf_interface.grad_z_logj(particles)
            grad_V = deepcopy(particles)
        elif not latent:
            grad_V = -nf_interface.grad_x_logq(particles)

        if not taylor:

            self.posterior.clear_cache()
            current_logp = self.posterior.value(beta, particles)
            stochastic = particles - step * grad_U + eta
            self.posterior.clear_cache()
            delta_U_stochastic = self.posterior.value(beta, stochastic) - current_logp

            def objective(zeta):
                modified_dlmc = particles - step * grad_U + step * zeta * grad_V
                self.posterior.clear_cache()
                delta_U_dlmc = self.posterior.value(beta, modified_dlmc) - current_logp
                return float(torch.mean(delta_U_dlmc) - torch.mean(delta_U_stochastic))

            try:
                sol = root_scalar(
                    f=objective
                )

                if not sol.converged:
                    warnings.warn(f"Zeta search did not converge. Setting to 1 (standard DLMC).")
                    calib_zeta = 1.0
                else:
                    calib_zeta = sol.root
                    print(f'Stochastic calibration zeta: {calib_zeta}')
            except:
                print('Zeta could not be determined. Setting to 1 (standard DLMC).')
                calib_zeta = 1.0

        elif taylor:

            calib_zeta = torch.sum(torch.linalg.norm(eta * grad_U, dim=1) ** 2) / torch.sum(
                step * torch.linalg.norm(grad_U * grad_V, dim=1) ** 2)
            print(f'Taylor calibration zeta: {calib_zeta}')

        return calib_zeta

    def run(self,
            x: torch.Tensor,
            main_step: float = 1e-2,
            burnin_step: float = 1e-2,
            burnin_optimizer: str = "adam",
            optimizer: str = "adam",
            optim_scheduler: str = "identity",
            exp_decay_rate: float = 0.999,
            cos_T_max: int = 100,
            cos_lr_min: float = 0.0,
            burnin_optim_steps: int = 1,
            optim_steps: int = 1,
            atol: float = 1e-15,
            max_burnin: int = 20,
            burnin_thresh: float = 0.5,
            post_burnin_mh: bool = True,
            num_upsample: int = 0,
            upsample_schedule: dict = None,
            step_schedule: dict = None,
            max_iterations: int = 100,
            latent: bool = True,
            IMHstep: int = 0,
            CMHstep: int = 0,
            CMHscale: float = 0.1,
            ULMstep: int = 0,
            stochastic_calibration: bool = False,
            taylor_zeta: bool = True,
            post_ulm_steps: int = 100,
            post_ulm_batch: int = 10,
            underdamped_kwargs: dict = None,
            use_tqdm: bool = False,
            train_kwargs: dict = None,
            animate_kwargs: dict = None) -> torch.Tensor:
        """
        Run DLA.

        :param x: prior samples with shape (n, d).
        :param main_step: step size for sample updates.
        :param burnin_optimizer: particle optimizer to use during burnin.
        :param optimizer: particle optimizer to use for DLMC updates.
        :param optim_scheduler: optimization step size scheduler.
        :param exp_decay_rate: decay rate for the exponential scheduler.
        :param cos_T_max: total number of epochs for the cosine scheduler.
        :param cos_lr_min: minimum learning rate for cosine scheduler.
        :param burnin_optim_steps: number of optimizer steps to take at each burnin iteration.
        :param optim_steps: number of optimizer steps to take at each DLMC iteration.
        :param atol: tolerance parameter when checking for convergence.
        :param max_burnin: maximum number of burnin iterations.
        :param burnin_thresh: in data space, the gradient norm threshold, in latent space, the virial threshold at which
            we end burnin.
        :param num_upsample: number of extra samples to draw from q after burnin.
        :param max_iterations: maximum number of DLA iterations for each beta level.
        :param latent: whether to do DLMC update in latent space.
        :param IMHstep: number of IMH steps to apply after DLMC.
        :param CMHstep: number of CMH steps to apply after DLMC.
        :param CMHscale: scale factor for CMH Gaussian proposal.
        :param ULMstep: number of ULM steps to apply after DLMC.
        :param stochastic_calibration: whether to do stochastic calibration of the particle density gradient.
        :param taylor_zeta: whether to use the linearized correction for stochastic calibration.
        :param post_ulm_steps: number of post-burndown underdamped Langevin steps.
        :param post_ulm_batch: number of underdamped Langevin steps to take before re-fitting q.
        :param underdamped_kwargs: keyword arguments for underdamped Langevin.
        :param use_tqdm: use a tqdm progress bar.
        :param train_kwargs: keyword arguments for interface.train_flow.
        :param animate_kwargs: keyword arguments for visualizer.animate.

        :return: torch.Tensor with posterior samples.
        """
        if main_step <= 0.0 or burnin_step <= 0.0:
            raise ValueError("step sizes must be positive")
        if burnin_optimizer not in ("adagrad", "rmsprop", "adam", "grad_descent", "line_search"):
            raise ValueError("burnin_optimizer must be one of adagrad, rmsprop, adam, grad_descent or line_search")
        if optimizer not in ("adagrad", "rmsprop",  "adam", "grad_descent", "line_search"):
            raise ValueError("optimizer must be one of adagrad, rmsprop, adam, grad_descent or line_search")
        if optim_scheduler not in ("identity", "exp", "cosine"):
            raise ValueError("optim_scheduler must be one of identity, exp or cosine")
        if exp_decay_rate <= 0.0 or exp_decay_rate > 1.0:
            raise ValueError("exp_decay rate must be between 0 and 1")
        if cos_T_max <= 0:
            raise ValueError("cos_T_max must be positive")
        if cos_lr_min < 0:
            raise ValueError("cos_lr_min must be greater than or equal to 0")
        if burnin_optim_steps < 1:
            raise ValueError("burnin_optim_steps must be at least 1")
        if optim_steps < 1:
            raise ValueError("optim_steps must be at least 1")
        if burnin_optimizer == "line_search" and burnin_optim_steps != 1:
            raise ValueError("If using line_search optimizer for burnin, burnin_optim_steps must be 1.")
        if optimizer == "line_search" and optim_steps != 1:
            raise ValueError("If using line_search optimizer, optim_steps must be 1.")
        if atol <= 0.0:
            raise ValueError("atol must be positive")
        if max_iterations <= 0:
            raise ValueError("max_iterations must be positive")
        if self.interface.flow is None:
            warnings.warn("Flow is None. Did you forget to create the flow using create_flow?")
        if train_kwargs is None:
            train_kwargs = dict()
        if underdamped_kwargs is None:
            underdamped_kwargs = dict()
        if animate_kwargs is None:
            animate_kwargs = dict()
        if latent and not isinstance(self.interface, TorchFlowInterface):
            raise ValueError("Latent DLA is supported only for TorchFlowInterfaces.")

        x = deepcopy(x)  # Make a copy of the initial particles, these will be changed during DLA
        initial_step = burnin_step

        latent_samples = torch.randn_like(x)  # Only for visualization purposes

        beta_stage = 0
        old_beta = 0.0

        x_iter = {'init': x}
        mh_iter = {'init': x}
        qx_iter = {'init': None}
        mean_iter = {'init': torch.mean(x, dim=0)}
        var_iter = {'init': torch.var(x, dim=0)}

        while self.beta <= 1:

            update_norm = torch.inf  # Just so we define this even if no iterations are executed.
            step = initial_step * torch.ones(len(x)).reshape(-1, 1) # Reset the step size for each beta level.
            dlmc_stage = 0
            num_burnin = 0
            burnin = True
            nf_interface = self.burnin_interface
            upsample_number = 0

            total_iterations = max_burnin + max_iterations

            for iteration in (pbar := (tqdm(range(total_iterations)) if use_tqdm else range(total_iterations))):
                # Clear the cached log_likelihood, log_prior, and gradients
                self.posterior.clear_cache()

                if dlmc_stage > max_iterations:
                    print(f'Reached maximum number of post-burnin DLMC iterations.')
                    break

                if not burnin:
                    dlmc_stage += 1
                elif burnin:
                    num_burnin += 1
                    if num_burnin == max_burnin:
                        print(f'Reached maximum number of burnin iterations.')
                        burnin = False


                if num_burnin == 1:
                    # Do initial update of the particles along the likelihood gradient
                    sample_sigma = torch.mean(torch.std(x, dim=0))
                    particle_step = step

                    current_logp = self.posterior.value(self.beta, x)
                    current_logq = self.posterior.log_prior(x)
                    logp_func = lambda val: self.posterior.value(self.beta, val)
                    logq_func = lambda val: self.posterior.log_prior(val)
                    particle_optimizer, particle_scheduler = self.initialize_optimizer(particles=x,
                                                                                       nf_interface=nf_interface,
                                                                                       step=particle_step,
                                                                                       particle_logp=current_logp,
                                                                                       particle_logq=current_logq,
                                                                                       grad_sq=torch.zeros_like(x),
                                                                                       avg_grad=torch.zeros_like(x),
                                                                                       logp_func=logp_func,
                                                                                       logq_func=logq_func,
                                                                                       optimizer=burnin_optimizer,
                                                                                       optim_scheduler=optim_scheduler,
                                                                                       exp_decay_rate=exp_decay_rate,
                                                                                       cos_T_max=cos_T_max,
                                                                                       cos_lr_min=cos_lr_min)

                    for i in range(burnin_optim_steps):
                        self.posterior.clear_cache()
                        grad = -(self.beta - old_beta) * self.posterior.log_likelihood_gradient(x)
                        particle_optimizer.step(grad)
                        particle_scheduler.step()

                # Train flow
                nf_interface.train_flow(x, **train_kwargs)

                if burnin:
                    self.posterior.clear_cache()
                    grad_x_U = -self.posterior.gradient(self.beta, x)

                    z = nf_interface.forward(x)
                    grad_z_U = nf_interface.grad_z_logp(z, grad_wrt_x=grad_x_U) + nf_interface.grad_z_logj(z)
                    z_virial = torch.mean(z * grad_z_U, dim=0)
                    sum_z_virial = torch.sum(z_virial)
                    x_virial = torch.mean(x * grad_x_U, dim=0)
                    sum_x_virial = torch.sum(x_virial)
                    print(f'zVirial = {sum_z_virial}')
                    print(f'Virial = {sum_x_virial}')
                    if sum_x_virial <= x.shape[1] * burnin_thresh and sum_x_virial > 0:
                        print(f'Reached virial threshold during burnin.')
                        burnin = False
                        num_burnin = max_burnin

                if num_burnin == max_burnin and burnin is False:
                    nf_interface = self.interface
                    num_burnin = torch.inf
                    if not post_burnin_mh:
                        IMHstep = 0
                    if num_upsample > 0:
                        print(f'Drawing {num_upsample} new particles from current q for post-burnin DLMC.')
                        new_x = nf_interface.sample(num_upsample)
                        x = torch.cat([x, new_x], dim=0)
                        accumulated_grad = torch.zeros_like(x)
                        sum_grad = torch.zeros_like(x)
                        step = main_step * torch.ones(len(x)).reshape(-1, 1)
                    x, accept = self.IMH(x, nf_interface=nf_interface)
                    nf_interface.train_flow(x, **train_kwargs)

                if upsample_schedule is not None:
                    if dlmc_stage in list(upsample_schedule.keys()):
                        print(f'Drawing {list(upsample_schedule.values())[upsample_number]} new samples from q.')
                        num_upsample = list(upsample_schedule.values())[upsample_number]
                        new_x = nf_interface.sample(num_upsample)
                        x = torch.cat([x, new_x], dim=0)
                        accumulated_grad = torch.zeros_like(x)
                        sum_grad = torch.zeros_like(x)
                        if step_schedule is None:
                            step = main_step * torch.ones(len(x)).reshape(-1, 1)
                        else:
                            new_step = list(step_schedule.values())[upsample_number]
                            step = new_step * torch.ones(len(x)).reshape(-1, 1)
                        upsample_number += 1
                        x, accept = self.IMH(x, nf_interface=nf_interface)
                        nf_interface.train_flow(x, **train_kwargs)

                # Update samples
                self.posterior.clear_cache()
                grad_x_U = -self.posterior.gradient(self.beta, x)
                x_virial = torch.mean(x * grad_x_U, dim=0)
                sum_x_virial = torch.sum(x_virial)
                print(f'<sum_x_virial> = {sum_x_virial}')

                z = nf_interface.forward(x)
                grad_z_U = nf_interface.grad_z_logp(z, grad_wrt_x=grad_x_U) + nf_interface.grad_z_logj(z)
                z_virial = torch.mean(z * grad_z_U, dim=0)
                sum_z_virial = torch.sum(z_virial)
                print(f'<sum_z_virial> = {sum_z_virial}')

                x_before_update = deepcopy(x)

                z = nf_interface.forward(x)
                particle_step = step

                self.posterior.clear_cache()
                current_logp = self.posterior.value(self.beta, x)
                current_logq = nf_interface.logq(x)

                def logp_func(val):
                    self.posterior.clear_cache()
                    x_val = nf_interface.inverse(val)
                    return self.posterior.value(self.beta, x_val)
                def logq_func(val):
                    x_val = nf_interface.inverse(val)
                    return nf_interface.logq(x_val)

                if burnin:
                    optim_method = burnin_optimizer
                    optim_updates = burnin_optim_steps
                elif not burnin:
                    optim_method = optimizer
                    optim_updates = optim_steps

                if iteration == 0:
                    accumulated_grad = torch.zeros_like(x)
                    sum_grad = torch.zeros_like(x)
                grad_sq = accumulated_grad

                particle_optimizer, particle_scheduler = self.initialize_optimizer(particles=z,
                                                                                   nf_interface=nf_interface,
                                                                                   step=particle_step,
                                                                                   particle_logp=current_logp,
                                                                                   particle_logq=current_logq,
                                                                                   grad_sq=grad_sq,
                                                                                   avg_grad=sum_grad/(iteration + 1),
                                                                                   logp_func=logp_func,
                                                                                   logq_func=logq_func,
                                                                                   optimizer=optim_method,
                                                                                   optim_scheduler=optim_scheduler,
                                                                                   exp_decay_rate=exp_decay_rate,
                                                                                   cos_T_max=cos_T_max,
                                                                                   cos_lr_min=cos_lr_min)

                for i in range(optim_updates):

                    if stochastic_calibration:
                        calib_zeta = self.stochastic_calibration(particles=z,
                                                                 nf_interface=nf_interface,
                                                                 step=particle_step,
                                                                 beta=self.beta,
                                                                 latent=True,
                                                                 taylor=taylor_zeta)
                    else:
                        calib_zeta = 1.0

                    self.posterior.clear_cache()
                    grad_x_U = -self.posterior.gradient(self.beta, x)
                    grad_z_U = nf_interface.grad_z_logp(z, grad_wrt_x=grad_x_U) + nf_interface.grad_z_logj(z)
                    grad = grad_z_U - calib_zeta * z
                    particle_optimizer.step(grad)
                    particle_scheduler.step()
                    x = nf_interface.inverse(z)

                    # Add in Adagrad step decay if needed.
                    if optim_method == 'line_search':
                        ls_tracker = particle_optimizer.step_counter
                        avg_steps = 0.5 * ((ls_tracker - 1.05) + torch.mean(ls_tracker - 1.05, dim=0)).reshape(-1, 1)
                        step = step / 10 ** avg_steps

                #if optim_method == 'adagrad':
                #    accumulated_grad = particle_optimizer.grad_sq
                #else:
                accumulated_grad = torch.zeros_like(x)
                sum_grad += grad ** 2

                update_values = x - x_before_update
                update_norm = torch.max(torch.abs(update_values))  # Using max norm as a worst case precaution

                x_iter[f'{iteration}'] = deepcopy(x)

                if IMHstep:
                    acc_rate_IMH = []
                    replaced_IMH = torch.zeros(len(x), dtype=torch.bool, device=x.device)

                    for _ in range(IMHstep):
                        x, accept = self.IMH(x, nf_interface=nf_interface)
                        replaced_IMH = replaced_IMH + accept
                        acc_rate_IMH.append(torch.sum(accept).item() / len(x))
                        total_acc_rate_IMH = torch.sum(replaced_IMH).item() / len(x)

                    mh_iter[f'{iteration}'] = deepcopy(x)

                if CMHstep:
                    x, acc_rate_CMH, total_acc_rate_CMH = self.CMH(x, CMHscale, CMHstep, nf_interface=nf_interface)

                if ULMstep:
                    v = torch.randn_like(x)
                    for _ in range(ULMstep):
                        x, v = self.underdamped_langevin(x=x, v=v, nf_interface=nf_interface, **underdamped_kwargs)

                self.log_iteration_data(x, latent_samples)

                if use_tqdm:
                    pbar.set_description(f'[Stage {beta_stage}] Beta: {self.beta}')
                    if IMHstep:
                        pbar_IMH = float(total_acc_rate_IMH)
                    else:
                        pbar_IMH = None
                    pbar.set_postfix(accept_IMH=pbar_IMH)

                qx_iter[f'{iteration}'] = nf_interface.sample(x.shape[0])
                mean_iter[f'{iteration}'] = torch.mean(x, dim=0)
                var_iter[f'{iteration}'] = torch.var(x, dim=0)

                # Check convergence
                if update_norm < atol:
                    warnings.warn('Particle updates are too small. Stopping.')
                    break

            # Burndown underdamped Langevin
            for iteration in (pbar := (tqdm(range(post_ulm_steps)) if use_tqdm else range(post_ulm_steps))):
                if iteration % post_ulm_batch == 0:
                    nf_interface.train_flow(x, **train_kwargs)
                    v = torch.randn_like(x)
                x, v = self.underdamped_langevin(x=x, v=v, nf_interface=nf_interface, **underdamped_kwargs)

            # Make the animation
            self.create_beta_stage_animation(beta_stage, **animate_kwargs)

            # Compute new beta
            self.set_new_beta(x=x)
            if self.finished:
                break
            if torch.isinf(update_norm):
                warnings.warn("Update norm is infinite. Stopping.")
                break

            beta_stage += 1
            self.debugger.stage_step()
        return x, x_iter, qx_iter, var_iter, mean_iter, mh_iter


class RegularDLA(DLA):
    def __init__(self,
                 interface: TorchFlowInterface,
                 burnin_interface: TorchFlowInterface,
                 posterior: DifferentiableTemperedPosterior,
                 beta_handler: BetaHandler = SingleStageBetaHandler(),
                 debugger: Optional[MultiStageDebugger] = None):
        super().__init__(interface, posterior, beta_handler, debugger)
        self.burnin_interface = burnin_interface

    def IMH(self, x, nf_interface):
        self.posterior.clear_cache()
        logp_old = self.posterior.value(self.beta, x)
        logq_old = nf_interface.logq(x)
        sample, logq_new = nf_interface.sample_with_logq(len(x))
        self.posterior.clear_cache()
        logp_new = self.posterior.value(self.beta, sample)

        logr = logp_new + logq_old - logp_old - logq_new

        accept = torch.log(torch.rand(len(x))) < logr
        x[accept] = sample[accept]
        return x, accept

    def initialize_optimizer(self,
                             particles: torch.Tensor,
                             nf_interface: TorchFlowInterface,
                             grad_sq: torch.Tensor,
                             avg_grad: torch.Tensor,
                             particle_logp: torch.Tensor,
                             particle_logq: torch.Tensor,
                             logp_func: callable,
                             logq_func: callable,
                             step: float,
                             optimizer: str,
                             optim_scheduler: str,
                             exp_decay_rate: float,
                             cos_T_max: int,
                             cos_lr_min: float):

        if optimizer == 'adagrad':
            particle_optimizer = ParticleAdagrad(particles=particles, grad_sq=grad_sq, lr=step)
        elif optimizer == 'rmsprop':
            particle_optimizer = ParticleRMSProp(particles=particles, avg_grad=avg_grad, lr=step)
        elif optimizer == 'adam':
            particle_optimizer = ParticleAdam(particles=particles, lr=step)
        elif optimizer == 'grad_descent':
            particle_optimizer = ParticleGradientDescent(particles=particles, lr=step)
        elif optimizer == 'line_search':
            self.posterior.clear_cache()
            particle_optimizer = ParticleLineSearch(particles=particles,
                                                    accumulated_grad=grad_sq,
                                                    particle_logp=particle_logp,
                                                    particle_logq=particle_logq,
                                                    posterior=self.posterior,
                                                    interface=nf_interface,
                                                    logp_func=logp_func,
                                                    logq_func=logq_func,
                                                    lr=step)

        if optim_scheduler == 'identity':
            particle_scheduler = IdentityScheduler(particle_optimizer)
        elif optim_scheduler == 'exp':
            particle_scheduler = ExponentialDecayScheduler(particle_optimizer, exp_decay_rate)
        elif optim_scheduler == 'cosine':
            particle_scheduler = CosineAnnealingScheduler(particle_optimizer, cos_T_max, cos_lr_min)

        return particle_optimizer, particle_scheduler

    def run(self,
            x: torch.Tensor,
            main_step: float = 1e-2,
            burnin_step: float = 1e-2,
            burnin_optimizer: str = "adam",
            optimizer: str = "adam",
            optim_scheduler: str = "identity",
            exp_decay_rate: float = 0.999,
            cos_T_max: int = 100,
            cos_lr_min: float = 0.0,
            burnin_optim_steps: int = 1,
            optim_steps: int = 1,
            atol: float = 1e-15,
            target_ess_fraction: float = 0.5,
            max_burnin: int = 20,
            burnin_thresh: float = 0.5,
            num_upsample: int = 0,
            max_iterations: int = 100,
            IMHstep: int = 0,
            use_tqdm: bool = False,
            train_kwargs: dict = None,
            animate_kwargs: dict = None) -> torch.Tensor:
        """
        Run DLA.
        :param x: prior samples with shape (n, d).
        :param step: step size for sample updates.
        :param atol: tolerance parameter when checking for convergence.
        :param target_ess_fraction: ESS target when determining new beta levels.
        :param max_iterations: maximum number of DLA iterations for each beta level.
        :param use_tqdm: use a tqdm progress bar.
        :param train_kwargs: keyword arguments for interface.train_flow.
        :param animate_kwargs: keyword arguments for visualizer.animate.
        :return: torch.Tensor with posterior samples.
        """
        if not 0.0 <= target_ess_fraction <= 1.0:
            raise ValueError("target_ess_fraction must be in the [0, 1] interval")
        if main_step <= 0.0 or burnin_step <= 0.0:
            raise ValueError("main_step and burnin_step must be positive")
        if atol <= 0.0:
            raise ValueError("atol must be positive")
        if max_iterations <= 0:
            raise ValueError("max_iterations must be positive")
        if self.interface.flow is None:
            warnings.warn("Flow is None. Did you forget to create the flow using create_flow?")
        if train_kwargs is None:
            train_kwargs = dict()
        if animate_kwargs is None:
            animate_kwargs = dict()

        x = deepcopy(x)  # Make a copy of the initial particles, these will be changed during DLA

        latent_samples = torch.randn_like(x)  # Only for visualization purposes
        x_iter = {'init': deepcopy(x)}
        mh_iter = {'init': deepcopy(x)}

        beta_stage = 0
        while self.beta <= 1:
            update_norm = torch.inf  # Just so we define this even if no iterations are executed.
            dlmc_stage = 0
            num_burnin = 0
            step = burnin_step
            burnin = True
            nf_interface = self.burnin_interface

            total_iterations = max_burnin + max_iterations
            for iteration in (pbar := (tqdm(range(total_iterations)) if use_tqdm else range(total_iterations))):
                # Clear the cached log_likelihood, log_prior, and gradients
                self.posterior.clear_cache()

                if dlmc_stage > max_iterations:
                    print(f'Reached maximum number of post-burnin DLMC iterations.')
                    break

                if not burnin:
                    dlmc_stage += 1
                elif burnin:
                    num_burnin += 1
                    if num_burnin == max_burnin:
                        print(f'Reached maximum number of burnin iterations.')
                        burnin = False

                if burnin:
                    self.posterior.clear_cache()
                    grad_x_U = -self.posterior.gradient(self.beta, x)
                    x_virial = torch.mean(x * grad_x_U, dim=0)
                    sum_x_virial = torch.sum(x_virial)
                    print(f'Virial = {sum_x_virial}')
                    if sum_x_virial <= x.shape[1] * burnin_thresh and sum_x_virial > 0:
                        print(f'Reached virial threshold during burnin.')
                        burnin = False
                        num_burnin = max_burnin

                if num_burnin == max_burnin and burnin is False:
                    nf_interface = self.interface
                    num_burnin = torch.inf
                    if num_upsample > 0:
                        print(f'Drawing {num_upsample} new particles from current q for post-burnin DLMC.')
                        x = nf_interface.sample(num_upsample)
                        accumulated_grad = torch.zeros_like(x)
                        sum_grad = torch.zeros_like(x)
                    step = main_step
                    nf_interface.train_flow(x, **train_kwargs)

                # Train flow
                nf_interface.train_flow(x, **train_kwargs)

                # Update samples
                x_before_update = deepcopy(x)

                current_logp = self.posterior.value(self.beta, x)
                current_logq = nf_interface.logq(x)
                def logp_func(val):
                    self.posterior.clear_cache()
                    return self.posterior.value(self.beta, val)
                logq_func = lambda val: nf_interface.logq(val)

                if burnin:
                    optim_method = burnin_optimizer
                    optim_updates = burnin_optim_steps
                elif not burnin:
                    optim_method = optimizer
                    optim_updates = optim_steps

                if iteration == 0:
                    accumulated_grad = torch.zeros_like(x)
                    sum_grad = torch.zeros_like(x)

                particle_optimizer, particle_scheduler = self.initialize_optimizer(particles=x,
                                                                                   nf_interface=nf_interface,
                                                                                   step=step,
                                                                                   particle_logp=current_logp,
                                                                                   particle_logq=current_logq,
                                                                                   grad_sq=accumulated_grad,
                                                                                   avg_grad=sum_grad / (iteration + 1),
                                                                                   logp_func=logp_func,
                                                                                   logq_func=logq_func,
                                                                                   optimizer=optim_method,
                                                                                   optim_scheduler=optim_scheduler,
                                                                                   exp_decay_rate=exp_decay_rate,
                                                                                   cos_T_max=cos_T_max,
                                                                                   cos_lr_min=cos_lr_min)

                for i in range(optim_updates):
                    self.posterior.clear_cache()
                    grad_x_U = -self.posterior.gradient(self.beta, x)
                    grad_x_V = -nf_interface.grad_x_logq(x)
                    grad = grad_x_U - grad_x_V
                    particle_optimizer.step(grad)
                    particle_scheduler.step()

                    if optim_method == 'adagrad':
                        accumulated_grad = particle_optimizer.grad_sq
                    else:
                        accumulated_grad = torch.zeros_like(x)
                sum_grad += grad ** 2

                x_iter[f'{iteration}'] = deepcopy(x)

                if torch.isnan(x).any():
                    warnings.warn('A particle is contains nan. Stopping.')
                    break

                if torch.isinf(x).any():
                    warnings.warn('A particle is contains inf. Stopping.')
                    break

                update_values = x - x_before_update
                update_norm = torch.max(torch.abs(update_values))  # Using max norm as a worst case precaution

                if IMHstep and burnin:
                    acc_rate_IMH = []
                    replaced_IMH = torch.zeros(len(x), dtype=torch.bool, device=x.device)
                    for _ in range(IMHstep):
                        x, accept = self.IMH(x, nf_interface=nf_interface)
                        replaced_IMH = replaced_IMH + accept
                        acc_rate_IMH.append(torch.sum(accept).item() / len(x))
                        total_acc_rate_IMH = torch.sum(replaced_IMH).item() / len(x)

                mh_iter[f'{iteration}'] = deepcopy(x)

                '''
                with torch.no_grad():
                    update_norm = float(torch.max(torch.abs(update_values)))  # Using max norm as a worst case precaution
                    grad_norm = float(torch.linalg.norm(grad, dim=1).mean())
                    log_prior = float(self.posterior.log_prior(x).mean())
                    log_likelihood = float(self.posterior.log_likelihood(x).mean())
                    log_prior_grad_norm = float(torch.linalg.norm(self.posterior.log_prior_gradient(x), dim=1).mean())
                    log_likelihood_grad_norm = float(torch.linalg.norm(self.posterior.log_likelihood_gradient(x), dim=1).mean())
                    log_q_grad_norm = float(torch.linalg.norm(grad_x_V, dim=1).mean())

                    log_data = dict(
                        grad_norm=grad_norm,
                        update_norm=update_norm,
                        log_prior=log_prior,
                        log_likelihood=log_likelihood,
                        log_prior_grad_norm=log_prior_grad_norm,
                        log_likelihood_grad_norm=log_likelihood_grad_norm,
                        log_q_grad_norm=log_q_grad_norm
                    )

                self.log_iteration_data(x, latent_samples, scalars=log_data)
                '''
                if use_tqdm:
                    pbar.set_description(f'[Stage {beta_stage}] Beta: {self.beta}')
                    #pbar.set_postfix(**log_data)
                    if IMHstep:
                        pbar.set_postfix(IMH=total_acc_rate_IMH)

                x_iter[f'{iteration}'] = x

                # Check convergence
                if update_norm < atol:
                    warnings.warn('Particle updates are too small. Stopping.')
                    break
                if np.isnan(update_norm):
                    warnings.warn('Particle update is nan. Stopping.')
                    break

            # Make the animation
            self.create_beta_stage_animation(beta_stage, **animate_kwargs)

            # Compute new beta
            self.set_new_beta(x=x)
            if self.finished:
                break
            if np.isinf(update_norm):
                warnings.warn("Update norm is infinite. Stopping.")
                break

            beta_stage += 1
            self.debugger.stage_step()
        return x, x_iter, mh_iter

class NFMH(DLA):
    def __init__(self,
                 interface: TorchFlowInterface,
                 burnin_interface: TorchFlowInterface,
                 posterior: DifferentiableTemperedPosterior,
                 beta_handler: BetaHandler = SingleStageBetaHandler(),
                 debugger: Optional[MultiStageDebugger] = None):
        super().__init__(interface, posterior, beta_handler, debugger)
        self.burnin_interface = burnin_interface

    def IMH(self, x, nf_interface):
        self.posterior.clear_cache()
        logp_old = self.posterior.value(self.beta, x)
        logq_old = nf_interface.logq(x)
        sample, logq_new = nf_interface.sample_with_logq(len(x))
        self.posterior.clear_cache()
        logp_new = self.posterior.value(self.beta, sample)

        logr = logp_new + logq_old - logp_old - logq_new

        accept = torch.log(torch.rand(len(x))) < logr
        x[accept] = sample[accept]
        return x, accept

    def run(self,
            x: torch.Tensor,
            target_ess_fraction: float = 0.5,
            max_burnin: int = 20,
            num_upsample: int = 0,
            max_iterations: int = 100,
            IMHstep: int = 1,
            use_tqdm: bool = False,
            train_kwargs: dict = None,
            animate_kwargs: dict = None) -> torch.Tensor:
        """
        Run DLA.
        :param x: prior samples with shape (n, d).
        :param step: step size for sample updates.
        :param atol: tolerance parameter when checking for convergence.
        :param target_ess_fraction: ESS target when determining new beta levels.
        :param max_iterations: maximum number of DLA iterations for each beta level.
        :param use_tqdm: use a tqdm progress bar.
        :param train_kwargs: keyword arguments for interface.train_flow.
        :param animate_kwargs: keyword arguments for visualizer.animate.
        :return: torch.Tensor with posterior samples.
        """
        if IMHstep < 1:
            raise ValueError("IMHstep must be at least 1")
        if not 0.0 <= target_ess_fraction <= 1.0:
            raise ValueError("target_ess_fraction must be in the [0, 1] interval")
        if max_iterations <= 0:
            raise ValueError("max_iterations must be positive")
        if self.interface.flow is None:
            warnings.warn("Flow is None. Did you forget to create the flow using create_flow?")
        if train_kwargs is None:
            train_kwargs = dict()
        if animate_kwargs is None:
            animate_kwargs = dict()

        x = deepcopy(x)  # Make a copy of the initial particles, these will be changed during DLA

        beta_stage = 0
        while self.beta <= 1:
            update_norm = torch.inf  # Just so we define this even if no iterations are executed.
            dlmc_stage = 0
            num_burnin = 0
            burnin = True
            nf_interface = self.burnin_interface

            total_iterations = max_burnin + max_iterations
            for iteration in (pbar := (tqdm(range(total_iterations)) if use_tqdm else range(total_iterations))):
                # Clear the cached log_likelihood, log_prior, and gradients
                self.posterior.clear_cache()

                if dlmc_stage > max_iterations:
                    print(f'Reached maximum number of post-burnin DLMC iterations.')
                    break

                if not burnin:
                    dlmc_stage += 1
                elif burnin:
                    num_burnin += 1
                    if num_burnin == max_burnin:
                        print(f'Reached maximum number of burnin iterations.')
                        burnin = False

                if num_burnin == max_burnin and burnin is False:
                    nf_interface = self.interface
                    num_burnin = torch.inf
                    if num_upsample > 0:
                        print(f'Drawing {num_upsample} new particles from current q for post-burnin DLMC.')
                        x = nf_interface.sample(num_upsample)
                    nf_interface.train_flow(x, **train_kwargs)

                # Train flow
                nf_interface.train_flow(x, **train_kwargs)

                # Update samples
                x_before_update = deepcopy(x)

                if torch.isnan(x).any():
                    warnings.warn('A particle is contains nan. Stopping.')
                    break

                if torch.isinf(x).any():
                    warnings.warn('A particle is contains inf. Stopping.')
                    break

                acc_rate_IMH = []
                replaced_IMH = torch.zeros(len(x), dtype=torch.bool, device=x.device)
                for _ in range(IMHstep):
                    x, accept = self.IMH(x, nf_interface=nf_interface)
                    replaced_IMH = replaced_IMH + accept
                    acc_rate_IMH.append(torch.sum(accept).item() / len(x))
                    total_acc_rate_IMH = torch.sum(replaced_IMH).item() / len(x)

                if use_tqdm:
                    pbar.set_description(f'[Stage {beta_stage}] Beta: {self.beta}')
                    pbar.set_postfix(IMH=total_acc_rate_IMH)

            # Make the animation
            self.create_beta_stage_animation(beta_stage, **animate_kwargs)

            # Compute new beta
            self.set_new_beta(x=x)
            if self.finished:
                break
            if np.isinf(update_norm):
                warnings.warn("Update norm is infinite. Stopping.")
                break

            beta_stage += 1
            self.debugger.stage_step()
        return x

class PPDLMC(DLA):
    def __init__(self,
                 interface: TorchFlowInterface,
                 posterior: DifferentiableTemperedPosterior,
                 beta_handler: BetaHandler = SingleStageBetaHandler(),
                 debugger: Optional[MultiStageDebugger] = None):
        super().__init__(interface, posterior, beta_handler, debugger)
        '''
        Note interface here is for latent space transformations only. The PPInterface is assigned internally, mainly
        because it needs to be re-defined for every sample set, unlike the standard flow interfaces.
        '''

    def IMH(self, x, pp_sigma):
        # IMH replaces the particles. How to implement this with ParticleAdam?
        self.posterior.clear_cache()
        pp_interface = PPInterface(x, pp_sigma)
        logp_old = self.posterior.value(self.beta, x)
        logq_old = pp_interface.logq(x)

        kde_idx = torch.randint(low=0, high=x.shape[0], size=(x.shape[0],))
        sample = x[kde_idx] + pp_sigma * torch.randn_like(x)
        logq_new = pp_interface.logq(sample)
        self.posterior.clear_cache()
        logp_new = self.posterior.value(self.beta, sample)

        logr = logp_new + logq_old - logp_old - logq_new

        accept = torch.log(torch.rand(len(x))) < logr
        x[accept] = sample[accept]
        self.posterior.clear_cache()
        return x, accept

    def CMH(self, x, scale, step):
        # CMH replaces the particles. How to implement this with ParticleAdam?
        self.posterior.clear_cache()
        logp_old = self.posterior.value(self.beta, x)

        z, logj_forward_old = self.interface.forward_with_logj(x)
        z = z.reshape(len(z), -1)

        acceptance_rate = []
        replaced = torch.zeros(len(x), dtype=torch.bool, device=x.device)
        for _ in range(step):
            z_new = z + torch.randn_like(z) * scale
            x_new, logj_backward_new = self.interface.inverse_with_logj(z_new)

            self.posterior.clear_cache()
            logp_new = self.posterior.value(self.beta, x_new)

            logr = (logp_new + logj_backward_new) - (logp_old - logj_forward_old)

            accept = torch.log(torch.rand(len(x))) < logr
            acceptance_rate.append(torch.sum(accept).item() / len(x))
            replaced = replaced + accept

            x[accept] = x_new[accept]
            logp_old[accept] = logp_new[accept]
            z[accept] = z_new[accept]
            logj_forward_old[accept] = -logj_backward_new[accept]

        self.posterior.clear_cache()
        return x, acceptance_rate, torch.sum(replaced).item() / len(x)

    def stochastic_calibration(self,
                               particles: torch.Tensor,
                               pp_interface: FlowInterface,
                               step: float,
                               beta: float,
                               latent: bool,
                               lower_zeta: float = -100.0,
                               upper_zeta: float = 100.0,
                               taylor: bool = False):

        grad_V = -pp_interface.grad_x_logq(particles)

        self.posterior.clear_cache()
        eta = torch.randn_like(particles) * torch.sqrt(2.0 * step)
        grad_U = -self.posterior.gradient(beta, particles)
        if latent:
            grad_U = self.interface.grad_z_logp(particles, grad_wrt_x=grad_U) + self.interface.grad_z_logj(particles)

        if not taylor:

            self.posterior.clear_cache()
            current_logp = self.posterior.value(beta, particles)
            stochastic = particles - step * grad_U + eta
            self.posterior.clear_cache()
            delta_U_stochastic = self.posterior.value(beta, stochastic) - current_logp
            self.posterior.clear_cache()
            grad_U_stochastic = self.posterior.gradient(beta, stochastic) + grad_U

            print(f'<delta_U_s> = {torch.mean(delta_U_stochastic)}')
            standard_dlmc = particles - step * grad_U + step * grad_V
            self.posterior.clear_cache()
            delta_U_dlmc = self.posterior.value(beta, standard_dlmc) - current_logp
            print(f'<delta_U_dlmc> = {torch.mean(delta_U_dlmc)}')

            def objective(zeta):
                modified_dlmc = particles - step * grad_U + step * zeta * grad_V
                self.posterior.clear_cache()
                delta_U_dlmc = self.posterior.value(beta, modified_dlmc) - current_logp
                return float(torch.mean(delta_U_dlmc) - torch.mean(delta_U_stochastic))

            def grad_objective(zeta):
                modified_dlmc = particles - step * grad_U + step * zeta * grad_V
                self.posterior.clear_cache()
                grad_U_dlmc = self.posterior.gradient(beta, modified_dlmc) + grad_U
                return float(torch.mean(grad_U_dlmc) - torch.mean(grad_U_stochastic))

            print(f'objective(upper_zeta) = {objective(upper_zeta)}')
            print(f'objective(lower_zeta) = {objective(lower_zeta)}')

            try:
                sol = root_scalar(
                    f=objective
                )

                if not sol.converged:
                    warnings.warn(f"Zeta search did not converge. Setting to 1 (standard DLMC).")
                    calib_zeta = 1.0
                else:
                    calib_zeta = sol.root
                    print(f'Stochastic calibration zeta: {calib_zeta}')
            except:
                print('Zeta could not be determined. Setting to 1 (standard DLMC).')
                calib_zeta = 1.0

        elif taylor:

            calib_zeta = torch.sum(torch.linalg.norm(eta * grad_U, dim=1) ** 2) / torch.sum(
                step * torch.linalg.norm(grad_U * grad_V, dim=1) ** 2)
            print(f'Taylor calibration zeta: {calib_zeta}')

        return calib_zeta

    def initialize_optimizer(self,
                             particles: torch.Tensor,
                             nf_interface: TorchFlowInterface,
                             grad_sq: torch.Tensor,
                             avg_grad: torch.Tensor,
                             particle_logp: torch.Tensor,
                             particle_logq: torch.Tensor,
                             logp_func: callable,
                             logq_func: callable,
                             step: float,
                             optimizer: str,
                             optim_scheduler: str,
                             exp_decay_rate: float,
                             cos_T_max: int,
                             cos_lr_min: float):

        if optimizer == 'adagrad':
            particle_optimizer = ParticleAdagrad(particles=particles, grad_sq=grad_sq, lr=step)
        elif optimizer == 'rmsprop':
            particle_optimizer = ParticleRMSProp(particles=particles, avg_grad=avg_grad, lr=step)
        elif optimizer == 'adam':
            particle_optimizer = ParticleAdam(particles=particles, lr=step)
        elif optimizer == 'grad_descent':
            particle_optimizer = ParticleGradientDescent(particles=particles, lr=step)
        elif optimizer == 'line_search':
            self.posterior.clear_cache()
            particle_optimizer = ParticleLineSearch(particles=particles,
                                                    accumulated_grad=grad_sq,
                                                    particle_logp=particle_logp,
                                                    particle_logq=particle_logq,
                                                    posterior=self.posterior,
                                                    interface=nf_interface,
                                                    logp_func=logp_func,
                                                    logq_func=logq_func,
                                                    lr=step)

        if optim_scheduler == 'identity':
            particle_scheduler = IdentityScheduler(particle_optimizer)
        elif optim_scheduler == 'exp':
            particle_scheduler = ExponentialDecayScheduler(particle_optimizer, exp_decay_rate)
        elif optim_scheduler == 'cosine':
            particle_scheduler = CosineAnnealingScheduler(particle_optimizer, cos_T_max, cos_lr_min)

        return particle_optimizer, particle_scheduler

    def run(self,
            x: torch.Tensor,
            pp_bw_factor: float = 1.0,
            step: float = 1e-2,
            optimizer: str = "adam",
            optim_scheduler: str = "identity",
            exp_decay_rate: float = 0.999,
            cos_T_max: int = 100,
            cos_lr_min: float = 0.0,
            optim_steps: int = 1,
            atol: float = 1e-15,
            max_burnin: int = 20,
            burnin_grad_thresh: float = 0.5,
            num_upsample: int = 0,
            max_iterations: int = 100,
            latent: bool = True,
            stochastic_calibration: bool = True,
            lower_zeta: float = -10.0,
            upper_zeta: float = 10.0,
            taylor_zeta: bool = False,
            IMHstep: int = 0,
            CMHstep: int = 0,
            CMHscale: float = 0.1,
            use_tqdm: bool = False,
            train_kwargs: dict = None,
            animate_kwargs: dict = None) -> torch.Tensor:
        """
        Run PPDLMC.

        :param x: prior samples with shape (n, d).
        :param pp_bw_factor: multiplicative factor for the PP KDE bandwidth.
        :param step: step size for sample updates.
        :param optimizer: particle optimizer to use for DLMC updates.
        :param optim_scheduler: optimization step size scheduler.
        :param exp_decay_rate: decay rate for the exponential scheduler.
        :param cos_T_max: total number of epochs for the cosine scheduler.
        :param cos_lr_min: minimum learning rate for cosine scheduler.
        :param optim_steps: number of optimizer steps to take at each DLMC iteration.
        :param atol: tolerance parameter when checking for convergence.
        :param max_burnin: maximum number of burnin iterations.
        :param burnin_grad_thresh: <grad(logq)/grad(logp)> threshold at which we end burnin.
        :param num_upsample: number of extra samples to draw from q after burnin.
        :param max_iterations: maximum number of DLMC iterations for each beta level post-burnin.
        :param latent: whether to do latent space DLMC.
        :param stochastic_calibration: whether to apply stochastic calibration to DLMC.
        :param lower_zeta: lower limit on zeta in stochastic calibration.
        :param upper_zeta: upper limit on zeta in stochastic calibration.
        :param taylor_zeta: whether top use the Taylor-derived expression for zeta.
        :param IMHstep: number of IMH steps to apply after DLMC.
        :param CMHstep: number of CMH steps to apply after DLMC.
        :param use_tqdm: use a tqdm progress bar.
        :param train_kwargs: keyword arguments for interface.train_flow.
        :param animate_kwargs: keyword arguments for visualizer.animate.

        :return: torch.Tensor with posterior samples.
        """
        if pp_bw_factor <= 0.0:
            raise ValueError("pp_bw_factor must be positive")
        if step <= 0.0:
            raise ValueError("step must be positive")
        if optimizer not in ("adagrad", "rmsprop", "adam", "grad_descent", "line_search"):
            raise ValueError("optimizer must be one of adagrad, rmsprop, adam, grad_descent or line_search")
        if optim_scheduler not in ("identity", "exp", "cosine"):
            raise ValueError("optim_scheduler must be one of identity, exp or cosine")
        if exp_decay_rate <= 0.0 or exp_decay_rate > 1.0:
            raise ValueError("exp_decay rate must be between 0 and 1")
        if cos_T_max <= 0:
            raise ValueError("cos_T_max must be positive")
        if cos_lr_min < 0:
            raise ValueError("cos_lr_min must be greater than or equal to 0")
        if optim_steps < 1:
            raise ValueError("optim_steps must be at least 1")
        if atol <= 0.0:
            raise ValueError("atol must be positive")
        if max_iterations <= 0:
            raise ValueError("max_iterations must be positive")
        if latent and self.interface.flow is None:
            warnings.warn("Flow is None, whilst latent is True. Did you forget to create the flow using create_flow?")
        if train_kwargs is None:
            train_kwargs = dict()
        if animate_kwargs is None:
            animate_kwargs = dict()
        if latent and not isinstance(self.interface, TorchFlowInterface):
            raise ValueError("Latent PPDLMC is supported only for TorchFlowInterfaces.")

        x = deepcopy(x)  # Make a copy of the initial particles, these will be changed during DLA
        initial_step = step

        latent_samples = torch.randn_like(x)  # Only for visualization purposes

        beta_stage = 0
        old_beta = 0.0
        while self.beta <= 1:
            update_norm = torch.inf  # Just so we define this even if no iterations are executed.
            step = initial_step  # Reset the step size for each beta level.
            dlmc_stage = 0
            num_burnin = 0
            burnin = True
            total_iterations = max_burnin + max_iterations
            for iteration in (pbar := (tqdm(range(total_iterations)) if use_tqdm else range(total_iterations))):

                if dlmc_stage > max_iterations:
                    print(f'Reached maximum number of post-burnin DLMC iterations.')
                    break

                if not burnin:
                    dlmc_stage += 1
                elif burnin:
                    num_burnin += 1
                    if num_burnin == max_burnin:
                        print(f'Reached maximum number of burnin iterations.')
                        burnin = False

                if num_burnin == 1:
                    # Do initial update of the particles along the likelihood gradient
                    sample_sigma = torch.mean(torch.std(x, dim=0))
                    particle_step = step * sample_sigma

                    current_logp = self.posterior.value(self.beta, x)
                    current_logq = self.posterior.log_prior(x)
                    logp_func = lambda val: self.posterior.value(self.beta, val)
                    logq_func = lambda val: self.posterior.log_prior(val)

                    particle_optimizer, particle_scheduler = self.initialize_optimizer(particles=x,
                                                                                       nf_interface=self.interface,
                                                                                       step=step,
                                                                                       particle_logp=current_logp,
                                                                                       particle_logq=current_logq,
                                                                                       grad_sq=torch.zeros_like(x),
                                                                                       avg_grad=torch.zeros_like(x),
                                                                                       logp_func=logp_func,
                                                                                       logq_func=logq_func,
                                                                                       optimizer=optimizer,
                                                                                       optim_scheduler=optim_scheduler,
                                                                                       exp_decay_rate=exp_decay_rate,
                                                                                       cos_T_max=cos_T_max,
                                                                                       cos_lr_min=cos_lr_min)

                    for i in range(optim_steps):
                        self.posterior.clear_cache()
                        grad = -(self.beta - old_beta) * self.posterior.log_likelihood_gradient(x)
                        particle_optimizer.step(grad)
                        particle_scheduler.step()

                if burnin:
                    self.posterior.clear_cache()
                    scott_factor = np.sqrt((x.shape[0]) ** (-1.0 / (x.shape[1] + 4.0)))
                    pp_sigma = scott_factor * torch.std(x, dim=0, unbiased=False) * pp_bw_factor
                    pp_interface = PPInterface(x, pp_sigma)
                    grad_x_U = -self.posterior.gradient(self.beta, x)
                    grad_x_V = -pp_interface.grad_x_logq(x)
                    grad_ratio = torch.mean(torch.linalg.norm(grad_x_V, dim=1) / torch.linalg.norm(grad_x_U, dim=1))
                    '''
                    if grad_ratio >= burnin_grad_thresh:
                        print(f'Reached gradient norm threshold during burnin.')
                        burnin = False
                        num_burnin = max_burnin
                    '''

                if num_burnin == max_burnin and burnin is False:
                    if num_upsample > 0:
                        print(f'Drawing {num_upsample} new particles from current q for post-burnin DLMC.')
                        scott_factor = np.sqrt((x.shape[0]) ** (-1.0 / (x.shape[1] + 4.0)))
                        pp_sigma = scott_factor * torch.std(x, dim=0, unbiased=False) * pp_bw_factor
                        kde_idx = torch.randint(low=0, high=x.shape[0], size=(num_upsample,))
                        new_particles = x[kde_idx] + pp_sigma * torch.randn_like(x[kde_idx])
                        x = torch.cat([x, new_particles])
                        grad_sq = torch.zeros_like(x)
                        accumulated_grad = torch.zeros_like(x)
                        num_burnin = torch.inf

                if latent:
                    # Train flow
                    self.interface.train_flow(x, **train_kwargs)

                # Update samples
                x_before_update = deepcopy(x)

                if iteration == 0:
                    accumulated_grad = torch.zeros_like(x)
                    sum_grad = torch.zeros_like(x)

                if latent:

                    z = self.interface.forward(x)
                    particle_step = step

                    def logp_func(val):
                        self.posterior.clear_cache()
                        x_val, logj_backward = self.interface.inverse_with_logj(val)
                        return self.posterior.value(self.beta, x_val) + logj_backward
                    def logq_func(val):
                        return torch.distributions.MultivariateNormal(loc=torch.zeros(val.shape[1]),
                                                                      covariance_matrix=torch.eye(val.shape[1])).log_prob(val)

                    particle_optimizer, particle_scheduler = self.initialize_optimizer(particles=x,
                                                                                       nf_interface=self.interface,
                                                                                       step=step,
                                                                                       particle_logp=current_logp,
                                                                                       particle_logq=current_logq,
                                                                                       grad_sq=accumulated_grad,
                                                                                       avg_grad=sum_grad / (
                                                                                                   iteration + 1),
                                                                                       logp_func=logp_func,
                                                                                       logq_func=logq_func,
                                                                                       optimizer=optimizer,
                                                                                       optim_scheduler=optim_scheduler,
                                                                                       exp_decay_rate=exp_decay_rate,
                                                                                       cos_T_max=cos_T_max,
                                                                                       cos_lr_min=cos_lr_min)

                    for i in range(optim_steps):
                        scott_factor = np.sqrt((z.shape[0]) ** (-1.0 / (z.shape[1] + 4.0)))
                        pp_sigma = scott_factor * torch.std(z, dim=0, unbiased=False) * pp_bw_factor
                        pp_interface = PPInterface(z, pp_sigma)

                        if stochastic_calibration:
                            calib_zeta = self.stochastic_calibration(particles=z,
                                                                     pp_interface=pp_interface,
                                                                     step=particle_step,
                                                                     beta=self.beta,
                                                                     latent=True,
                                                                     lower_zeta=lower_zeta,
                                                                     upper_zeta=upper_zeta,
                                                                     taylor=taylor_zeta)
                        else:
                            calib_zeta = 1.0

                        self.posterior.clear_cache()
                        grad_x_U = -self.posterior.gradient(self.beta, x)
                        grad_z_U = self.interface.grad_z_logp(z, grad_wrt_x=grad_x_U) + self.interface.grad_z_logj(z)
                        grad_z_V = -pp_interface.grad_x_logq(z)
                        grad = grad_z_U - calib_zeta * grad_z_V
                        particle_optimizer.step(grad)
                        particle_scheduler.step()
                        x = self.interface.inverse(z)

                    #particle_step = particle_optimizer.lr

                else:

                    scott_factor = np.sqrt((x.shape[0]) ** (-1.0 / (x.shape[1] + 4.0)))
                    pp_sigma = scott_factor * torch.std(x, dim=0, unbiased=False) * pp_bw_factor
                    pp_interface = PPInterface(x, pp_sigma)

                    def logp_func(val):
                        self.posterior.clear_cache()
                        x_val, logj_backward = self.interface.inverse_with_logj(val)
                        return self.posterior.value(self.beta, x_val) + logj_backward
                    def logq_func(val):
                        return pp_interface.logq(val)

                    particle_optimizer, particle_scheduler = self.initialize_optimizer(particles=x,
                                                                                       nf_interface=self.interface,
                                                                                       step=step,
                                                                                       particle_logp=current_logp,
                                                                                       particle_logq=current_logq,
                                                                                       grad_sq=accumulated_grad,
                                                                                       avg_grad=sum_grad / (
                                                                                               iteration + 1),
                                                                                       logp_func=logp_func,
                                                                                       logq_func=logq_func,
                                                                                       optimizer=optimizer,
                                                                                       optim_scheduler=optim_scheduler,
                                                                                       exp_decay_rate=exp_decay_rate,
                                                                                       cos_T_max=cos_T_max,
                                                                                       cos_lr_min=cos_lr_min)

                    for i in range(optim_steps):
                        scott_factor = np.sqrt((x.shape[0]) ** (-1.0 / (x.shape[1] + 4.0)))
                        pp_sigma = scott_factor * torch.std(x, dim=0, unbiased=False) * pp_bw_factor
                        pp_interface = PPInterface(x, pp_sigma)

                        if stochastic_calibration:
                            calib_zeta = self.stochastic_calibration(particles=x,
                                                                     pp_interface=pp_interface,
                                                                     step=particle_step,
                                                                     beta=self.beta,
                                                                     latent=False,
                                                                     lower_zeta=lower_zeta,
                                                                     upper_zeta=upper_zeta,
                                                                     taylor=taylor_zeta)
                        else:
                            calib_zeta = 1.0

                        self.posterior.clear_cache()
                        grad_x_U = -self.posterior.gradient(self.beta, x)
                        grad_x_V = -pp_interface.grad_x_logq(x)
                        grad = grad_x_U - calib_zeta * grad_x_V
                        particle_optimizer.step(grad)
                        particle_scheduler.step()

                    particle_step = particle_optimizer.lr

                update_values = x - x_before_update
                update_norm = torch.max(torch.abs(update_values))  # Using max norm as a worst case precaution

                # Independent Metropolis-Hastings updates
                if IMHstep and burnin:
                    acc_rate_IMH = []
                    replaced_IMH = torch.zeros(len(x), dtype=torch.bool, device=x.device)
                    for _ in range(IMHstep):
                        scott_factor = np.sqrt((x.shape[0]) ** (-1.0 / (x.shape[1] + 4.0)))
                        pp_sigma = scott_factor * torch.std(x, dim=0, unbiased=False) * pp_bw_factor
                        x, accept = self.IMH(x, pp_sigma)
                        replaced_IMH = replaced_IMH + accept
                        acc_rate_IMH.append(torch.sum(accept).item() / len(x))

                if CMHstep:
                    x, acc_rate_CMH, total_acc_rate_CMH = self.CMH(x, CMHscale, CMHstep)

                self.log_iteration_data(x, latent_samples)

                if IMHstep:
                    print(f'Independent MH acceptance rate: {acc_rate_IMH}')
                    print(f'Independent MH total acceptance rate: {torch.sum(replaced_IMH).item() / len(x)}')
                if CMHstep:
                    print(f'Correlated MH acceptance rate: {acc_rate_CMH}')
                    print(f'Correlated MH total acceptance rate: {total_acc_rate_CMH}')
                if use_tqdm:
                    pbar.set_description(f'[Stage {beta_stage}] Beta: {self.beta}')
                    pbar.set_postfix(update_norm=float(update_norm))

                # Check convergence
                if update_norm < atol:
                    warnings.warn('Particle updates are too small. Stopping.')
                    break

            # Make the animation
            self.create_beta_stage_animation(beta_stage, **animate_kwargs)

            # Compute new beta
            old_beta = self.beta
            self.set_new_beta(x=x)
            if self.finished:
                break
            if torch.isinf(update_norm):
                warnings.warn("Update norm is infinite. Stopping.")
                break

            beta_stage += 1
            self.debugger.stage_step()
        return x


if __name__ == '__main__':
    import torch.distributions
    from nfmc_jax.flows.base import SINFInterface, RealNVPInterface, MAFInterface, RQNSFInterface
    from nfmc_jax.utils.torch_distributions import gaussian_log_prob, Funnel

    torch.manual_seed(0)

    n_dim = 6  # number of dimensions
    n_samples = 100

    prior_loc = 0
    prior_scale = 3

    likelihood_loc = 3
    likelihood_scale = 0.25

    initial_samples = torch.distributions.Normal(loc=prior_loc, scale=prior_scale).sample((n_samples, n_dim))

    # MAF
    # flow_interface = MAFInterface(n_dim=n_dim)
    # flow_interface.create_flow()

    # RealNVP
    # flow_interface = RealNVPInterface(n_dim=n_dim)
    # flow_interface.create_flow()

    # RealNVP
    # flow_interface = RQNSFInterface(n_dim=n_dim)
    # flow_interface.create_flow()

    # SINF
    flow_interface = SINFInterface()
    flow_interface.create_flow(x_train=initial_samples, iteration=5, alpha=(0, 0.98))

    # Torch
    funnel_dist = Funnel(n_dim)
    posterior = TorchPosterior(
        log_likelihood=lambda x: funnel_dist.log_prob(x),
        log_prior=lambda x: gaussian_log_prob(x, loc=prior_loc, scale=prior_scale)
    )

    # Jax
    # posterior = JaxPosterior(
    #     log_likelihood=...,
    #     log_prior=...
    # )

    debugger = MultiStageDebugger(
        save_raw_data=True,
        animate=True,
        static_data={
            'likelihood': funnel_dist.sample(500)
        }
    )

    dla = RegularDLA(
        interface=flow_interface,
        posterior=posterior,
        debugger=debugger  # You can skip this if you don't need debug functionalities
    )
    dla.run(
        initial_samples,
        step=5e-2,
        max_iterations=1500,
        use_tqdm=True,
        animate_kwargs=dict(
            dpi=200
        ),
        train_kwargs=dict(
            # n_epochs=10,
            # lr_Psi=5e-4,
            # lr_A=5e-4,
            # lr_A_decay=0.99,
            # lr_Psi_decay=0.99
        )
    )
