import jax.numpy as np
import jax
import aisboundingmachine as bm
import hisboundingmachine as hbm
import hislboundingmachine as hrbm
import iwboundingmachine as iwbm
from model_handler import load_model
import opt
from tqdm import tqdm
import argparse
import pickle
import sys

args_parser = argparse.ArgumentParser(description = 'Process arguments')
args_parser.add_argument('-boundmode', type = str, default = 'AIS', help = 'What bounding machine to use.')
args_parser.add_argument('-model', type = str, default = 'log_sonar', help = 'Model to use.')
args_parser.add_argument('-N', type = int, default = 5, help = 'Number of samples to estimate gradient at each step.')
args_parser.add_argument('-k', type = int, default = 10, help = 'Number of samples for IW.')
args_parser.add_argument('-vdmode', type = int, default = 1, help = 'Variational distribution: 1 is diagonal.')
args_parser.add_argument('-nbridges', type = int, default = 10, help = 'Number of bridging densities.')
args_parser.add_argument('-lfsteps', type = int, default = 1, help = 'Number of leapfrog steps.')
args_parser.add_argument('-iters', type = int, default = 5000, help = 'Number of iterations.')
args_parser.add_argument('-iters_tune', type = int, default = 5000, help = 'Number of iterations.')
args_parser.add_argument('-lr', type = float, default = 0.001, help = 'Learning rate.')
args_parser.add_argument('-id', type = int, default = -1, help = 'ID.')
args_parser.add_argument('-seed', type = int, default = 1, help = 'Random seed to use.')
args_parser.add_argument('-run_cluster', type = int, default = 0, help = '1: true, 0: false.')
args_parser.add_argument('-mode_eps', type = str, default = 'single', help = 'Can use single or affine.')
args_parser.add_argument('-brmode', type = str, default = 'brvd', help = 'Can use vd or linear.')
info = args_parser.parse_args()


log_prob_model, dim = load_model(info.model)
print(info, dim)
rng_key = jax.random.PRNGKey(info.seed)


# Train base VD
trainable = ('vd',)
params_flat_VI, unflatten_VI, params_fixed_VI = bm.initialize(dim = dim, vdmode = 1, nbridges = 0, trainable = trainable)
grad_and_loss = jax.jit(jax.grad(bm.compute_avg_ratio, 1, has_aux = True), static_argnums = (2, 3, 4))
losses_VI, diverged, params_flat_VI = opt.run(0.01, params_flat_VI, unflatten_VI, params_fixed_VI, log_prob_model, grad_and_loss,
	trainable, rng_key, info.iters, 10, info.run_cluster)
vdparams_init = unflatten_VI(params_flat_VI)[0]['vd']
print(np.mean(np.array(losses_VI[-500:])))

print('Done training plain VI.')


if info.boundmode == 'HIS-es':
	trainable = ('vd', 'eps', 'beta_0')
	params_flat, unflatten, params_fixed = hbm.initialize(dim = dim, vdmode = info.vdmode, vdparams = vdparams_init, nbridges = info.nbridges,
		trainable = trainable, eps_vec = False)
	grad_and_loss = jax.jit(jax.grad(hbm.compute_avg_ratio, 1, has_aux = True), static_argnums = (2, 3, 4))

if info.boundmode == 'HIS-ev':
	trainable = ('vd', 'eps', 'beta_0')
	params_flat, unflatten, params_fixed = hbm.initialize(dim = dim, vdmode = info.vdmode, vdparams = vdparams_init, nbridges = info.nbridges,
		trainable = trainable, eps_vec = True)
	grad_and_loss = jax.jit(jax.grad(hbm.compute_avg_ratio, 1, has_aux = True), static_argnums = (2, 3, 4))

if info.boundmode == 'HISLR-es':
	trainable = ('vd', 'eps', 'rd')
	params_flat, unflatten, params_fixed = hrbm.initialize(dim = dim, vdmode = info.vdmode, vdparams = vdparams_init, nbridges = info.nbridges,
		trainable = trainable, eps_vec = False)
	grad_and_loss = jax.jit(jax.grad(hrbm.compute_avg_ratio, 1, has_aux = True), static_argnums = (2, 3, 4))

if info.boundmode == 'HISLR-ev':
	trainable = ('vd', 'eps', 'rd')
	params_flat, unflatten, params_fixed = hrbm.initialize(dim = dim, vdmode = info.vdmode, vdparams = vdparams_init, nbridges = info.nbridges,
		trainable = trainable, eps_vec = True)
	grad_and_loss = jax.jit(jax.grad(hrbm.compute_avg_ratio, 1, has_aux = True), static_argnums = (2, 3, 4))

if info.boundmode == 'AIS':
	trainable = ('vd', 'eps', 'eta')
	params_flat, unflatten, params_fixed = bm.initialize(dim = dim, vdmode = info.vdmode, vdparams = vdparams_init, nbridges = info.nbridges, 
		lfsteps = info.lfsteps, trainable = trainable, mode_eps = info.mode_eps, brmode = info.brmode)
	grad_and_loss = jax.jit(jax.grad(bm.compute_avg_ratio, 1, has_aux = True), static_argnums = (2, 3, 4))

if info.boundmode == 'IW':
	trainable = ('vd',)
	params_flat, unflatten, params_fixed = iwbm.initialize(dim = dim, vdmode = info.vdmode, vdparams = vdparams_init, k = info.k,
		trainable = trainable)
	grad_and_loss = jax.jit(jax.grad(iwbm.compute_avg_bound, 1, has_aux = True), static_argnums = (2, 3, 4))


losses, diverged, params_flat = opt.run(info.lr, params_flat, unflatten, params_fixed, log_prob_model, grad_and_loss, trainable, rng_key,
	info.iters_tune, info.N, info.run_cluster)


print(np.mean(np.array(losses[-500:])))
print('Done with optimization.')

if info.run_cluster == 0:
	print(np.mean(np.array(losses[-500:])))
	# import matplotlib.pyplot as plt
	# plt.figure()
	# plt.plot(losses)
	# plt.show()
else:
	dresults = '/mnt/nfs/work1/domke/tgeffner/results_ais/results/'
	name = '%i_%s_%i_%i_%i.pkl' % (info.id, info.boundmode, info.k, info.nbridges, info.lfsteps)
	with open(dresults + name, 'wb') as f:
		pickle.dump((info, losses, diverged), f)
	print('Done saving results.')





