import argparse
import jax
import jax.numpy as np

jax.config.update("jax_enable_x64", True)

from envs.toy import get_dataset_linspaced
from nn.ntwide import SetupType
from nn.mcwide import wide_init, wide_apply_jit
from util.print import info

def run(config):
  sigmaweights = np.sqrt(1.)
  sigmabiases = np.sqrt(0.1)
  num_layers = config.num_layers
  layer_width = config.layer_width
  arch = num_layers * [layer_width]
  arch_text = f'{num_layers}x{layer_width}'
  network_size = [1] + arch + [1]

  info(f'Found devices: {jax.devices()}. Using device {sigmaweights.device()}.')

  times = config.times
  lr = config.lr

  dataset_size = config.dataset_size
  gamma = config.gamma

  num_samples = config.eval_resolution
  train_range = (0, 1.25)
  reward_threshold = 1.

  eval_range = (config.eval_range_start, config.eval_range_end)

  num_runs = config.ensemble_size

  dataset = get_dataset_linspaced(dataset_size, train_range, reward_threshold)
  dataset = list(dataset); dataset[1] = np.zeros_like(dataset[0], dtype=np.int64)
  anew = np.zeros_like(dataset[0], dtype=np.int64)

  s_test = np.expand_dims(np.linspace(*eval_range, num_samples), 1)
  a_test = np.zeros_like(s_test, dtype=np.int64)

  ###############################
  ########## Wide (MC) ##########
  ###############################

  q_time_wide_mc = {time if time != -1 else np.inf: np.zeros((num_runs, num_samples)) for time in times}

  for i in range(num_runs):
    info(f'Evaluating seed {i}...')
    qparams = wide_init(i, network_size, sigmaweights, sigmabiases, size_state=1)

    anew = np.zeros_like(dataset[3], dtype=np.int64)
    for time in times:
      if time == -1: time = np.inf
      q_i = wide_apply_jit(qparams, *dataset, anew, s_test, a_test, gamma, method=SetupType.SINGLE_Q, t=time, lr=lr)
      q_time_wide_mc[time] = q_time_wide_mc[time].at[i, :].set(q_i[:, 0])


  for time in times:
    if time == -1: time = np.inf
    mean = q_time_wide_mc[time].mean(axis=0)
    std = q_time_wide_mc[time].std(axis=0, ddof=1)
    q_time_wide_mc[time] = (mean, std)

  save_path_wide = f'data/toy_single_time_theor_mc_{num_runs}runs-lr{lr}-arch{arch_text}.npy'
  info(f'Saving Wide (MC) Q-values to \'{save_path_wide}\'...')
  np.save(save_path_wide, q_time_wide_mc)


if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Toy environment with single Q-network')
  parser.add_argument('--gamma', default=1., type=float, help='environment discount factor')
  parser.add_argument('--lr', default=0.01, type=float, help='neural network learning rate')
  parser.add_argument('--dataset_size', default=21, type=int, help='number of linearly spaced states for which transitions are generated in the dataset')
  parser.add_argument('--num_layers', default=2, type=int, help='number of fully connected layers in the neural network')
  parser.add_argument('--layer_width', default=1024, type=int, help='fully connected layer width for all layers')
  parser.add_argument('--eval_resolution', default=100, type=int, help='number of linearly spaced states to use to generate the Q-value data')
  parser.add_argument('--ensemble_size', default=50, type=int, help='number of different seeds to run replications')
  parser.add_argument('--times', default=[0, 16, 256, 2048, 65536], type=int, nargs='+', help='times at which to evaluate the Q-values')

  parser.add_argument('--eval_range_start', default=0, type=float, help='lower limit of the states generated for evaluation')
  parser.add_argument('--eval_range_end', default=1.25, type=float, help='upper limit of the states generated for evaluation')

  args = parser.parse_args()
  run(args)