import torch
import normflows as nf
class NormalizingFlowTorch:
    def __init__(self, dim, device='cpu', *args, **kwargs):
        self.dim = dim
        K = 8
        latent_size = dim
        hidden_units = 64
        hidden_layers = 3

        flows = []
        for i in range(K):
            flows += [nf.flows.AutoregressiveRationalQuadraticSpline(latent_size, hidden_layers, hidden_units)]
            flows += [nf.flows.LULinearPermute(latent_size)]

        q0 = nf.distributions.DiagGaussian(dim, trainable=False)

        self.nfm = nf.NormalizingFlow(q0=q0, flows=flows)

        self.nfm = self.nfm.to(device)
        self.nfm.train()

    def extract_params(self, ):
        raise NotImplementedError()

    def log_posterior(self, theta,):
        return self.nfm.log_prob(theta)

    def sample(self, num = 1):
        z, logq = self.nfm.sample(num)
        return z

    def posterior_parameters(self, ):
        return self.nfm.parameters()






