import jax.numpy as np
import jax
import encoder
import decoder
import momdist as md
from jax.flatten_util import ravel_pytree
import functools


def initialize(zdim, xdim, hsize = 50, encoderparams = None, decoderparams = None, nbridges = 200, lfsteps = 1, eps = 0.05, eta = 0.9):
	params_train = {} # Has all trainable parameters
	params_notrain = {} # Non trainable parameters
	
	rng_key = jax.random.PRNGKey(1)

	# Encoder parameters
	encoderparams_init, encode_fn = encoder.initialize(rng_key, xdim, hsize, zdim)
	if encoderparams is None:
		raise NotImplementedError('Should pass the encoder parameters.')
	params_notrain['encoder'] = encoderparams

	# Decoder parameters
	decoderparams_init, decode_fn = decoder.initialize(rng_key, xdim, hsize, zdim)
	if decoderparams is None:
		raise NotImplementedError('Should pass the decoder parameters.')
	params_notrain['decoder'] = decoderparams

	# Leapfrog step-size, eps_var only used in 'affine' mode
	params_train['eps'] = eps

	# Under-damping coefficient
	params_train['eta'] = eta

	# Momentum distribution
	# Always set to identity
	params_train['md'] = md.initialize(zdim)

	mbetas = np.ones(nbridges + 1) * 1.
	params_train['mbetas'] = mbetas

	# Other fixed parameters
	params_fixed = (zdim, xdim, hsize, nbridges, lfsteps, encode_fn, decode_fn)
	params_flat, unflatten = ravel_pytree((params_train, params_notrain))
	return params_flat, unflatten, params_fixed


def compute_ratio(seed, x, params_flat, unflatten, params_fixed):
	x = x[None, :]
	params_train, params_notrain = unflatten(params_flat)
	params_notrain = jax.lax.stop_gradient(params_notrain)
	params = {**params_train, **params_notrain} # Gets all parameters in single place
	zdim, xdim, hsize, nbridges, lfsteps, encode_fn, decode_fn = params_fixed

	betas = np.cumsum(params['mbetas']) / np.sum(params['mbetas'])

	# Initial sample
	rng_key = jax.random.PRNGKey(seed)
	rng_key, rng_key_aux = jax.random.split(rng_key)
	enc_mean, enc_scale = encode_fn(x, params['encoder'], zdim)
	z = encoder.sample_rep(rng_key_aux, enc_mean, enc_scale, zdim)
	w = -encoder.log_prob(z, enc_mean, enc_scale)

	def evolve_bridges(aux, i):
		def U(z, x, enc_mean, enc_scale, beta, decoderparams):
			params_out = decode_fn(z, decoderparams)
			out = beta * decoder.log_prior(z, zdim)
			out = out + beta * decoder.log_prob(x, params_out)
			out = out + (1. - beta) * encoder.log_prob(z, enc_mean, enc_scale)
			return -1. * out
		
		z, rho_prev, w, rng_key, rejections = aux

		beta = betas[i]
		eps = params['eps']

		rng_key, rng_key_aux = jax.random.split(rng_key)
		rho = md.sample(rng_key_aux, params['eta'], rho_prev, params['md'])

		z_prev = z
		rng_key, rng_key_aux = jax.random.split(rng_key)
		z, rho_new, reject = leapfrog(z, rho, x, eps, lfsteps, beta, enc_mean, enc_scale, params['decoder'], params['md'], U, rng_key_aux)
		rejections += reject
		
		w = w + U(z, x, enc_mean, enc_scale, beta, params['decoder']) -  U(z_prev, x, enc_mean, enc_scale, beta, params['decoder'])
		aux = (z, rho_new, w, rng_key, rejections)
		return aux, None

	# Evolve
	if nbridges >= 1:
		rng_key, rng_key_aux = jax.random.split(rng_key)
		rho_prev = md.sample(rng_key_aux, params['eta'], None, params['md'])
		aux = (z, rho_prev, w, rng_key, 0) # Last is counter for rejections
		aux = jax.lax.scan(evolve_bridges, aux, np.arange(nbridges))[0]
		z, _, w, _, rejections = aux

	# Evaluate model at final sample
	params_out = decode_fn(z, params['decoder'])
	w = w + decoder.log_prior(z, zdim)
	w = w + decoder.log_prob(x, params_out)
	return w, rejections



def leapfrog(z, rho, x, eps, lfsteps, beta, enc_mean, enc_scale, decoderparams, mdparams, U, rng_key):
	def K(rho, mdparams):
		return -1. * md.log_prob(rho, mdparams)
	U_grad = jax.grad(U, 0)
	K_grad = jax.grad(K, 0)

	z_prev = z
	rho_prev = rho

	def full_leap(aux, i):
		z, rho = aux
		rho = rho - eps * U_grad(z, x, enc_mean, enc_scale, beta, decoderparams)
		z = z + eps * K_grad(rho, mdparams)
		aux = (z, rho)
		return aux, None

	# Half step for momentum and full step for z
	rho = rho - eps * U_grad(z, x, enc_mean, enc_scale, beta, decoderparams) / 2.
	z = z + eps * K_grad(rho, mdparams)
	
	# Alternate full steps
	if lfsteps > 1:
		aux = (z, rho)
		aux = jax.lax.scan(full_leap, aux, np.arange(lfsteps - 1))[0]
		z, rho = aux

	# Half step for momentum
	rho = rho - eps * U_grad(z, x, enc_mean, enc_scale, beta, decoderparams) / 2.

	# Accept/reject
	reject = 0
	rng_key, rng_key_aux = jax.random.split(rng_key)
	vals = (z, rho, z_prev, rho_prev)
	accept = lambda vals: (vals[0], vals[1], 0)
	reject = lambda vals: (vals[2], -vals[3], 1) # Only negate momentum when reject sample
	U_prev = U(z_prev, x, enc_mean, enc_scale, beta, decoderparams)
	K_prev = K(rho_prev, mdparams)
	U_new = U(z, x, enc_mean, enc_scale, beta, decoderparams)
	K_new = K(rho, mdparams)
	cond = np.log(jax.random.uniform(rng_key_aux)) < U_prev - U_new + K_prev - K_new
	z, rho, reject = jax.lax.cond(cond, accept, reject, vals)
	return z, rho, reject


compute_ratio_vec = jax.vmap(compute_ratio, in_axes = (0, 0, None, None, None)) # One sample z per x


@functools.partial(jax.jit, static_argnums = (3, 4))
def compute_avg_ratio(seeds, xs, params_flat, unflatten, params_fixed):
	# One sample per x
	ratios, rejections = compute_ratio_vec(seeds, xs, params_flat, unflatten, params_fixed)
	return ratios.mean(), rejections.mean()


