import torch

from pfgmpp_kernel import sample_noise


def ipfm_sampler(
    net, *, latents, class_labels=None, init_sigma=2.5, min_sigma=0.002, D="inf", nsteps=1, **sampling_kwargs
):
    sigmas = torch.linspace(init_sigma, min_sigma, nsteps + 1)
    x_noised = sample_noise(latents=latents, sigma=init_sigma, D=D)
    for i, (curr_sigma, next_sigma) in enumerate(zip(sigmas[:-1], sigmas[1:])):
        x_denoised = net(x_noised, (curr_sigma*torch.ones(x_noised.shape[0],1,1,1)).to(x_noised.device), class_labels, **sampling_kwargs)
        if i < nsteps - 1:
            x_noised = x_denoised + sample_noise(latents=torch.randn_like(latents), sigma=next_sigma.item(), D=D)
    return x_denoised
