import argparse
import jax
import jax.numpy as np

from util.print import info
from nn.approxwide import get_nn_ntk


def run(config):
  if config.test:
    dataset = np.load('data/dataset_cartpole_small_test.npy')
  else:
    dataset = np.load('data/dataset_cartpole_small.npy')

  ssize = 4
  s = dataset[:, :ssize]

  gamma = config.gamma

  # NN
  num_layers = config.num_layers
  layer_width = config.layer_width
  arch = num_layers * [layer_width]
  arch_text = f'{num_layers}x{layer_width}'
  
  network_size = [4] + arch + [2]
  sigmaweights = np.sqrt(1.)
  sigmabiases = np.sqrt(0.1)

  info(f'Found devices: {jax.devices()}. Using device {sigmaweights.device()}.')

  # Training Setup
  lr = config.lr

  num_runs = config.num_runs

  _, qnetwork_predict = get_nn_ntk(network_size, sigmaweights, sigmabiases)

  q_preds_all = []

  info(f'Starting...')

  for run in range(num_runs):
    load_path = f'data/cartpole_single_nn_gamma{gamma}-lr{lr}-arch{arch_text}-seed{run}_qparam.npy'
    info(f'Loading Q-network parameters from \'{load_path}\'...')
    qparams = np.load(load_path, allow_pickle=True)

    q_preds_i = []
    for _, state in enumerate(s):
      q_pred = qnetwork_predict(qparams, state)
      action = np.argmax(q_pred).item()
      q_pred = q_pred[action].item()
      q_preds_i += [q_pred]
    
    q_preds_all += [q_preds_i]

  save_path = f'data/cartpole_single_nn_{num_runs}runs-gamma{gamma}-lr{lr}-arch{arch_text}_qvals{"_test" if config.test else ""}.npy'
  info(f'Saving NN Q-values to \'{save_path}\'...')
  np.save(save_path, q_preds_all)

if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Cartpole evaluate NN Q-values')
  parser.add_argument('--gamma', default=1., type=float, help='environment discount factor')
  parser.add_argument('--lr', default=0.1, type=float, help='neural network learning rate')
  parser.add_argument('--num_layers', default=2, type=int, help='number of fully connected layers in the neural network')
  parser.add_argument('--layer_width', default=256, type=int, help='fully connected layer width for all layers')
  parser.add_argument('--num_runs', default=50, type=int, help='number of different Q-network initializations to evaluate')
  parser.add_argument('--test', action='store_true', help='use the test set for evaluating. Default: train set')

  args = parser.parse_args()
  run(args)
