from jax.lib import xla_bridge
platform = xla_bridge.get_backend().platform

if 'gpu' in platform:
	import numpyro
	numpyro.set_platform('gpu')
else:
	import numpyro

import tensorflow_datasets as tfds
import jax
import jax.numpy as np
import aisboundingmachine as bm
import iwboundingmachine as iwbm
import opt
from timeit import default_timer as timer
import tensorflow as tf
from tqdm import tqdm
import argparse
import pickle
from jax.flatten_util import ravel_pytree
import tensorflow_datasets as tfds
from timeit import default_timer as timer
import sys

print('========================')
print(xla_bridge.get_backend().platform)
print('========================')



args_parser = argparse.ArgumentParser(description = 'Process arguments')
args_parser.add_argument('-boundmode', type = str, default = 'AIS', help = 'What bounding machine to use.')
args_parser.add_argument('-dset', type = str, default = 'mnist', help = 'mnist or fashion_mnist.')
args_parser.add_argument('-bs', type = int, default = 64, help = 'Batch size.')
args_parser.add_argument('-nbridges', type = int, default = 8, help = 'Number of bridging densities.')
args_parser.add_argument('-k', type = int, default = 8, help = 'Number of samples for IW.')
args_parser.add_argument('-lfsteps', type = int, default = 1, help = 'Number of leapfrog steps.')
args_parser.add_argument('-iters_base', type = int, default = 130000, help = 'Number of iterations for base VAE.')
# For a bs of 64 100000 iters gives 128 epochs
args_parser.add_argument('-iters_tune', type = int, default = 5000, help = 'Number of iterations for tuning with AAIS.')
args_parser.add_argument('-lr', type = float, default = 0.0001, help = 'Learning rate.')
args_parser.add_argument('-id', type = int, default = -1, help = 'ID.')
args_parser.add_argument('-zdim', type = int, default = 64, help = 'Dimension of latent space.')
args_parser.add_argument('-hsize', type = int, default = 450, help = 'Size of hidden layer.')
args_parser.add_argument('-seed', type = int, default = 1, help = 'Random seed to use.')
args_parser.add_argument('-run_cluster', type = int, default = 0, help = '1: true, 0: false.')
args_parser.add_argument('-tune_md', type = int, default = 0, help = 'Tune moment distribution.')
args_parser.add_argument('-tune_beta', type = int, default = 0, help = 'Tune betas.')
args_parser.add_argument('-tune_decoder', type = int, default = 0, help = 'Tune decoder.')
args_parser.add_argument('-tune_encoder', type = int, default = 0, help = 'Tune encoder.')
args_parser.add_argument('-eps_vec', type = int, default = 0, help = '0: single eps, 1: vector eps.')
args_parser.add_argument('-mode_eps', type = str, default = 'single', help = 'single or affine.')
info = args_parser.parse_args()



xdim = 28 * 28
rng_key = jax.random.PRNGKey(info.seed)
tf.random.set_seed(info.seed)
print('zdim', info.zdim)


# prefix = 'onehidden_%i_%i_' % (info.hsize, info.zdim) # One hidden, relus
prefix = 'iwae_%i_%i_' % (info.hsize, info.zdim) # This is IWAE, tanh activation
name = '%sparams_encoder_decoder_%s.pkl' % (prefix, info.dset)
if 'letters' in info.dset:
	name = '%sparams_encoder_decoder_letters.pkl' % (prefix)


trainable = ('encoder', 'decoder')
params_flat, unflatten, params_fixed = bm.initialize(info.zdim, xdim, hsize = info.hsize, trainable = trainable, nbridges = 0)
grad_and_loss = jax.jit(jax.grad(bm.compute_avg_ratio, 2, has_aux = True), static_argnums = (3, 4))


print('Training base VAE')
start = timer()
losses_base, diverged, best_loss_val, last_loss_val, best_params, _, _, _, _ = opt.run(0.001, params_flat, unflatten,
	params_fixed, grad_and_loss, trainable, rng_key, info.iters_base, info)

losses_base, diverged, best_loss_val, last_loss_val, best_params, _, _, _, _ = opt.run(0.0005, best_params, unflatten,
	params_fixed, grad_and_loss, trainable, rng_key, int(info.iters_base), info, best_loss_val = best_loss_val)

losses_base, diverged, best_loss_val, last_loss_val, best_params, _, _, _, _ = opt.run(0.0002, best_params, unflatten,
	params_fixed, grad_and_loss, trainable, rng_key, int(info.iters_base), info, best_loss_val = best_loss_val)

losses_base, diverged, best_loss_val, last_loss_val, best_params, _, _, _, _ = opt.run(0.0001, best_params, unflatten,
	params_fixed, grad_and_loss, trainable, rng_key, int(info.iters_base), info, best_loss_val = best_loss_val)

losses_base, diverged, best_loss_val, last_loss_val, best_params, _, _, _, _ = opt.run(0.00005, best_params, unflatten,
	params_fixed, grad_and_loss, trainable, rng_key, int(info.iters_base), info, best_loss_val = best_loss_val)
end = timer()
print('Base took', end - start)
print('Done training base VAE')


end = timer()
print('Base took', end - start)
print('Done base VAE')

# Evaluating
print('Evaluating')
best_loss_test = opt.evaluate(best_params, unflatten, params_fixed, grad_and_loss, rng_key, info, mode = 'test', iters = 20)
best_loss_test = np.mean(np.array(best_loss_test))
print('Best loss test', info.dset, best_loss_test)

params_train, _ = unflatten(best_params)
params_encoder = params_train['encoder']
params_decoder = params_train['decoder']

with open(name, 'wb') as f:
	pickle.dump((params_encoder, params_decoder), f)













