import argparse
import jax
import jax.numpy as np
from jax import random
from nn.ntwide import SetupType
from nn.policy import get_nnpolicy

from util.print import info
from nn.mcwide import wide_init, wide_apply_jit


def run(config):
  dataset = np.load('data/dataset_cartpole_small.npy')

  ssize = 4
  asize = 1
  rsize = 1

  s, a, r, snew, done = (dataset[:, :ssize],
                         dataset[:, ssize:ssize+asize].astype(int),
                         dataset[:, ssize+asize:ssize+asize+rsize],
                         dataset[:, ssize+asize+rsize:ssize+asize+rsize+ssize],
                         dataset[:, -1:])

  gamma = config.gamma

  dummy_key = random.PRNGKey(0)

  # NN
  num_layers = config.num_layers
  layer_width = config.layer_width
  arch = num_layers * [layer_width]
  arch_text = f'{num_layers}x{layer_width}'
  
  network_size = [4] + arch + [2]
  sigmaweights = np.sqrt(1.)
  sigmabiases = np.sqrt(0.1)

  info(f'Found devices: {jax.devices()}. Using device {sigmaweights.device()}.')

  # Policy NN
  pi_nn_size = [4, 256, 2]

  # Training Setup
  time = config.time
  lr = config.lr

  num_runs = config.num_runs

  pinetwork, pinetwork_predict = get_nnpolicy(pi_nn_size, np.sqrt(2.), np.sqrt(0.1))

  piparams = pinetwork.init(random.PRNGKey(0), np.array([0., 0., 0., 0.]))
  piparams['nn_policy'] = np.load('data/policy_cartpole.npy', allow_pickle=True).item()

  info(f'Starting...')

  anew = np.argmax(pinetwork_predict(piparams, dummy_key, snew), axis=-1, keepdims=True)

  q_wide = np.zeros((num_runs, s.shape[0]))

  if config.test:
    dataset_test = np.load('data/dataset_cartpole_small_test.npy')

  for run in range(num_runs):
    qparams = wide_init(run, network_size, sigmaweights, sigmabiases, 4)
    sstar, astar = s, a
    if config.test:
      sstar, astar = dataset_test[:, :ssize], dataset_test[:, ssize:ssize+asize].astype(int)
    q_theor = wide_apply_jit(qparams, s, a, r, snew, done, anew, sstar, astar, gamma, t=time, lr=lr, method=SetupType.SINGLE_Q)
    q_wide = q_wide.at[run, :].set(q_theor[:, 0])
    info(f'Evaluated run {run}')

  save_path = f'data/cartpole_single_wide_mc_{num_runs}runs-gamma{gamma}-lr{lr}-arch{arch_text}_qvals{"_test" if config.test else ""}.npy'
  info(f'Saving Wide (MC) Q-values to \'{save_path}\'...')
  np.save(save_path, q_wide)

if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Cartpole Wide (MC)')
  parser.add_argument('--gamma', default=1., type=float, help='environment discount factor')
  parser.add_argument('--lr', default=0.1, type=float, help='neural network learning rate')
  parser.add_argument('--num_layers', default=2, type=int, help='number of fully connected layers in the neural network')
  parser.add_argument('--layer_width', default=256, type=int, help='fully connected layer width for all layers')
  parser.add_argument('--time', default=65536, type=int, help='timepoint at which to evaluate the Q-values')
  parser.add_argument('--num_runs', default=50, type=int, help='number of different Q-network initializations to run')
  parser.add_argument('--test', action='store_true', help='use the test set for evaluating. Default: train set')

  args = parser.parse_args()
  run(args)
