import jax
import jax.numpy as np
import numpyro.distributions as npdist
from jax.experimental import stax
from jax.experimental.stax import (BatchNorm, Conv, Dense, Flatten, Relu, Softplus, Sigmoid, Tanh)


# This is one hidden layer
def initialize(rng_key, xdim, hsize, zdim):
	print('One hidden')
	osize = int(2 * zdim)
	init_fun, feed_forward = stax.serial(Dense(hsize), Softplus, Dense(osize))
	output_shape, params = init_fun(rng_key, (xdim,)) # (N, H, W, C)
	def encode(x, params, zdim):
		x = np.reshape(x, (-1,))
		out = feed_forward(params, x)
		mean = out[:zdim]
		scale = np.exp(out[zdim:])
		return mean, scale
	return params, encode


# # This is from the iwae paper
# def initialize(rng_key, xdim, hsize, zdim):
# 	print('Two hidden iwae')
# 	osize = int(2 * zdim)
# 	init_fun, feed_forward = stax.serial(Dense(hsize), Tanh, Dense(hsize), Tanh, Dense(osize))
# 	output_shape, params = init_fun(rng_key, (xdim,)) # (N, H, W, C)
# 	def encode(x, params, zdim):
# 		x = np.reshape(x, (-1,))
# 		out = feed_forward(params, x)
# 		mean = out[:zdim]
# 		scale = np.exp(out[zdim:])
# 		return mean, scale
# 	return params, encode

def sample_rep(rng_key, mean, scale, zdim):
	eps = jax.random.normal(rng_key, shape = (zdim,))
	return scale * eps + mean

def log_prob(z, mean, scale):
	dist = npdist.Independent(npdist.Normal(loc = mean, scale = scale), 1)
	return dist.log_prob(z)
