import jax.numpy as np
import jax
from jax.flatten_util import ravel_pytree
import functools
import encoder
import decoder


# For now, betas equally spaced between 0 and 1.

def initialize(zdim, xdim, hsize = 50, encoderparams = None, decoderparams = None, k = 10, trainable = ('vd',)):
	params_train = {} # Has all trainable parameters
	params_notrain = {} # Non trainable parameters

	rng_key = jax.random.PRNGKey(1)
	# Encoder parameters
	rng_key, rng_key_aux = jax.random.split(rng_key)
	encoderparams_init, encode_fn = encoder.initialize(rng_key_aux, xdim, hsize, zdim)
	if encoderparams is None:
		encoderparams = encoderparams_init
	if 'encoder' in trainable:
		params_train['encoder'] = encoderparams
	else:
		params_notrain['encoder'] = encoderparams

	# Decoder parameters
	rng_key, rng_key_aux = jax.random.split(rng_key)
	decoderparams_init, decode_fn = decoder.initialize(rng_key_aux, xdim, hsize, zdim)
	if decoderparams is None:
		decoderparams = decoderparams_init
	if 'decoder' in trainable:
		params_train['decoder'] = decoderparams
	else:
		params_notrain['decoder'] = decoderparams
	
	# Other fixed parameters - these are always fixed
	params_fixed = (zdim, xdim, hsize, k, encode_fn, decode_fn)
	params_flat, unflatten = ravel_pytree((params_train, params_notrain))
	return params_flat, unflatten, params_fixed

def compute_log_weight(seed, x, params_flat, unflatten, params_fixed):
	x = x[None, :]
	params_train, params_notrain = unflatten(params_flat)
	params_notrain = jax.lax.stop_gradient(params_notrain)
	params = {**params_train, **params_notrain}
	zdim, xdim, hsize, k, encode_fn, decode_fn = params_fixed
	# Sample
	rng_key = jax.random.PRNGKey(seed)

	enc_mean, enc_scale = encode_fn(x, params['encoder'], zdim)
	z = encoder.sample_rep(rng_key, enc_mean, enc_scale, zdim)
	w = -encoder.log_prob(z, jax.lax.stop_gradient(enc_mean), jax.lax.stop_gradient(enc_scale)) # For drep

	params_out = decode_fn(z, params['decoder'])
	w = w + decoder.log_prior(z, zdim)
	w = w + decoder.log_prob(x, params_out)

	return w

def compute_bound(seed, x, params_flat, unflatten, params_fixed):
	# This function returns two things:
	# 1- The loss to compute the drep estimator
	# 2- The loss to plot
	rng_key = jax.random.PRNGKey(seed)
	zdim, xdim, hsize, k, _, _ = params_fixed
	seeds = jax.random.randint(rng_key, (k,), 1, 1e6)
	log_ws = jax.vmap(compute_log_weight, in_axes = (0, None, None, None, None))(seeds, x, params_flat, unflatten, params_fixed,)
	max_log_w = np.max(log_ws)
	log_ws_shifted = log_ws - max_log_w
	# Drep gradient loss
	ws_normalized = np.exp(log_ws_shifted) / np.sum(np.exp(log_ws_shifted))
	loss_grad = np.square(jax.lax.stop_gradient(ws_normalized)) * log_ws # For drep
	# loss_grad = jax.lax.stop_gradient(ws_normalized) * log_ws # For rep
	# Loss loss (tracking IWELBO)
	loss_loss = np.log(1. / k)
	loss_loss = loss_loss + max_log_w
	loss_loss = loss_loss + np.log(np.sum(np.exp(log_ws_shifted)))
	return -1. * loss_grad, -1. * loss_loss

compute_bound_vec = jax.vmap(compute_bound, in_axes = (0, 0, None, None, None))

@functools.partial(jax.jit, static_argnums = (3, 4))
# def compute_avg_bound(seeds, params_flat, unflatten, params_fixed, log_prob):
def compute_avg_bound(seeds, xs, params_flat, unflatten, params_fixed):
	bounds_grad, bounds_loss = compute_bound_vec(seeds, xs, params_flat, unflatten, params_fixed)
	return bounds_grad.mean(), (bounds_loss.mean(), None)






