import jax.numpy as np
import jax
import encoder
import decoder
import momdist as md
from jax.flatten_util import ravel_pytree
import functools


def initialize(zdim, xdim, hsize = 50, encoderparams = None, decoderparams = None, nbridges = 0, lfsteps = 1, eps = 0.0, eps_var = 0.0, eta = 0.9,
	mdparams = None, ngridb = 15, mgridref_y = None, epsvec = 0, mode_eps = 'single', trainable = ('encoder', 'decoder')):
	params_train = {} # Has all trainable parameters
	params_notrain = {} # Non trainable parameters
	
	rng_key = jax.random.PRNGKey(1)
	# Encoder parameters
	rng_key, rng_key_aux = jax.random.split(rng_key)
	encoderparams_init, encode_fn = encoder.initialize(rng_key_aux, xdim, hsize, zdim)
	if encoderparams is None:
		encoderparams = encoderparams_init
	if 'encoder' in trainable:
		params_train['encoder'] = encoderparams
	else:
		params_notrain['encoder'] = encoderparams

	# Decoder parameters
	rng_key, rng_key_aux = jax.random.split(rng_key)
	decoderparams_init, decode_fn = decoder.initialize(rng_key_aux, xdim, hsize, zdim)
	if decoderparams is None:
		decoderparams = decoderparams_init
	if 'decoder' in trainable:
		params_train['decoder'] = decoderparams
	else:
		params_notrain['decoder'] = decoderparams

	# Leapfrog step-size, eps_var only used in 'affine' mode
	if 'eps' in trainable:
		params_train['eps'] = eps
		params_train['eps_var'] = eps_var
	else:
		params_notrain['eps'] = eps
		params_notrain['eps_var'] = eps_var

	# Under-damped coefficient
	if 'eta' in trainable:
		params_train['eta'] = eta
	else:
		params_notrain['eta'] = eta

	# Momentum distribution
	if 'md' in trainable:
		params_train['md'] = mdparams
		if mdparams is None:
			params_train['md'] = md.initialize(zdim)
	else:
		params_notrain['md'] = mdparams
		if mdparams is None:
			params_notrain['md'] = md.initialize(zdim)

	mbetas = np.ones(nbridges + 1) * 1.
	if 'mbetas' in trainable:
		params_train['mbetas'] = mbetas
	else:
		params_notrain['mbetas'] = mbetas

	# Other fixed parameters
	params_fixed = (zdim, xdim, hsize, nbridges, lfsteps, mode_eps, encode_fn, decode_fn)
	params_flat, unflatten = ravel_pytree((params_train, params_notrain))
	return params_flat, unflatten, params_fixed


def compute_ratio(seed, x, params_flat, unflatten, params_fixed):
	def get_eps(params, beta, mode_eps):
		if mode_eps == 'single':
			return params['eps']
		elif mode_eps == 'affine':
			return params['eps'] + beta * params['eps_var']
		else:
			raise NotImplementedError('Mode eps %s not implemented.' % mode_eps)

	x = x[None, :]
	params_train, params_notrain = unflatten(params_flat)
	params_notrain = jax.lax.stop_gradient(params_notrain)
	params = {**params_train, **params_notrain} # Gets all parameters in single place
	zdim, xdim, hsize, nbridges, lfsteps, mode_eps, encode_fn, decode_fn = params_fixed

	betas = np.cumsum(params['mbetas']) / np.sum(params['mbetas'])

	# Initial sample
	rng_key = jax.random.PRNGKey(seed)
	rng_key, rng_key_aux = jax.random.split(rng_key)
	enc_mean, enc_scale = encode_fn(x, params['encoder'], zdim)
	z = encoder.sample_rep(rng_key_aux, enc_mean, enc_scale, zdim)
	w = -encoder.log_prob(z, enc_mean, enc_scale)

	def evolve_bridges(aux, i):
		def U(z, x, enc_mean, enc_scale, beta, decoderparams):
			params_out = decode_fn(z, decoderparams)
			out = beta * decoder.log_prior(z, zdim)
			out = out + beta * decoder.log_prob(x, params_out)
			out = out + (1. - beta) * encoder.log_prob(z, enc_mean, enc_scale)
			return -1. * out
		
		z, rho_prev, w, rng_key = aux
		rng_key, rng_key_aux = jax.random.split(rng_key)

		beta = betas[i]
		eps = get_eps(params, beta, mode_eps)
		rho = md.sample(rng_key_aux, params['eta'], rho_prev, params['md'])
		z, rho_new = leapfrog(z, rho, x, eps, lfsteps, beta, enc_mean, enc_scale, params['decoder'], params['md'], U)
		w = w + md.log_prob(rho_new, params['md']) - md.log_prob(rho, params['md'])
		aux = (z, rho_new, w, rng_key)
		return aux, 1 # None

	# Evolve
	if nbridges >= 1:
		rng_key, rng_key_aux = jax.random.split(rng_key)
		rho_prev = md.sample(rng_key_aux, params['eta'], None, params['md'])
		aux = (z, rho_prev, w, rng_key)
		# aux = jax.lax.scan(evolve_bridges, aux, np.arange(nbridges))[0]
		aux, erase = jax.lax.scan(evolve_bridges, aux, np.arange(nbridges))
		z, _, w, _ = aux

	# Evaluate model at final sample
	params_out = decode_fn(z, params['decoder'])
	w = w + decoder.log_prior(z, zdim)
	w = w + decoder.log_prob(x, params_out)
	return -1. * w, (x, w)



def leapfrog(z, rho, x, eps, lfsteps, beta, enc_mean, enc_scale, decoderparams, mdparams, U):
	def K(rho, mdparams):
		return -1. * md.log_prob(rho, mdparams)
	U_grad = jax.grad(U, 0)
	K_grad = jax.grad(K, 0)

	def full_leap(aux, i):
		z, rho = aux
		rho = rho - eps * U_grad(z, x, enc_mean, enc_scale, beta, decoderparams)
		z = z + eps * K_grad(rho, mdparams)
		aux = (z, rho)
		return aux, None

	# Half step for momentum and full step for z
	rho = rho - eps * U_grad(z, x, enc_mean, enc_scale, beta, decoderparams) / 2.
	z = z + eps * K_grad(rho, mdparams)
	
	# Alternate full steps
	if lfsteps > 1:
		aux = (z, rho)
		aux = jax.lax.scan(full_leap, aux, np.arange(lfsteps - 1))[0]
		z, rho = aux

	# Half step for momentum
	rho = rho - eps * U_grad(z, x, enc_mean, enc_scale, beta, decoderparams) / 2.
	return z, rho


compute_ratio_vec = jax.vmap(compute_ratio, in_axes = (0, 0, None, None, None)) # One sample z per x


@functools.partial(jax.jit, static_argnums = (3, 4))
def compute_avg_ratio(seeds, xs, params_flat, unflatten, params_fixed):
	# One sample per x
	ratios, (x, _) = compute_ratio_vec(seeds, xs, params_flat, unflatten, params_fixed)
	return ratios.mean(), (ratios.mean(), x)


