
import torch
# Wrapper for normflows library datasets
class NFDistributionWrapper:
    """
    Wrapper to make nf.distributions compatible with EM2C / MCMC.
    """

    def __init__(self, nf_dist, device="cpu"):
        self.dist = nf_dist
        self.device = device

    def sample(self, N):
        return self.dist.sample(N).to(self.device)

    def logpi(self, x):
        return self.dist.log_prob(x)

    def grad_logpi(self, x):
        x = x.requires_grad_(True)
        logp = self.logpi(x).sum()
        grad = torch.autograd.grad(logp, x)[0]
        return grad
