import jax.numpy as np
import numpyro.distributions as npdist
import jax


def encode_params_linear(mean, logdiag):
	return {'mean': mean, 'logdiag': logdiag}

def encode_params_affine(mean1, mean2, logdiag1, logdiag2):
	return {'mean1': mean1, 'mean2': mean2, 'logdiag1': logdiag1, 'logdiag2': logdiag2}
	
def decode_params_linear(params):
	mean, logdiag = params['mean'], params['logdiag']
	return mean, logdiag

def decode_params_affine(params):
	mean1, mean2, logdiag1, logdiag2 = params['mean1'], params['mean2'], params['logdiag1'], params['logdiag2']
	return mean1, mean2, logdiag1, logdiag2

def decode_params_vd(params):
	mean, logdiag = params['mean'], params['logdiag']
	return mean, logdiag

def initialize(dim, mode):
	if mode == 'brlinear':
		mean = np.zeros(dim)
		logdiag = np.zeros(dim)
		return encode_params_linear(mean, logdiag)
	elif mode == 'braffine':
		mean1 = np.zeros(dim)
		mean2 = np.zeros(dim)
		logdiag1 = np.zeros(dim)
		logdiag2 = np.zeros(dim)
		return encode_params_affine(mean1, mean2, logdiag1, logdiag2)

def build(beta, vdparams, brmode, brparams):
	mean, logdiag = update_params(beta, vdparams, brmode, brparams)
	return npdist.Independent(npdist.Normal(loc = mean, scale = np.exp(logdiag)), 1)

def log_prob(z, beta, vdmode, vdparams, brmode, brparams):
	if vdmode == 2 and brmode != 'vd':
		raise NotImplementedError('Tuning bridges not available for vdmode 2.')
	dist = build(beta, vdparams, brmode, brparams)
	return dist.log_prob(z)

def update_params(beta, vdparams, brmode, brparams):
	vdmean, vdlogdiag = decode_params_vd(vdparams)
	if brmode == 'brvd':
		mean = vdmean
		logdiag = vdlogdiag
	elif brmode == 'brlinear':
		brmean, brlogdiag = decode_params_linear(brparams)
		mean = vdmean + beta * brmean
		logdiag = vdlogdiag + beta * brlogdiag
	elif brmode == 'braffine':
		brmean1, brmean2, brlogdiag1, brlogdiag2 = decode_params_affine(brparams)
		mean = vdmean + brmean1 + beta * brmean2
		logdiag = vdlogdiag + brlogdiag1 + beta * brlogdiag2
	else:
		raise NotImplementedError('Mode bridge %s not available.' % brmode)
	return mean, logdiag





