import argparse
import jax
import jax.numpy as np
import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", True)

from envs.toy import get_dataset_linspaced
from util.print import info
import os

plt.rcParams.update({'font.size': 14})
plt.rcParams['svg.fonttype'] = 'none'

def run(config):
  num_layers = config.num_layers
  layer_width = config.layer_width
  arch_text = f'{num_layers}x{layer_width}'

  times = config.times
  lr = config.lr

  dataset_size = config.dataset_size

  num_samples = config.eval_resolution
  train_range = (0, 1.25)
  reward_threshold = 1.

  eval_range = (config.eval_range_start, config.eval_range_end)

  num_runs = config.ensemble_size

  dataset = get_dataset_linspaced(dataset_size, train_range, reward_threshold)
  dataset = list(dataset); dataset[1] = np.zeros_like(dataset[0], dtype=np.int64)

  s_test = np.expand_dims(np.linspace(*eval_range, num_samples), 1)

  if not config.show_wide and not config.show_wide_mc and not config.show_nn:
    info('No experiments marked to show. Please specify at least one of: `--show_wide`, `--show_wide_mc`, `--show_nn`.')
    return

  ##########################
  ########## Load ##########
  ##########################

  load_path_wide = f'data/toy_single_time_theor_lr{lr}-arch{arch_text}.npy'
  load_path_wide_mc = f'data/toy_single_time_theor_mc_{num_runs}runs-lr{lr}-arch{arch_text}.npy'
  load_path_nn = f'data/toy_single_time_nnensemble_{num_runs}runs-lr{lr}-arch{arch_text}.npy'
  display_types = []
  if config.show_wide:
    info(f'Loading Wide (inf) Q-values from \'{load_path_wide}\'...')
    display_types += [('Lin.', 'blue', np.load(load_path_wide, allow_pickle=True).item())]
  if config.show_wide_mc:
    info(f'Loading Wide (MC) Q-values from \'{load_path_wide_mc}\'...')
    display_types += [('Lin. (MC)', 'orange', np.load(load_path_wide_mc, allow_pickle=True).item())]
  if config.show_nn:
    info(f'Loading NN ensemble Q-values from \'{load_path_nn}\'...')
    display_types += [('NNs', 'red', np.load(load_path_nn, allow_pickle=True).item())]

  ##########################
  ########## Plot ##########
  ##########################

  def build_pane(ax):
    ax.axvline(reward_threshold, linestyle='--', color='black', linewidth=0.7)
    ax.axhline(0, linestyle='--', color='black', linewidth=0.7)
    ax.axhline(1, linestyle='--', color='black', linewidth=0.7, alpha=0.3)
    if not config.hide_data:
      for x in dataset[0]:
        ax.axvline(x, linestyle='--', color='tab:red', linewidth=0.7, alpha=0.1)

    ax.set_xlabel('s')
    ax.set_xlim(eval_range)
    ax.set_xticks(np.arange(train_range[0], train_range[1]+0.1, 0.5))

  _, axs = plt.subplots(1, len(times), figsize=(len(times) * 3.5, 3))

  for time, ax in zip(times, axs if len(times) > 1 else [axs]):
    build_pane(ax)
    if time == -1: time = np.inf
    
    for name, color, q_times in display_types:
      if name == 'NNs':
        if time == np.inf: continue
        if config.show_nn_individual:
          for q_run in q_times[time]:
            ax.plot(s_test, q_run, color='black', alpha=0.1, linestyle='--')
        mean = q_times[time].mean(axis=0)
        std = q_times[time].std(axis=0, ddof=1)
      else:
        mean, std = q_times[time]

      ax.plot(s_test, mean, color=f'tab:{color}', label=name)
      ax.fill_between(s_test[:, 0], mean-std, mean+std, alpha=0.3, color=f'tab:{color}')

    ax.set_ylim(-0.5, 1.1)
    ax.set_title(f't = {time}')

  ax.legend()
  axs[0].set_ylabel('Q')
  
  os.makedirs('plots', exist_ok=True)

  fig_path = f'plots/toy_single_time_compare_{num_runs}runs-lr{lr}-arch{arch_text}.svg'
  plt.savefig(fig_path)
  info(f'Saved plot to \'{fig_path}\'')


if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Toy environment with single Q-network')
  # parser.add_argument('--gamma', default=1., type=float, help='environment discount factor')
  parser.add_argument('--lr', default=0.01, type=float, help='neural network learning rate')
  parser.add_argument('--dataset_size', default=21, type=int, help='number of linearly spaced states for which transitions are generated in the dataset')
  parser.add_argument('--num_layers', default=2, type=int, help='number of fully connected layers in the neural network')
  parser.add_argument('--layer_width', default=1024, type=int, help='fully connected layer width for all layers')
  parser.add_argument('--eval_resolution', default=100, type=int, help='number of linearly spaced states to use to generate the Q-value data')
  parser.add_argument('--ensemble_size', default=50, type=int, help='number of different seeds to run replications')
  parser.add_argument('--show_nn', action='store_true', help='load and display NN Q-values')
  parser.add_argument('--show_nn_individual', action='store_true', help='load and display Q-values of individual runs of NNs')
  parser.add_argument('--show_wide', action='store_true', help='load and display infinitely wide predicted Q-values')
  parser.add_argument('--show_wide_mc', action='store_true', help='load and display MC estimation of wide Q-values')
  parser.add_argument('--times', default=[0, 16, 256, 2048, 65536], type=int, nargs='+', help='times at which to evaluate the Q-values')

  parser.add_argument('--eval_range_start', default=0, type=float, help='lower limit of the states generated for evaluation')
  parser.add_argument('--eval_range_end', default=1.25, type=float, help='upper limit of the states generated for evaluation')
  parser.add_argument('--hide_data', action='store_true', help='hide the vertical bars that show states where transitions are present in the train set. Default: show')

  args = parser.parse_args()
  run(args)