import jax.numpy as np
import numpyro.distributions as npdist
import jax
import variationaldist as vd
import bridge_dist as bd
import momdist as md
from jax.flatten_util import ravel_pytree
import functools


# For now, betas equally spaced between 0 and 1.

def initialize(dim, vdmode = 1, vdparams = None, nbridges = 0, lfsteps = 1, eps = 0.001, eps_var = 0.0001, eta = 0.9, mdparams = None, ngridb = 32,
	mgridref_y = None, exact = False, mode_eps = 'single', brparams = None, brmode = 'brvd', trainable = ['vd']):
	"""
	- vdmode is the variational distribution used: 'diag' or 'full'.
	- vdparams are the parameters of the variational distribution.
	- nbridges is the number of bridging densities.
	- lfsteps is the number of leapfrog integrator steps per iteration (1 -> Langevin, 1+ -> HMC).
	- eps is the log of the step size for the leapfrog integrator.
	- eta is the under-damped coefficient (between 0 and 1).
	- mdparams are the parameters of the momentum distribution.
	- exact is whether we run exact AIS or our method.
	- mgridref_y are the base quantities for computing betas.
	-ngridb is the number of grid points for the beta grid which is then interpolated.
	-mode_eps can be 'single' (learns just one step-size for the whole thing), or 'affine' (learns an affine function of beta).
	-brparams represents the parameters of the bridging densities.
	-brmode 'brvd', 'brlinear' (in beta) or 'braffinee'
	All parameters, if left at None, are trained. If some parameter is given a vale, then it is left fixed.
	"""
	params_train = {} # Has all trainable parameters
	params_notrain = {} # Non trainable parameters
	
	# Variational distribution parameters
	if 'vd' in trainable:
		params_train['vd'] = vdparams
		if vdparams is None:
			params_train['vd'] = vd.initialize(dim, vdmode)
	else:
		params_notrain['vd'] = vdparams
		if vdparams is None:
			params_notrain['vd'] = vd.initialize(dim, vdmode)

	# Extra parameters for bridging density
	if 'br' in trainable and brmode != 'brvd':
		params_train['br'] = brparams
		if brparams is None:
			params_train['br'] = bd.initialize(dim, brmode)
	else:
		params_notrain['br'] = brparams
		if brparams is None:
			params_notrain['br'] = bd.initialize(dim, brmode)

	# Leapfrog step-size
	# eps_var only used when using affine for step-size mode
	if 'eps' in trainable:
		params_train['eps'] = eps
		params_train['eps_var'] = eps_var
	else:
		params_notrain['eps'] = eps
		params_notrain['eps_var'] = eps_var

	# Under-damped coefficient
	if 'eta' in trainable:
		params_train['eta'] = eta
	else:
		params_notrain['eta'] = eta

	# Momentum distribution
	if 'md' in trainable:
		params_train['md'] = mdparams
		if mdparams is None:
			params_train['md'] = md.initialize(dim)
	else:
		params_notrain['md'] = mdparams
		if mdparams is None:
			params_notrain['md'] = md.initialize(dim)

	# Everything related to betas
	if mgridref_y is not None:
		ngridb = mgridref_y.shape[0] - 1
	else:
		if nbridges < ngridb:
			ngridb = nbridges
		mgridref_y = np.ones(ngridb + 1) * 1.
	params_notrain['gridref_x'] = np.linspace(0, 1, ngridb + 2)
	params_notrain['target_x'] = np.linspace(0, 1, nbridges + 2)[1:-1]
	if 'mgridref_y' in trainable:
		params_train['mgridref_y'] = mgridref_y
	else:
		params_notrain['mgridref_y'] = mgridref_y

	# Other fixed parameters
	params_fixed = (dim, vdmode, nbridges, lfsteps, exact, mode_eps, brmode)
	params_flat, unflatten = ravel_pytree((params_train, params_notrain))
	return params_flat, unflatten, params_fixed


def compute_ratio(seed, params_flat, unflatten, params_fixed, log_prob):
	def get_eps(params, beta, mode_eps):
		if mode_eps == 'single':
			return params['eps']
		elif mode_eps == 'affine':
			return params['eps'] + beta * params['eps_var']
		else:
			raise NotImplementedError('Mode eps %s not implemented.' % mode_eps)
	params_train, params_notrain = unflatten(params_flat)
	params_notrain = jax.lax.stop_gradient(params_notrain)
	params = {**params_train, **params_notrain} # Gets all parameters in single place
	dim, vdmode, nbridges, lfsteps, exact, mode_eps, brmode = params_fixed

	if nbridges >= 1:
		gridref_y = np.cumsum(params['mgridref_y']) / np.sum(params['mgridref_y'])
		gridref_y = np.concatenate([np.array([0.]), gridref_y])
		betas = np.interp(params['target_x'], params['gridref_x'], gridref_y)

	rejections = 0
	# Initial sample
	rng_key = jax.random.PRNGKey(seed)
	z = vd.sample_rep(rng_key, vdmode, params['vd'])
	w = -vd.log_prob(vdmode, params['vd'], z)
	# w = -vd.log_prob(vdmode, jax.lax.stop_gradient(params['vd']), z)

	def evolve_bridges(aux, i):
		def U(z, beta, vdmode, vdparams, brmode, brparams):
			return -1. * (beta * log_prob(z) + (1. - beta) * bd.log_prob(z, beta, vdmode, vdparams, brmode, brparams))
		z, rho_prev, w, rng_key, rejections = aux
		_, rng_key = jax.random.split(rng_key)
		# beta = params['betas'][i] # For prev implementation without interp nor trainable
		beta = betas[i]
		eps = get_eps(params, beta, mode_eps)
		rho = md.sample(rng_key, params['eta'], rho_prev, params['md'])
		z_prev = z
		z, rho_new, reject = leapfrog(z, rho, eps, lfsteps, beta, log_prob, params['vd'], vdmode, params['md'], exact, rng_key, U, 
			brmode, params['br'])
		rejections += reject
		if exact:
			w = w + U(z, beta, vdmode, params['vd'], brmode, params['br']) - U(z_prev, beta, vdmode, params['vd'], brmode, params['br'])
		else:
			w = w + md.log_prob(rho_new, params['md']) - md.log_prob(rho, params['md'])
		aux = (z, rho_new, w, rng_key, rejections)
		return aux, None

	# Evolve
	if nbridges >= 1:
		_, rng_key = jax.random.split(rng_key)
		rho_prev = md.sample(rng_key, params['eta'], None, params['md'])
		aux = (z, rho_prev, w, rng_key, 0) # Last is counter for rejections
		aux = jax.lax.scan(evolve_bridges, aux, np.arange(nbridges))[0]
		z, _, w, _, rejections = aux

	# Evaluate model at final sample
	w = w + log_prob(z)
	return -1. * w, (z, rejections)


# This can be ignored
def compute_ratio_bis(seed, params_flat, unflatten, params_fixed, log_prob):
	ratios, _ = compute_ratio(seed, params_flat, unflatten, params_fixed, log_prob)
	return -ratios, -ratios


def leapfrog(z, rho, eps, lfsteps, beta, log_prob, vdparams, vdmode, mdparams, exact, rng_key, U, brmode, brparams):
	def K(rho, mdparams):
		return -1. * md.log_prob(rho, mdparams)
	U_grad = jax.grad(U, 0)
	K_grad = jax.grad(K, 0)

	# Save current if exact
	if exact:
		z_prev = z
		rho_prev = rho
	###

	def full_leap(aux, i):
		z, rho = aux
		rho = rho - eps * U_grad(z, beta, vdmode, vdparams, brmode, brparams)
		z = z + eps * K_grad(rho, mdparams)
		aux = (z, rho)
		return aux, None

	# Half step for momentum and full step for z
	rho = rho - eps * U_grad(z, beta, vdmode, vdparams, brmode, brparams) / 2.
	z = z + eps * K_grad(rho, mdparams)
	
	# Alternate full steps
	if lfsteps > 1:
		aux = (z, rho)
		aux = jax.lax.scan(full_leap, aux, np.arange(lfsteps - 1))[0]
		z, rho = aux

	# Half step for momentum
	rho = rho - eps * U_grad(z, beta, vdmode, vdparams, brmode, brparams) / 2.

	# MH step if exact
	reject = 0
	if exact:
		rng_key, _ = jax.random.split(rng_key)
		vals = (z, rho, z_prev, rho_prev)
		accept = lambda vals: (vals[0], vals[1], 0)
		# reject = lambda vals: (vals[2], vals[3], 1)
		reject = lambda vals: (vals[2], -vals[3], 1) # Only negate momentum when reject sample
		U_prev = U(z_prev, beta, vdmode, vdparams, brmode, brparams)
		K_prev = K(rho_prev, mdparams)
		U_new = U(z, beta, vdmode, vdparams, brmode, brparams)
		K_new = K(rho, mdparams)
		cond = np.log(jax.random.uniform(rng_key)) < U_prev - U_new + K_prev - K_new
		z, rho, reject = jax.lax.cond(cond, accept, reject, vals)
	###
	return z, rho, reject

compute_ratio_vec = jax.vmap(compute_ratio, in_axes = (0, None, None, None, None))

@functools.partial(jax.jit, static_argnums = (2, 3, 4))
def compute_avg_ratio(seeds, params_flat, unflatten, params_fixed, log_prob):
	dim, vdmode, nbridges, lfsteps, exact, mode_eps, brmode = params_fixed
	ratios, (_, rejections) = compute_ratio_vec(seeds, params_flat, unflatten, params_fixed, log_prob)
	return ratios.mean(), (ratios.mean(), rejections.mean() / nbridges)




