import jax.numpy as np
import numpyro.distributions as npdist
import jax
import variationaldist as vd
import reversaldist as rd
import momdist as md
from jax.flatten_util import ravel_pytree
import functools


# Gaussian tools
def build_gaussian(dim, scale):
	return npdist.Independent(npdist.Normal(loc = np.zeros(dim), scale = np.ones(dim) * scale), 1)

def log_prob_gaussian(rho, dim, scale):
	dist = build_gaussian(dim, scale)
	return dist.log_prob(rho)

def sample_gaussian(rng_key, dim, scale):
	eps = jax.random.normal(rng_key, shape = (dim,))
	return eps * scale



def initialize(dim, vdmode = 1, vdparams = None, rdparams = None, nbridges = 0, eps = 0.001, beta_0 = 0.9, mdparams = None, 
	eps_vec = False, trainable = ['vd']):
	"""
	- vdmode is the variational distribution used: 'diag' or 'full'.
	- vdparams are the parameters of the variational distribution.
	- nbridges is the number of bridging densities.
	- eps is the log of the step size for the leapfrog integrator.
	- mdparams are the parameters of the momentum distribution.
	All parameters, if left at None, are trained. If some parameter is given a vale, then it is left fixed.
	"""
	params_train = {} # Has all trainable parameters
	params_notrain = {} # Non trainable parameters
	
	# Variational distribution parameters
	if 'vd' in trainable:
		params_train['vd'] = vdparams
		if vdparams is None:
			params_train['vd'] = vd.initialize(dim, vdmode)
	else:
		params_notrain['vd'] = vdparams
		if vdparams is None:
			params_notrain['vd'] = vd.initialize(dim, vdmode)
	
	# Leapfrog step-size
	if eps_vec:
		eps = np.ones(dim) * eps
	if 'eps' in trainable:
		params_train['eps'] = eps 
		params_train['eps_var'] = 0.
	else:
		params_notrain['eps'] = eps
		params_notrain['eps_var'] = 0.

	# Momentum distribution
	if 'md' in trainable:
		params_train['md'] = mdparams
		if mdparams is None:
			params_train['md'] = md.initialize(dim)
	else:
		params_notrain['md'] = mdparams
		if mdparams is None:
			params_notrain['md'] = md.initialize(dim)

	# Reversal model
	if 'rd' in trainable:
		params_train['rd'] = rd.initialize(dim)
	else:
		params_notrain['rd'] = rd.initialize(dim)

	# Other fixed parameters
	params_fixed = (dim, vdmode, nbridges)
	params_flat, unflatten = ravel_pytree((params_train, params_notrain))
	return params_flat, unflatten, params_fixed


def compute_ratio(seed, params_flat, unflatten, params_fixed, log_prob):
	params_train, params_notrain = unflatten(params_flat)
	params_notrain = jax.lax.stop_gradient(params_notrain)
	params = {**params_train, **params_notrain} # Gets all parameters in single place
	dim, vdmode, nbridges = params_fixed

	# Initial sample
	rng_key = jax.random.PRNGKey(seed)
	z = vd.sample_rep(rng_key, vdmode, params['vd'])
	_, rng_key = jax.random.split(rng_key)
	rho = sample_gaussian(rng_key, dim, scale = 1.)

	w = -vd.log_prob(vdmode, params['vd'], z)
	w = w - log_prob_gaussian(rho, dim, scale = 1.) 


	def evolve_bridges(aux, i):
		def U(z):
			return -1. * log_prob(z)
		U_grad = jax.grad(U, 0)
		z, rho, nbridges = aux
		eps = params['eps']

		# Leapfrog
		rho = rho - eps * U_grad(z) / 2.
		z = z + eps * rho
		rho = rho - eps * U_grad(z) / 2.
		
		aux = (z, rho, nbridges)
		return aux, None

	# Evolve
	if nbridges >= 1:
		aux = (z, rho, nbridges)
		aux = jax.lax.scan(evolve_bridges, aux, np.arange(nbridges))[0]
		z, rho, _ = aux

	# Evaluate model at final sample
	w = w + log_prob(z)
	w = w + rd.log_prob(rho, params['rd'], z)
	return -1. * w, (z, None)



compute_ratio_vec = jax.vmap(compute_ratio, in_axes = (0, None, None, None, None))

@functools.partial(jax.jit, static_argnums = (2, 3, 4))
def compute_avg_ratio(seeds, params_flat, unflatten, params_fixed, log_prob):
	dim, vdmode, nbridges = params_fixed
	ratios, (z, _) = compute_ratio_vec(seeds, params_flat, unflatten, params_fixed, log_prob)
	return ratios.mean(), (ratios.mean(), None)




