import argparse
import jax
import jax.numpy as np

jax.config.update("jax_enable_x64", True)

from envs.toy import get_dataset_linspaced
from nn.ntwide import SetupType, nt_wide_predict_lr, nt_wide_predict_mean_cov
from util.print import info

def run(config):
  sigmaweights = np.sqrt(1.)
  sigmabiases = np.sqrt(0.1)
  num_layers = config.num_layers
  layer_width = config.layer_width
  arch = num_layers * [layer_width]
  arch_text = f'{num_layers}x{layer_width}'
  network_size = [1] + arch + [1]

  info(f'Found devices: {jax.devices()}. Using device {sigmaweights.device()}.')

  times = config.times
  lr = config.lr

  dataset_size = config.dataset_size
  gamma = config.gamma

  num_samples = config.eval_resolution
  train_range = (0, 1.25)
  reward_threshold = 1.

  eval_range = (config.eval_range_start, config.eval_range_end)

  dataset = get_dataset_linspaced(dataset_size, train_range, reward_threshold)
  dataset = list(dataset); dataset[1] = np.zeros_like(dataset[0], dtype=np.int64)
  anew = np.zeros_like(dataset[0], dtype=np.int64)

  s_test = np.expand_dims(np.linspace(*eval_range, num_samples), 1)
  a_test = np.zeros_like(s_test, dtype=np.int64)

  # Predict learning rate

  lr_critical = nt_wide_predict_lr(network_size, sigmaweights, sigmabiases, *dataset, anew, gamma, method=SetupType.SINGLE_Q)
  info(f'Critical learning rate: {lr_critical:.6f}')

  ###################################
  ########## Wide (Theor.) ##########
  ###################################

  q_time_wide = {time if time != -1 else np.inf: (None, None) for time in times}
  for time in times:
    if time == -1: time = np.inf
    mean, cov = nt_wide_predict_mean_cov(network_size, sigmaweights, sigmabiases, *dataset, anew, s_test, a_test, gamma, t=time, lr=lr, method=SetupType.SINGLE_Q)
    mean = np.squeeze(mean)
    cov = np.sqrt(np.diag(cov))
    q_time_wide[time] = (mean, cov)

  save_path_wide = f'data/toy_single_time_theor_lr{lr}-arch{arch_text}.npy'
  info(f'Saving Wide (inf) Q-values to \'{save_path_wide}\'...')
  np.save(save_path_wide, q_time_wide)


if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Toy environment with single Q-network')
  parser.add_argument('--gamma', default=1., type=float, help='environment discount factor')
  parser.add_argument('--lr', default=0.01, type=float, help='neural network learning rate')
  parser.add_argument('--dataset_size', default=21, type=int, help='number of linearly spaced states for which transitions are generated in the dataset')
  parser.add_argument('--num_layers', default=2, type=int, help='number of fully connected layers in the neural network')
  parser.add_argument('--layer_width', default=1024, type=int, help='fully connected layer width for all layers')
  parser.add_argument('--eval_resolution', default=100, type=int, help='number of linearly spaced states to use to generate the Q-value data')
  parser.add_argument('--times', default=[0, 16, 256, 2048, 65536], type=int, nargs='+', help='times at which to evaluate the Q-values')

  parser.add_argument('--eval_range_start', default=0, type=float, help='lower limit of the states generated for evaluation')
  parser.add_argument('--eval_range_end', default=1.25, type=float, help='upper limit of the states generated for evaluation')

  args = parser.parse_args()
  run(args)