import jax
import jax.numpy as np
import nn
from jax.nn import sigmoid
import numpyro.distributions as npdist
from jax.experimental import stax
from jax.experimental.stax import (BatchNorm, Conv, ConvTranspose, Dense, Flatten, Relu, Softplus, Sigmoid, Tanh)


# This is with one hidden layer
def initialize(rng_key, xdim, hsize, zdim):
	print('One hidden')
	init_fun, feed_forward = stax.serial(Dense(hsize), Softplus, Dense(xdim))
	out_shape, params = init_fun(rng_key, (zdim,))
	print(out_shape)
	def decode(z, params):
		out = feed_forward(params, z)
		return np.reshape(out, (1, 28, 28, 1))
	return params, decode


# # This is from the iwae paper
# def initialize(rng_key, xdim, hsize, zdim):
# 	print('Two hidden iwae')
# 	init_fun, feed_forward = stax.serial(Dense(hsize), Tanh, Dense(hsize), Tanh, Dense(xdim))
# 	out_shape, params = init_fun(rng_key, (zdim,))
# 	print(out_shape)
# 	def decode(z, params):
# 		out = feed_forward(params, z)
# 		return np.reshape(out, (1, 28, 28, 1))
# 	return params, decode

def log_prob(x, logits):
	dist = npdist.Bernoulli(logits = logits)
	return dist.log_prob(x).sum()

def log_prior(z, zdim):
	dist = npdist.Independent(npdist.Normal(loc = np.zeros(zdim), scale = np.ones(zdim)), 1)
	return dist.log_prob(z)


