import argparse
import jax
import jax.numpy as np
from jax import random
from jax.experimental import host_callback
import haiku as hk

from nn.approxwide import get_nn_ntk
from nn.policy import get_nnpolicy
from util.print import info

def run(config):
  dataset = np.load('data/dataset_cartpole_small.npy')

  ssize = 4
  asize = 1
  rsize = 1

  s, a, r, snew, done = (dataset[:, :ssize],
                         dataset[:, ssize:ssize+asize].astype(int),
                         dataset[:, ssize+asize:ssize+asize+rsize],
                         dataset[:, ssize+asize+rsize:ssize+asize+rsize+ssize],
                         dataset[:, -1:])

  gamma = config.gamma

  # NN
  num_layers = config.num_layers
  layer_width = config.layer_width
  arch = num_layers * [layer_width]
  arch_text = f'{num_layers}x{layer_width}'
  
  network_size = [4] + arch + [2]
  sigmaweights = np.sqrt(1.)
  sigmabiases = np.sqrt(0.1)

  info(f'Found devices: {jax.devices()}. Using device {sigmaweights.device()}.')

  # Policy NN
  pi_nn_size = [4, 256, 2]

  # Training Setup
  num_epochs = config.epochs
  lr = config.lr

  seed = config.seed

  qnetwork_init, qnetwork_predict = get_nn_ntk(network_size, sigmaweights, sigmabiases)

  _, qparams = qnetwork_init(random.PRNGKey(seed), (4,))

  pinetwork, pinetwork_predict = get_nnpolicy(pi_nn_size, np.sqrt(2.), np.sqrt(0.1))

  piparams = pinetwork.init(random.PRNGKey(0), np.array([0., 0., 0., 0.]))
  piparams['nn_policy'] = np.load('data/policy_cartpole.npy', allow_pickle=True).item()

  def loss(qparams: hk.Params, s, a, r, snew, done):
    q = qnetwork_predict(qparams, s)
    q = np.take_along_axis(q, a, axis=1)

    anew = pinetwork_predict(piparams, random.PRNGKey(0), snew)
    anew = np.argmax(anew, axis=-1, keepdims=True).astype(int)

    qnew = qnetwork_predict(qparams, snew)
    qnew = np.take_along_axis(qnew, anew, axis=1)

    qnew = np.where(done == 1., np.array(0.), qnew)

    return np.mean(np.square(q - r - gamma*qnew))

  @jax.jit
  def update(qparams: hk.Params, s, a, r, snew, done):
    loss_value, grad = jax.value_and_grad(loss)(qparams, s, a, r, snew, done)

    qparams = jax.tree_map(lambda weights, grads: weights - lr * grads, qparams, grad)

    return qparams, loss_value

  def epoch_print(meta, _):
    epoch, loss = meta
    info(f'Epoch {epoch:6d}  loss: {loss:9.6f}')

  def noop(*args):
    pass

  def epoch_loop(epoch, meta):
    qparams, loss_epochs = meta

    qparams, qloss = update(qparams, s, a, r, snew, done)
    loss_epochs = loss_epochs.at[epoch].set(qloss)

    jax.lax.cond(np.equal(np.remainder(epoch, 1000), 0.),
                lambda args: noop(host_callback.id_tap(epoch_print, args)),
                lambda _: noop(None),
                (epoch, qloss)
                )

    return (qparams, loss_epochs)

  info('Starting training...')
  qparams, loss_epochs = jax.lax.fori_loop(0, num_epochs, epoch_loop, (qparams, np.zeros((num_epochs,))))
  epoch_print((num_epochs, loss_epochs[-1]), None)

  save_path = f'data/cartpole_single_nn_gamma{gamma}-lr{lr}-arch{arch_text}-seed{seed}_qparam.npy'
  info(f'Saving Q-network parameters to \'{save_path}\'...')
  np.save(save_path, qparams)

if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Cartpole NNs')
  parser.add_argument('--gamma', default=1., type=float, help='environment discount factor')
  parser.add_argument('--lr', default=0.1, type=float, help='neural network learning rate')
  parser.add_argument('--num_layers', default=2, type=int, help='number of fully connected layers in the neural network')
  parser.add_argument('--layer_width', default=256, type=int, help='fully connected layer width for all layers')
  parser.add_argument('--epochs', default=65536, type=int, help='number of full batch gradient descent iterations')
  parser.add_argument('--seed', type=int, help='random seed for Q-network initialization')

  args = parser.parse_args()
  run(args)