import argparse
import math
import jax
import jax.numpy as np
import haiku as hk
from jax import random

jax.config.update("jax_enable_x64", True)

from envs.toy import get_dataset_linspaced
from nn.approxwide import get_nn_ntk
from util.print import info

def run(config):
  sigmaweights = np.sqrt(1.)
  sigmabiases = np.sqrt(0.1)
  num_layers = config.num_layers
  layer_width = config.layer_width
  arch = num_layers * [layer_width]
  arch_text = f'{num_layers}x{layer_width}'
  network_size = [1] + arch + [1]

  info(f'Found devices: {jax.devices()}. Using device {sigmaweights.device()}.')

  times = config.times
  lr = config.lr

  dataset_size = config.dataset_size
  gamma = config.gamma

  num_samples = config.eval_resolution
  train_range = (0, 1.25)
  reward_threshold = 1.

  eval_range = (config.eval_range_start, config.eval_range_end)

  num_runs = config.ensemble_size

  dataset = get_dataset_linspaced(dataset_size, train_range, reward_threshold)
  dataset = list(dataset); dataset[1] = np.zeros_like(dataset[0], dtype=np.int64)
  anew = np.zeros_like(dataset[0], dtype=np.int64)

  s_test = np.expand_dims(np.linspace(*eval_range, num_samples), 1)
  a_test = np.zeros_like(s_test, dtype=np.int64)

  #####################################
  ########## Ensemble of NNs ##########
  #####################################

  qnetwork_init, qnetwork_predict = get_nn_ntk(network_size, sigmaweights, sigmabiases)

  def loss(qparams: hk.Params, s, a, r, snew, done):
      q = qnetwork_predict(qparams, s)
      q = np.take_along_axis(q, a, axis=1)

      qnew = qnetwork_predict(qparams, snew)
      qnew = np.take_along_axis(qnew, anew, axis=1)

      qnew = np.where(done == 1., np.array(0.), qnew)

      return np.sum(np.square(q - r - gamma*qnew))

  loss_and_grad = jax.value_and_grad(loss)


  @jax.jit
  def update(qparams: hk.Params, s, a, r, snew, done):
      loss_value, grad = loss_and_grad(qparams, s, a, r, snew, done)

      qparams_new = jax.tree_map(lambda weights, grads: weights - lr * grads, qparams, grad)

      del grad, qparams

      return qparams_new, loss_value

  def epoch_loop(epoch, meta):
    qparams, loss_epochs = meta

    qparams, qloss = update(qparams, *dataset)
    loss_epochs = loss_epochs.at[epoch].set(qloss)

    return (qparams, loss_epochs)


  # t_multiplier = 1 / math.sqrt(lr)
  t_multiplier = 1


  loss_all = []

  q_time_nn = {time: np.zeros((num_runs, num_samples)) for time in times}

  info(f'Starting training... ')
  for run in range(num_runs):
    _, qparams = qnetwork_init(random.PRNGKey(run), (1,))

    loss_epochs_times = []
    for i, time in enumerate(times):
      # NNs
      run_for_epochs = math.ceil((time - (0 if i == 0 else times[i-1]))*t_multiplier)
      if time != 0:
        qparams, loss_epochs_i = jax.lax.fori_loop(0, run_for_epochs, epoch_loop, (qparams, np.zeros((run_for_epochs,))))
        loss_epochs_times += [loss_epochs_i]
      # info(f'Finished {math.ceil(time * t_multiplier):6d}/{math.ceil(times[-1] * t_multiplier)}')

      q_i = qnetwork_predict(qparams, s_test)
      q_i = np.take_along_axis(q_i, a_test, axis=1)
      q_time_nn[time] = q_time_nn[time].at[run].set(q_i[:, 0])

    loss_epochs = np.concatenate(loss_epochs_times) if len(times) > 2 else loss_epochs_times
    info(f'Finished run {run:3d}, final loss: {loss_epochs[-1] if len(loss_epochs) > 0 else -1:= 9.4f}')
    
    loss_all += [loss_epochs]

  save_path_nn = f'data/toy_single_time_nnensemble_{num_runs}runs-lr{lr}-arch{arch_text}.npy'
  info(f'Saving NN ensemble Q-values to \'{save_path_nn}\'...')
  np.save(save_path_nn, q_time_nn)


if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Toy environment with single Q-network')
  parser.add_argument('--gamma', default=1., type=float, help='environment discount factor')
  parser.add_argument('--lr', default=0.01, type=float, help='neural network learning rate')
  parser.add_argument('--dataset_size', default=21, type=int, help='number of linearly spaced states for which transitions are generated in the dataset')
  parser.add_argument('--num_layers', default=2, type=int, help='number of fully connected layers in the neural network')
  parser.add_argument('--layer_width', default=1024, type=int, help='fully connected layer width for all layers')
  parser.add_argument('--eval_resolution', default=100, type=int, help='number of linearly spaced states to use to generate the Q-value data')
  parser.add_argument('--ensemble_size', default=50, type=int, help='number of different seeds to run replications')
  parser.add_argument('--times', default=[0, 16, 256, 2048, 65536], type=int, nargs='+', help='times at which to evaluate the Q-values')

  parser.add_argument('--eval_range_start', default=0, type=float, help='lower limit of the states generated for evaluation')
  parser.add_argument('--eval_range_end', default=1.25, type=float, help='upper limit of the states generated for evaluation')

  args = parser.parse_args()
  run(args)