import jax.numpy as np
import jax
from jax.flatten_util import ravel_pytree
from tqdm import tqdm
import sys
import os
import functools
import tensorflow_datasets as tfds
import numpy as onp
import data_manager as dm


def adam(step_size, b1 = 0.9, b2 = 0.999, eps = 1e-8):
	# Basically JAX's thing with added projection for some parameters of interest
	# Assumes ravel_pytree will always work the same way, so no need to update the
	# unflatten function (which may be problematic for jitting stuff)
	def init(x0):
		m0 = np.zeros_like(x0)
		v0 = np.zeros_like(x0)
		return x0, m0, v0
	def update(i, g, state, max_eps, unflatten, trainable):
		def project(x, max_eps, unflatten, trainable):
			x_train, x_notrain = unflatten(x)
			if 'eps' in trainable:
				x_train['eps'] = np.clip(x_train['eps'], 0, 0.05)
			if 'eta' in trainable:
				x_train['eta'] = np.clip(x_train['eta'], 0, 0.99)
			if 'mbetas' in trainable:
				x_train['mbetas'] = jax.nn.relu(x_train['mbetas'] - 0.001) + 0.001
			return ravel_pytree((x_train, x_notrain))[0]

		x, m, v = state
		m = (1 - b1) * g + b1 * m # First moment estimate
		v = (1 - b2) * np.square(g) + b2 * v # Second moment estimate
		mhat = m / (1 - np.asarray(b1, m.dtype) ** (i + 1)) # Bias correction
		vhat = v / (1 - np.asarray(b2, m.dtype) ** (i + 1))
		x = x - step_size * mhat / (np.sqrt(vhat) + eps)
		x = project(x, max_eps, unflatten, trainable)
		return x, m, v
	def get_params(state):
		x, _, _ = state
		return x
	return init, update, get_params


@functools.partial(jax.jit, static_argnums = (1, 2))
def collect_eps_eta(params_flat, unflatten, trainable):
	if 'eps' in trainable:
		return unflatten(params_flat)[0]['eps'], unflatten(params_flat)[0]['eps_var'], unflatten(params_flat)[0]['eta']
	else:
		return 0., 0., 0.

@functools.partial(jax.jit, static_argnums = (1, 2))
def collect_md(params_flat, unflatten, trainable):
	return np.array([0., 0.])




def run(step_size, params_flat, unflatten, params_fixed, grad_and_loss, trainable, rng_key, iters, info, best_loss_val = 500000.):
	# Initialize optimizer
	opt_init, update, get_params = adam(step_size)
	update = jax.jit(update, static_argnums = (4, 5))
	opt_state = opt_init(params_flat)

	# Initialize trackers
	losses = []
	epss = []
	eps_vars = []
	etas = []
	i = 0
	max_eps = 1.
	best_params = params_flat
	best_epoch_val = 0

	# Load dataset and do data stuff
	data = dm.load_dataset(info.dset, mode = 'train')
	batches_per_epoch = int(data.shape[0] / info.bs)
	total_epochs = int(iters / batches_per_epoch) + 1

	# Iterate epochs
	for epoch in range(total_epochs):
		print('Epoch %i / %i -- %.6f' % (epoch, total_epochs, step_size))
		try:
			print(np.mean(np.array(losses[-300:])), epss[-1], eps_vars[-1], etas[-1])
		except:
			pass

		# Every now and then check validation loss
		if epoch % 3 == 0:
			rng_key, rng_key_aux = jax.random.split(rng_key)
			loss_val = evaluate(params_flat, unflatten, params_fixed, grad_and_loss, rng_key_aux, info, mode = 'val', iters = 5)
			loss_val = np.mean(np.array(loss_val))

			last_loss_val = loss_val
			if loss_val < best_loss_val:
				print('Improved on loss!', best_loss_val, loss_val)
				best_params = params_flat
				best_loss_val = loss_val
				best_epoch_val = epoch
			else:
				print('Did not improve on loss...', best_loss_val, loss_val)

		# Shuffle data
		rng_key, rng_key_aux = jax.random.split(rng_key)
		data = jax.random.permutation(rng_key_aux, data)
		start_index = 0

		# Iterate within epoch
		for i in range(batches_per_epoch):
			
			# Load batch, and binarize if not binarized initially
			x = data[start_index : start_index + info.bs]
			rng_key, rng_key_aux = jax.random.split(rng_key)
			x = dm.binarize(rng_key_aux, x)
			start_index += info.bs

			# Do stuff
			rng_key, rng_key_aux = jax.random.split(rng_key)
			seeds = jax.random.randint(rng_key_aux, (x.shape[0],), 1, 1e5)
			params_flat = get_params(opt_state)
			eps, eps_var, eta = collect_eps_eta(params_flat, unflatten, trainable)
			if info.eps_vec == 1:
				eps = eps.mean()
			epss.append(eps.item())
			eps_vars.append(eps_var.item())
			etas.append(eta.item())

			# Compute gradient and loss
			grad, (loss, _) = grad_and_loss(seeds, x, params_flat, unflatten, params_fixed)
			losses.append(loss.item())
			if np.isnan(loss):
				print('Diverged')
				return losses, True, best_loss_val, last_loss_val, best_params, params_flat, best_epoch_val, epss, etas
			
			# Take step
			opt_state = update(i, grad, opt_state, max_eps, unflatten, trainable)
	return losses, False, best_loss_val, last_loss_val, best_params, params_flat, best_epoch_val, epss, etas





def evaluate(params_flat, unflatten, params_fixed, grad_and_loss, rng_key, info, mode = 'val', iters = 5):
	# Load data
	ds = dm.load_dataset(info.dset, mode = mode)
	bs = 200
	numbatches = int(ds.shape[0] / 200)

	losses = []
	for i in range(iters):
		startindex = 0
		for j in range(numbatches):
			# Load batch
			x = ds[startindex : startindex + bs]
			rng_key, rng_key_aux = jax.random.split(rng_key)
			x = dm.binarize(rng_key_aux, x)

			# Compute loss
			rng_key, rng_key_aux = jax.random.split(rng_key)
			seeds = jax.random.randint(rng_key_aux, (x.shape[0],), 1, 1e5)
			_, (loss, _) = grad_and_loss(seeds, x, params_flat, unflatten, params_fixed)
			losses.append(loss.item())
			
			# Update
			startindex += bs
	return losses




# ######
# x = np.reshape(example['image'], (32, 28, 28))
# import matplotlib.pyplot as plt
# for i in range(10):
# 	plt.figure()
# 	plt.imshow(x[i, :, :])
# 	plt.show()
# 	print(x.shape)
# exit()
# ######
