from jax.lib import xla_bridge
platform = xla_bridge.get_backend().platform

if 'gpu' in platform:
	import numpyro
	numpyro.set_platform('gpu')
else:
	import numpyro

import tensorflow_datasets as tfds
import jax
import jax.numpy as np
import aisboundingmachine as bm
import iwboundingmachine as iwbm
import eaisboundingmachine as ebm
import opt
from timeit import default_timer as timer
import tensorflow as tf
from tqdm import tqdm
import argparse
import pickle
from jax.flatten_util import ravel_pytree
import tensorflow_datasets as tfds
from timeit import default_timer as timer
import sys
import data_manager as dm

print('========================')
print(xla_bridge.get_backend().platform)
print('========================')



args_parser = argparse.ArgumentParser(description = 'Process arguments')
args_parser.add_argument('-boundmode', type = str, default = 'AIS', help = 'What bounding machine to use.')
args_parser.add_argument('-dset', type = str, default = 'mnist', help = 'mnist or fashion_mnist.')
args_parser.add_argument('-bs', type = int, default = 100, help = 'Batch size.')
args_parser.add_argument('-nbridges', type = int, default = 8, help = 'Number of bridging densities.')
args_parser.add_argument('-k', type = int, default = 8, help = 'Number of samples for IW.')
args_parser.add_argument('-k_eval', type = int, default = 8, help = 'Number of samples for IW.')
args_parser.add_argument('-lfsteps', type = int, default = 1, help = 'Number of leapfrog steps.')
args_parser.add_argument('-iters_base', type = int, default = 60000, help = 'Number of iterations for base VAE.')
args_parser.add_argument('-iters_tune', type = int, default = 80000, help = 'Number of iterations for tuning with AAIS.')
args_parser.add_argument('-lr', type = float, default = 0.0001, help = 'Learning rate.')
args_parser.add_argument('-id', type = int, default = -1, help = 'ID.')
args_parser.add_argument('-zdim', type = int, default = 64, help = 'Dimension of latent space.')
args_parser.add_argument('-hsize', type = int, default = 450, help = 'Size of hidden layer.')
args_parser.add_argument('-seed', type = int, default = 1, help = 'Random seed to use.')
args_parser.add_argument('-run_cluster', type = int, default = 0, help = '1: true, 0: false.')
args_parser.add_argument('-tune_md', type = int, default = 0, help = 'Tune moment distribution.')
args_parser.add_argument('-tune_beta', type = int, default = 0, help = 'Tune betas.')
args_parser.add_argument('-tune_decoder', type = int, default = 0, help = 'Tune decoder.')
args_parser.add_argument('-tune_encoder', type = int, default = 0, help = 'Tune encoder.')
args_parser.add_argument('-eps_vec', type = int, default = 0, help = '0: single eps, 1: vector eps.')
args_parser.add_argument('-mode_eps', type = str, default = 'single', help = 'single or affine.')
####
args_parser.add_argument('-nbridges_eval', type = int, default = 2000, help = 'Number of bridging densities.')
args_parser.add_argument('-lfsteps_eval', type = int, default = 16, help = 'Number of bridging densities.')
args_parser.add_argument('-iters_eval', type = int, default = 3, help = 'Number of passes of dset.')
args_parser.add_argument('-eps_eval', type = float, default = 0.05, help = 'Integrators step size.')
args_parser.add_argument('-eta_eval', type = float, default = 0.8, help = 'Under-damping coefficient.')
info = args_parser.parse_args()


def evaluate(params_flat, unflatten, params_fixed, ratio_and_rej, rng_key, info, mode = 'val', iters = 15):
	# Load data
	ds = dm.load_dataset(info.dset, mode = mode)
	bs = 100
	numbatches = int(ds.shape[0] / bs)

	losses = []
	rejections = []
	for i in range(iters):
		print(i, iters)
		startindex = 0
		# for j in tqdm(range(numbatches)):
		for j in range(numbatches):
			# Load batch
			x = ds[startindex : startindex + bs]
			rng_key, rng_key_aux = jax.random.split(rng_key)
			x = dm.binarize(rng_key_aux, x)

			# Compute loss
			rng_key, rng_key_aux = jax.random.split(rng_key)
			seeds = jax.random.randint(rng_key_aux, (x.shape[0],), 1, 1e5)
			loss, rej = ratio_and_rej(seeds, x, params_flat, unflatten, params_fixed)
			losses.append(loss.item())
			rejections.append(rej.item())
			
			# Update
			startindex += bs
		print(np.array(losses).mean(), np.array(rejections).mean() / info.nbridges)
	return losses, rejections



xdim = 28 * 28
rng_key = jax.random.PRNGKey(info.seed)
tf.random.set_seed(info.seed)
print('zdim', info.zdim)



print('Loading base VAE')
dir_params = './params_saved_base_all/'

# prefix = '%sonehidden_%i_%i_' % (dir_params, info.hsize, info.zdim) # This is for onehidden
prefix = '%siwae_%i_%i_' % (dir_params, info.hsize, info.zdim) # This is IWAE, tanh activation
name = '%sparams_encoder_decoder_%s.pkl' % (prefix, info.dset)
if 'letters' in info.dset:
	name = '%sparams_encoder_decoder_letters.pkl' % (prefix)

print(name)
with open(name, 'rb') as f:
	encoderparams, decoderparams = pickle.load(f)
print(len(encoderparams), len(decoderparams))

print('Done loading')

if info.boundmode == 'IW':
	trainable = []
	if info.tune_encoder == 1:
		trainable.append('encoder')
	if info.tune_decoder == 1:
		trainable.append('decoder')
	trainable = tuple(trainable)

if info.boundmode == 'AIS':
	trainable = ['eps', 'eta']
	if info.tune_md == 1:
		trainable.append('md')
	if info.tune_beta == 1:
		trainable.append('mbetas')
	if info.tune_encoder == 1:
		trainable.append('encoder')
	if info.tune_decoder == 1:
		trainable.append('decoder')
	trainable = tuple(trainable)


if info.boundmode == 'AIS':
	print('Tuning AIS %i' % (info.nbridges))
	params_flat, unflatten, params_fixed = bm.initialize(info.zdim, xdim, encoderparams = encoderparams, decoderparams = decoderparams, 
		hsize = info.hsize, nbridges = info.nbridges, epsvec = info.eps_vec, trainable = trainable, mode_eps = info.mode_eps)
	grad_and_loss = jax.jit(jax.grad(bm.compute_avg_ratio, 2, has_aux = True), static_argnums = (3, 4))

if info.boundmode == 'IW':
	print('Tuning IW %i' % (info.k))
	params_flat, unflatten, params_fixed = iwbm.initialize(info.zdim, xdim, encoderparams = encoderparams, decoderparams = decoderparams, 
		hsize = info.hsize, k = info.k, trainable = trainable)
	grad_and_loss = jax.jit(jax.grad(iwbm.compute_avg_bound, 2, has_aux = True), static_argnums = (3, 4))

platform = xla_bridge.get_backend().platform
if 'gpu' not in platform and info.run_cluster == 1:
	import os
	os.system('echo No gpu...')
	print('Not in gpu...')
	print('Not in gpu...', file = sys.stderr)
	exit()


start = timer()
losses_tune, diverged_tune, best_loss_val, last_loss_val, best_params, last_params, best_epoch_val, epss, etas = opt.run(info.lr, params_flat,
	unflatten, params_fixed, grad_and_loss, trainable, rng_key, info.iters_tune, info)
end = timer()
print('Tune took', end - start)
print('Done Training')

best_loss_test = opt.evaluate(best_params, unflatten, params_fixed, grad_and_loss, rng_key, info, mode = 'test', iters = 50)
best_loss_test = np.mean(np.array(best_loss_test))
print(np.mean(np.array(best_loss_test)))

print('Done evaluating')



########### Done training and measuring test ELBO. Now, evaluate LL
rng_key = jax.random.PRNGKey(info.seed)
#### Load parameters
dir_params = './'

# Load encoder parameters - these are the base ones
print('Loading encoder')
print(name)
with open(name, 'rb') as f:
	encoderparams, _ = pickle.load(f)

# Load decoder parameters
print('Loading decoder')
decoderparams = unflatten(best_params)[0]['decoder']
print('Done loading')

# Set up bounding machine
params_flat, unflatten, params_fixed = ebm.initialize(info.zdim, xdim, encoderparams = encoderparams, decoderparams = decoderparams, 
	hsize = info.hsize, nbridges = info.nbridges_eval, lfsteps = info.lfsteps_eval, eps = info.eps_eval, eta = info.eta_eval)
ratio_and_rej = jax.jit(ebm.compute_avg_ratio, static_argnums = (3, 4))

# Evaluate
loss_test_EAIS, rej_test_EAIS = evaluate(params_flat, unflatten, params_fixed, ratio_and_rej, rng_key, info, mode = 'test', iters = info.iters_eval)
loss_test_EAIS = np.mean(np.array(loss_test_EAIS))
rej_test_EAIS = np.mean(np.array(rej_test_EAIS)) / info.nbridges_eval
print('test loss', loss_test_EAIS)
print('test rr', rej_test_EAIS)



if info.run_cluster == 0:
	def smooth(y, n = 300):
		y = np.convolve(y, np.ones(n), mode = 'valid') / n
		return y
	import matplotlib.pyplot as plt
	fig, (ax1, ax2, ax3) = plt.subplots(3)
	ax1.plot(losses_tune)
	ax1.plot(smooth(np.array(losses_tune)))
	ax2.plot(epss)
	ax3.plot(etas)
	plt.show()
else:
	# Save results
	dresults = './'
	name = '%i_%s_%i_%i.pkl' % (info.id, info.boundmode, info.nbridges, info.lfsteps)
	with open(dresults + name, 'wb') as f:
		pickle.dump((info, losses_tune, best_loss_test, best_loss_val, best_epoch_val, diverged_tune, loss_test_EAIS, rej_test_EAIS, epss, etas), f)
	print('Done saving results.')












