import torch

from .distribution import Distribution


def get_prior(args):
    mean = torch.zeros(
        [1, int(args.latent_dim * 2)], 
        device=args.device
    )
    covar = torch.zeros(
        [1, int(args.latent_dim * 2)], 
        device=args.device
    )

    prior = Distribution(mean, covar)
    return prior

