from numpyro.contrib.einstein.steinvi import SteinVI
from numpyro.infer.autoguide import AutoDelta

class SVGD(SteinVI):
    def __init__(
        self,
        model,
        optim,
        kernel_fn,
        num_stein_particles = 10,
        repulsion_temperature = 1.0,
        non_mixture_guide_params_fn = lambda name: False,
        enum=True,
        guide_kwargs={},
        **static_kwargs,
    ):
        super().__init__(
            model,
            AutoDelta(model, **guide_kwargs),
            optim,
            kernel_fn,
            num_stein_particles,
            1,
            1. / float(num_stein_particles),
            repulsion_temperature,
            non_mixture_guide_params_fn,
            enum,
            **static_kwargs,
        )