import jax.numpy as np
import jax
import aisboundingmachine as bm
import iwboundingmachine as iwbm
from model_handler import load_model
import opt
from tqdm import tqdm
import argparse
import pickle
import sys
# import matplotlib.pyplot as plt


args_parser = argparse.ArgumentParser(description = 'Process arguments')
args_parser.add_argument('-boundmode', type = str, default = 'EAIS', help = 'What bounding machine to use.')
args_parser.add_argument('-N_eval', type = int, default = 25, help = 'Number of samples to estimate final loss.')
args_parser.add_argument('-N_rej', type = int, default = 50, help = 'Number of samples to estimate final loss.')
args_parser.add_argument('-iters_eval', type = int, default = 20, help = 'Number of iters to estimate final loss.')
args_parser.add_argument('-model', type = str, default = 'log_sonar', help = 'Model to use.')
args_parser.add_argument('-N', type = int, default = 5, help = 'Number of samples to estimate gradient at each step.')
args_parser.add_argument('-vdmode', type = int, default = 1, help = 'Variational distribution: 1 is diagonal, 2 is full.')
args_parser.add_argument('-nbridges', type = int, default = 10, help = 'Number of bridging densities.')
args_parser.add_argument('-nbridges_max_tune', type = int, default = 64, help = 'Max tune.')
args_parser.add_argument('-lfsteps', type = int, default = 1, help = 'Number of leapfrog steps.')
args_parser.add_argument('-iters', type = int, default = 10000, help = 'Number of iterations.')
args_parser.add_argument('-lr', type = float, default = 0.001, help = 'Learning rate.')
args_parser.add_argument('-id', type = int, default = -1, help = 'ID.')
args_parser.add_argument('-seed', type = int, default = 1, help = 'Random seed to use.')
args_parser.add_argument('-target_rejection', type = float, default = 0.25, help = 'Random seed to use.')
args_parser.add_argument('-tolerance_rejection', type = float, default = 0.025, help = 'Random seed to use.')
args_parser.add_argument('-eta', type = float, default = 0.9, help = 'Random seed to use.')
args_parser.add_argument('-run_cluster', type = int, default = 0, help = '1: true, 0: false.')
info = args_parser.parse_args()


log_prob_model, dim = load_model(info.model)
print(info, dim)
rng_key = jax.random.PRNGKey(info.seed)


def tune_eps(vdparams_init, nbridges):
	print('Tuning epsilon for rejection %.3f' % info.target_rejection)
	ubound = 1.
	lbound = 0.
	eps = (ubound + lbound) / 2.
	_, unflatten, _ = bm.initialize(dim = dim, vdmode = info.vdmode, nbridges = nbridges, vdparams = vdparams_init,
			eps = eps, eta = info.eta, lfsteps = info.lfsteps, exact = True)
	for i in range(40):
		rng_key = jax.random.PRNGKey(info.seed + 1)
		seeds = jax.random.randint(rng_key, (info.N_rej,), 1, 1e6)
		params_flat, _, params_fixed = bm.initialize(dim = dim, vdmode = info.vdmode, nbridges = nbridges, vdparams = vdparams_init,
			eps = eps, eta = info.eta, lfsteps = info.lfsteps, exact = True)
		_, (_, rrate) = bm.compute_avg_ratio(seeds, params_flat, unflatten, params_fixed, log_prob_model)
		print(eps, rrate)
		# Done
		if rrate > info.target_rejection - info.tolerance_rejection and rrate < info.target_rejection + info.tolerance_rejection:
			return eps
		# Else, update eps
		if rrate > info.target_rejection:
			ubound = eps
		else:
			lbound = eps
		eps = (lbound + ubound) / 2.
	raise NotImplementedError('Did not converge to step size.')



def evaluate(vdparams_init, nbridges, eps):
	print('Evaluating %i' % nbridges)
	rng_key = jax.random.PRNGKey(info.seed)
	params_flat, unflatten, params_fixed = bm.initialize(dim = dim, vdmode = info.vdmode, nbridges = nbridges, vdparams = vdparams_init, 
		eps = eps, eta = info.eta, lfsteps = info.lfsteps, exact = True)
	losses = []
	rejs = []
	looper = range(info.iters_eval)
	if info.run_cluster == 0:
		looper = tqdm(range(info.iters_eval))
	for i in looper:
		rng_key, _ = jax.random.split(rng_key)
		seeds = jax.random.randint(rng_key, (info.N_eval,), 1, 1e6)
		_, (loss, rejections) = bm.compute_avg_ratio(seeds, params_flat, unflatten, params_fixed, log_prob_model)
		rejs.append(rejections.item())
		losses.append(loss.item())
	print(rejs)
	return np.mean(np.array(losses))



# Train initial VI
rng_key = jax.random.PRNGKey(info.seed)
trainable = ('vd',)
params_flat, unflatten, params_fixed = bm.initialize(dim = dim, vdmode = info.vdmode, nbridges = 0, lfsteps = info.lfsteps, trainable = trainable)
grad_and_loss = jax.jit(jax.grad(bm.compute_avg_ratio, 1, has_aux = True), static_argnums = (2, 3, 4))
losses, diverged, params_flat = opt.run(0.01, params_flat, unflatten, params_fixed, log_prob_model, grad_and_loss, trainable, rng_key,
	info.iters, 10, info.run_cluster, extra = False)


print(np.mean(np.array(losses[-500:])))
losses, diverged, params_flat = opt.run(0.0001, params_flat, unflatten, params_fixed, log_prob_model, grad_and_loss, trainable, rng_key,
	5000, 10, info.run_cluster, extra = True)

params_train, params_notrain = unflatten(params_flat)
params = {**params_train, **params_notrain}
vdparams_init = params['vd']
print('Done training initial VI')

print(np.mean(np.array(losses[-500:])))

loss_base = np.mean(np.array(losses[-500:]))

# # Tune for nbridges in (8, 64)
nbridges = info.nbridges
if nbridges > 128:
	nbridges = 128
if nbridges < 8:
	nbridges = 8

eps = tune_eps(vdparams_init, nbridges)

# Evaluate for nbridges tuned
loss = evaluate(vdparams_init, info.nbridges, eps)

print(loss)

if info.run_cluster == 1:
	dresults = './'
	name = '%i_%i_%s.pkl' % (info.id, info.nbridges, info.model)
	with open(dresults + name, 'wb') as f:
		pickle.dump((info, loss_base, loss), f)
	print('Done saving results.')










