import argparse
import jax.numpy as np
import scipy.stats
from util.print import info

def run(config):
  num_layers = config.num_layers
  layer_width = config.layer_width
  arch_text = f'{num_layers}x{layer_width}'
  gamma = config.gamma
  lr = config.lr

  num_runs = config.num_runs

  load_path_true = f'data/cartpole_true_returns{"_test" if config.test else ""}.npy'
  info(f'Loading true returns from {load_path_true}...')
  true_returns = np.load(load_path_true)

  load_path_wide_mc = f'data/cartpole_single_wide_mc_{num_runs}runs-gamma{gamma}-lr{lr}-arch{arch_text}_qvals{"_test" if config.test else ""}.npy'
  info(f'Loading Wide (MC) Q-values from \'{load_path_wide_mc}\'...')
  q_wide = np.load(load_path_wide_mc)

  load_path_nn = f'data/cartpole_single_nn_{num_runs}runs-gamma{gamma}-lr{lr}-arch{arch_text}_qvals{"_test" if config.test else ""}.npy'
  info(f'Loading NN Q-values from \'{load_path_nn}\'...')
  q_nn = np.load(load_path_nn)

  ## RMSE

  rmses_preds = []
  rmses_theor = []
  for seed in range(num_runs):
    q_preds_i = q_nn[seed]
    q_theor_i = q_wide[seed]
    rmses_preds += [np.sqrt(np.mean(np.square(q_preds_i - true_returns)))]
    rmses_theor += [np.sqrt(np.mean(np.square(q_theor_i - true_returns)))]

  rmses_preds = np.array(rmses_preds)
  rmses_theor = np.array(rmses_theor)

  ## NLL

  nn_mean = np.mean(q_nn, axis=0)
  nn_std = np.std(q_nn, axis=0, ddof=1)

  wide_mc_mean = np.mean(q_wide, axis=0)
  wide_mc_std = np.std(q_wide, axis=0, ddof=1)

  nlls_preds = -scipy.stats.norm(nn_mean, nn_std).logpdf(true_returns)
  nlls_theor = -scipy.stats.norm(wide_mc_mean, wide_mc_std).logpdf(true_returns)


  ## Print results

  print(f'{"*Method*":^16s} | {"RMSE":^20s} | {"NLL":^20s}')
  print(f'{"-"*17}+{"-"*22}+{"-"*22}')
  print(f'{"NNs":^16s} |  {rmses_preds.mean():>8.3f} ±{rmses_preds.std():>7.3f}   | {nlls_preds.mean():>9.3f} ± {nlls_preds.std():>9.3f}')
  print(f'{"Linearization":^16s} |  {rmses_theor.mean():>8.3f} ±{rmses_theor.std():>7.3f}   | {nlls_theor.mean():>9.3f} ± {nlls_theor.std():>9.3f}')

if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Cartpole get RMSE and NLL metrics')
  parser.add_argument('--gamma', default=1., type=float, help='environment discount factor')
  parser.add_argument('--lr', default=0.1, type=float, help='neural network learning rate')
  parser.add_argument('--num_layers', default=2, type=int, help='number of fully connected layers in the neural network')
  parser.add_argument('--layer_width', default=256, type=int, help='fully connected layer width for all layers')
  parser.add_argument('--num_runs', default=50, type=int, help='number of runs to get metrics for')
  parser.add_argument('--test', action='store_true', help='use the test set for evaluating. Default: train set')

  args = parser.parse_args()
  run(args)