import jax.numpy as np
import numpyro.distributions as npdist
import jax


def encode_params(mean, logdiag, a_mean, a_logdiag):
	return {'mean': mean, 'logdiag': logdiag, 'a_mean': a_mean, 'a_logdiag': a_logdiag}

def decode_params(params):
	mean, logdiag, a_mean, a_logdiag = params['mean'], params['logdiag'], params['a_mean'], params['a_logdiag']
	return mean, logdiag, a_mean, a_logdiag

def initialize(dim):
	mean = np.zeros(dim)
	logdiag = np.zeros(dim)
	# a_mean = 0. # Could be vectors
	# a_logdiag = 0.
	a_mean = np.zeros(dim)
	a_logdiag = np.zeros(dim)
	return encode_params(mean, logdiag, a_mean, a_logdiag)

def build(params, z):
	mean, logdiag, a_mean, a_logdiag = decode_params(params)
	return npdist.Independent(npdist.Normal(loc = mean + a_mean * z, scale = np.exp(logdiag + a_logdiag * z)), 1)
	# return npdist.Independent(npdist.Normal(loc = mean, scale = np.exp(logdiag)), 1)

def log_prob(rho, params, z):
	dist = build(params, z)
	return dist.log_prob(rho)



