import jax.numpy as np
import numpyro.distributions as npdist
import jax
import variationaldist as vd
import momdist as md
from jax.flatten_util import ravel_pytree
import functools


# Gaussian tools
def build_gaussian(dim, scale):
	return npdist.Independent(npdist.Normal(loc = np.zeros(dim), scale = np.ones(dim) * scale), 1)

def log_prob_gaussian(rho, dim, scale):
	dist = build_gaussian(dim, scale)
	return dist.log_prob(rho)

def sample_gaussian(rng_key, dim, scale):
	eps = jax.random.normal(rng_key, shape = (dim,))
	return eps * scale



def initialize(dim, vdmode = 1, vdparams = None, nbridges = 0, eps = 0.02, beta_0 = 0.9, mdparams = None, trainable = ['vd'], eps_vec = False):
	"""
	- vdmode is the variational distribution used: 'diag' or 'full'.
	- vdparams are the parameters of the variational distribution.
	- nbridges is the number of bridging densities.
	- eps is the log of the step size for the leapfrog integrator.
	- mdparams are the parameters of the momentum distribution.
	All parameters, if left at None, are trained. If some parameter is given a vale, then it is left fixed.
	"""
	params_train = {} # Has all trainable parameters
	params_notrain = {} # Non trainable parameters
	
	# Variational distribution parameters
	if 'vd' in trainable:
		params_train['vd'] = vdparams
		if vdparams is None:
			params_train['vd'] = vd.initialize(dim, vdmode)
	else:
		params_notrain['vd'] = vdparams
		if vdparams is None:
			params_notrain['vd'] = vd.initialize(dim, vdmode)
	
	# Leapfrog step-size
	if eps_vec:
		eps = eps * np.ones(dim) ###
	if 'eps' in trainable:
		params_train['eps'] = eps 
		params_train['eps_var'] = 0.
	else:
		params_notrain['eps'] = eps
		params_notrain['eps_var'] = 0.

	if 'beta_0' in trainable:
		params_train['beta_0'] = beta_0
	else:
		params_notrain['beta_0'] = beta_0
	# params_notrain['beta_0'] = 1.

	# Other fixed parameters
	params_fixed = (dim, vdmode, nbridges)
	params_flat, unflatten = ravel_pytree((params_train, params_notrain))
	return params_flat, unflatten, params_fixed


def compute_ratio(seed, params_flat, unflatten, params_fixed, log_prob):
	params_train, params_notrain = unflatten(params_flat)
	params_notrain = jax.lax.stop_gradient(params_notrain)
	params = {**params_train, **params_notrain} # Gets all parameters in single place
	dim, vdmode, nbridges = params_fixed

	# Initial sample
	rng_key = jax.random.PRNGKey(seed)
	z = vd.sample_rep(rng_key, vdmode, params['vd'])
	_, rng_key = jax.random.split(rng_key)
	rho = sample_gaussian(rng_key, dim, scale = 1. / np.sqrt(params['beta_0']))
	w = -vd.log_prob(vdmode, params['vd'], z)
	w = w - log_prob_gaussian(rho, dim, scale = 1. / np.sqrt(params['beta_0'])) 
	w = w + dim * np.log(params['beta_0']) / 2.


	def evolve_bridges(aux, i):
		def U(z):
			return -1. * log_prob(z)
		U_grad = jax.grad(U, 0)
		b0 = params['beta_0']
		z, rho, nbridges, sqrt_beta_k_1, k = aux

		eps = params['eps']

		# Leapfrog
		rho = rho - eps * U_grad(z) / 2.
		z = z + eps * rho
		rho = rho - eps * U_grad(z) / 2.
		
		# Cool down
		sqrt_beta_k = 1. / ((1. - 1. / np.sqrt(b0)) * (1. * k ** 2) / (1. * nbridges ** 2) + 1. / np.sqrt(b0))
		rho = rho * sqrt_beta_k_1 / sqrt_beta_k
		
		aux = (z, rho, nbridges, sqrt_beta_k, k + 1)
		return aux, z

	# Evolve
	if nbridges >= 1:
		aux = (z, rho, nbridges, np.sqrt(params['beta_0']), 1) # one to last is sqrt(beta_{k-1}), last is k
		aux, zs = jax.lax.scan(evolve_bridges, aux, np.arange(nbridges))
		z, rho, _, _, _ = aux

	# Evaluate model at final sample
	w = w + log_prob(z)
	w = w + log_prob_gaussian(rho, dim, scale = 1.)
	return -1. * w, (z, zs)



compute_ratio_vec = jax.vmap(compute_ratio, in_axes = (0, None, None, None, None))

@functools.partial(jax.jit, static_argnums = (2, 3, 4))
def compute_avg_ratio(seeds, params_flat, unflatten, params_fixed, log_prob):
	dim, vdmode, nbridges = params_fixed
	ratios, (z, zs) = compute_ratio_vec(seeds, params_flat, unflatten, params_fixed, log_prob)
	return ratios.mean(), (ratios.mean(), zs)




