import jax.numpy as np
import jax
import aisboundingmachine as bm
import iwboundingmachine as iwbm
from model_handler import load_model
import opt
from tqdm import tqdm
import argparse
import pickle
import sys
# import matplotlib.pyplot as plt


args_parser = argparse.ArgumentParser(description = 'Process arguments')
args_parser.add_argument('-boundmode', type = str, default = 'AAIS', help = 'What bounding machine to use.')
args_parser.add_argument('-N_eval', type = int, default = 25, help = 'Number of samples to estimate final loss.')
args_parser.add_argument('-iters_eval', type = int, default = 20, help = 'Number of iters to estimate final loss.')
args_parser.add_argument('-model', type = str, default = 'log_sonar', help = 'Model to use.')
args_parser.add_argument('-N', type = int, default = 5, help = 'Number of samples to estimate gradient at each step.')
args_parser.add_argument('-iters_tune', type = int, default = 5000, help = 'Number of iterations.')
args_parser.add_argument('-k', type = int, default = 10, help = 'Number of samples for IW.')
args_parser.add_argument('-vdmode', type = int, default = 1, help = 'Variational distribution: 1 is diagonal.')
args_parser.add_argument('-nbridges', type = int, default = 10, help = 'Number of bridging densities.')
args_parser.add_argument('-nbridges_ref', type = int, default = 64, help = 'Reference.')
args_parser.add_argument('-nbridges_max_tune', type = int, default = 64, help = 'Max tune.') # For madelo
args_parser.add_argument('-lfsteps', type = int, default = 1, help = 'Number of leapfrog steps.')
args_parser.add_argument('-iters', type = int, default = 10000, help = 'Number of iterations.')
args_parser.add_argument('-lr', type = float, default = 0.001, help = 'Learning rate.')
args_parser.add_argument('-id', type = int, default = -1, help = 'ID.')
args_parser.add_argument('-seed', type = int, default = 1, help = 'Random seed to use.')
args_parser.add_argument('-run_cluster', type = int, default = 0, help = '1: true, 0: false.')
args_parser.add_argument('-tune_md', type = int, default = 0, help = '1: true, 0: false.')
args_parser.add_argument('-tune_beta', type = int, default = 0, help = '1: true, 0: false.')
args_parser.add_argument('-tune_vd', type = int, default = 0, help = '1: true, 0: false.')
args_parser.add_argument('-mode_eps', type = str, default = 'single', help = 'Can use single or affine.')
args_parser.add_argument('-mode_br', type = str, default = 'brvd', help = 'Can use brvd, brlinear.')
info = args_parser.parse_args()


log_prob_model, dim = load_model(info.model)
print(info, dim)
rng_key = jax.random.PRNGKey(info.seed)


def extrapolate(params, mode, nbridges_ref):
	params_extrapol = {p: params[p] for p in params}
	if mode == 'log':
		params_extrapol['eps'] = np.log(nbridges_ref) * params_extrapol['eps'] / np.log(info.nbridges)
		params_extrapol['eps_var'] = np.log(nbridges_ref) * params_extrapol['eps_var'] / np.log(info.nbridges)
	if mode == 'lin':
		params_extrapol['eps'] = nbridges_ref * params_extrapol['eps'] / info.nbridges
		params_extrapol['eps_var'] = nbridges_ref * params_extrapol['eps_var'] / info.nbridges
	if mode == 'sqrt':
		params_extrapol['eps'] = np.sqrt(nbridges_ref) * params_extrapol['eps'] / np.sqrt(info.nbridges)
		params_extrapol['eps_var'] = np.sqrt(nbridges_ref) * params_extrapol['eps_var'] / np.sqrt(info.nbridges)
	return params_extrapol


def tune(vdparams, nbridges, trainable):
	print('Tuning %i' % nbridges)
	rng_key = jax.random.PRNGKey(info.seed)
	params_flat, unflatten, params_fixed = bm.initialize(dim = dim, vdmode = info.vdmode, nbridges = nbridges, vdparams = vdparams,
		lfsteps = info.lfsteps, trainable = trainable, mode_eps = info.mode_eps, brmode = info.mode_br)
	grad_and_loss = jax.jit(jax.grad(bm.compute_avg_ratio, 1, has_aux = True), static_argnums = (2, 3, 4))
	losses, diverged, params_flat = opt.run(info.lr, params_flat, unflatten, params_fixed, log_prob_model, grad_and_loss, trainable, rng_key,
		info.iters_tune, info.N, info.run_cluster, extra = False)

	if diverged:
		print('============Diverged!============')
		print(info)
		print('============!!!!!!!!!============')

	if info.run_cluster == 0:
		print(np.mean(np.array(losses[-500:])))
		# import matplotlib.pyplot as plt
		# plt.figure()
		# plt.plot(losses)
		# plt.show()

	return {p: unflatten(params_flat)[0][p] for p in trainable}


def evaluate(vdparams, nbridges, eps = None, eps_var = 0., eta = None, mgridref_y = None, md = None, vd = None, br = None):
	print('Evaluating %i' % nbridges)
	rng_key = jax.random.PRNGKey(info.seed)
	if vd is None:
		vd = vdparams_init
	params_flat, unflatten, params_fixed = bm.initialize(dim = dim, vdmode = info.vdmode, nbridges = nbridges, vdparams = vd, brparams = br,
		eps = eps, eps_var = eps_var, eta = eta, mgridref_y = mgridref_y, mdparams = md, lfsteps = info.lfsteps, mode_eps = info.mode_eps, 
		brmode = info.mode_br)
	losses = []
	looper = range(info.iters_eval)
	if info.run_cluster == 0:
		looper = tqdm(range(info.iters_eval))
	for i in looper:
		rng_key, _ = jax.random.split(rng_key)
		seeds = jax.random.randint(rng_key, (info.N_eval,), 1, 1e6)
		_, (loss, rejections) = bm.compute_avg_ratio(seeds, params_flat, unflatten, params_fixed, log_prob_model)
		losses.append(loss.item())
	return np.mean(np.array(losses))



# Train initial VI
rng_key = jax.random.PRNGKey(info.seed)
trainable = ('vd',)
params_flat, unflatten, params_fixed = bm.initialize(dim = dim, vdmode = info.vdmode, nbridges = 0, lfsteps = info.lfsteps, trainable = trainable)
grad_and_loss = jax.jit(jax.grad(bm.compute_avg_ratio, 1, has_aux = True), static_argnums = (2, 3, 4))
losses, diverged, params_flat = opt.run(0.01, params_flat, unflatten, params_fixed, log_prob_model, grad_and_loss, trainable, rng_key,
	info.iters, 10, info.run_cluster, extra = False)

print(np.mean(np.array(losses[-500:])))

losses, diverged, params_flat = opt.run(0.0001, params_flat, unflatten, params_fixed, log_prob_model, grad_and_loss, trainable, rng_key,
	5000, 10, info.run_cluster, extra = False)

params_train, params_notrain = unflatten(params_flat)
params = {**params_train, **params_notrain}
vdparams_init = params['vd']
print('Done training initial VI')

print(np.mean(np.array(losses[-500:])))

loss_base = np.mean(np.array(losses[-500:]))

trainable = ['eps', 'eta', 'eps_var']
if info.tune_md == 1:
	print('Tuning md')
	trainable.append('md')
if info.tune_beta == 1:
	print('Tuning beta')
	trainable.append('mgridref_y')
if info.tune_vd == 1:
	print('Tuning vd')
	trainable.append('vd')
if info.mode_br in ['brlinear', 'braffine']:
	print('Tuning br')
	trainable.append('br')
trainable = tuple(trainable)

losses = {'tune': None, 'log': None}


# Tune for nbridges, unless too large
nbridges = info.nbridges
if nbridges > info.nbridges_max_tune:
	nbridges = info.nbridges_max_tune

params_tuned = tune(vdparams_init, nbridges, trainable)

# Evaluate for nbridges tuned
loss = evaluate(vdparams_init, nbridges, **params_tuned)
losses['tune'] = loss


# If nbridges <= 64 copy
if info.nbridges <= info.nbridges_ref:
	losses['log'] = loss
else:
	# Tune for nbridges = 64
	params_tuned = tune(vdparams_init, info.nbridges_ref, trainable)
	for mode in ['log']:
		params_extrapol = extrapolate(params_tuned, mode, info.nbridges_ref)
		losses[mode] = evaluate(vdparams_init, info.nbridges, **params_extrapol)

print(losses)

if info.run_cluster == 1:
	dresults = '/mnt/nfs/work1/domke/tgeffner/results_ais/results/'
	name = '%i_%i_%s.pkl' % (info.id, info.nbridges, info.model)
	with open(dresults + name, 'wb') as f:
		pickle.dump((info, loss_base, losses['log']), f)
	print('Done saving results.')










