import os
import re

from absl import app
from absl import flags
import tensorflow as tf
from tf_agents.environments import tf_py_environment
from tf_agents.policies import py_tf_eager_policy
import tqdm

import evaluation
from dataset import load_tfrecord_dataset_sequence
import wrappers
import gsf
import tf_utils
import utils


PROCGEN_ENVS = [
    'bigfish', 'bossfight', 'caveflyer', 'chaser', 'climber', 'coinrun',
    'dodgeball', 'fruitbot', 'heist', 'jumper', 'leaper', 'maze', 'miner',
    'ninja', 'plunder', 'starpilot'
]

FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'procgen-bigfish-200-0', 'Env name.')
flags.DEFINE_string('task_name', '', 'Env name.')
flags.DEFINE_enum('algo_name', 'gsf', ['gsf'], 'Algorithm.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('action_repeat', 8,
                     '(optional) action repeat used when instantiating env.')
flags.DEFINE_integer('frame_stack', 3,
                     '(optional) frame stack used when instantiating env.')
flags.DEFINE_integer('max_timesteps', 100_000,
                     'Size of dataset to load (typically 100k).')
flags.DEFINE_integer('ckpt_timesteps', 100_000,
                     'Checkpoint timesteps to load dataset from.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('num_updates', 100_000, 'Num updates.')
flags.DEFINE_integer('num_eval_episodes', 10, 'Num eval episodes.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('save_interval', 10_000, 'Model save interval.')
flags.DEFINE_integer('eval_interval', 1000, 'Evaluation interval.')
flags.DEFINE_string('save_dir', '/tmp/save/', 'Saving directory.')
flags.DEFINE_boolean('eager', False, 'Execute functions eagerly.')
flags.DEFINE_float('f_reg', 0.1, 'BR regularization.')
flags.DEFINE_float('reward_bonus', 5.0, 'CQL style reward bonus.')
flags.DEFINE_boolean(
    'numpy_dataset', False,
    'If true, saves and loads the data into/from NumPy arrays with shards')
flags.DEFINE_integer('num_data_augs', 0,
                     'Number of DrQ-style data augs in pixel space.')
flags.DEFINE_enum(
    'obs_type', 'pixels', ['pixels', 'state'],
    'Type of observations to write in the dataset (`state` or `pixels`)')
flags.DEFINE_boolean('load_bc', False,
                     ('Whether to pre-load the online policy'
                      ' corresponding to the dataset, or train from scratch.'))
flags.DEFINE_integer('dataset_size', 100_000, 'Num transitions to train on.')
flags.DEFINE_string('rep_learn_keywords', 'CQL',
                    'Representation learning method list')
flags.DEFINE_integer('n_step_returns', 1, 'N-step returns.')
flags.DEFINE_integer('pretrain', 200_000,
                     ('Pretrain our algorithm using contrastive loss.'
                      ' If `>0`, splits pre-training and FQI. '
                      'If `0`, optimize encoder and FQI losses jointly'))
flags.DEFINE_integer('n_quantiles', 5, 'Number of GVF quantiles.')
flags.DEFINE_float('temp', 0.1, 'NCE softmax temperature.')


def main(_):
  tf.config.experimental_run_functions_eagerly(FLAGS.eager)

  print('Num GPUs Available: ', len(tf.config.list_physical_devices('GPU')))
  if FLAGS.env_name.startswith('procgen'):
    print('Test env: %s' % FLAGS.env_name)
    _, env_name, train_levels, _ = FLAGS.env_name.split('-')
    print('Train env: %s' % FLAGS.env_name)
    env = tf_py_environment.TFPyEnvironment(
        procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,  # no normalization for evaluation
            env_name=env_name,
            num_levels=int(train_levels),
            start_level=0))
    env_all = tf_py_environment.TFPyEnvironment(
        procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,  # no normalization for evaluation
            env_name=env_name,
            num_levels=0,
            start_level=0))

    if int(train_levels) == 0:
      train_levels = '200'


  if FLAGS.obs_type != 'state':
    if FLAGS.env_name.startswith('procgen'):
      bcq = bcq_pixel
      cql = cql_pixel
      fisher_brac = fisher_brac_pixel
      deepmdp = deepmdp_pixel
      vpn = vpn_pixel
      cssc = cssc_pixel
      pse = pse_pixel
  else:
    bcq = bcq_state
    cql = cql_state
  print('Loading dataset')

  dataset_path = 'PATH/TO/SHARDS/*.tfrecord*'

  dataset_iter = tf_utils.create_data_iterator(
    dataset_path,
    FLAGS.batch_size,
    shuffle_buffer_size=FLAGS.batch_size,
    obs_to_float=False)

  tf.random.set_seed(FLAGS.seed)

  hparam_str = utils.make_hparam_string(
      FLAGS.xm_parameters,
      algo_name=FLAGS.algo_name,
      seed=FLAGS.seed,
      task_name=FLAGS.env_name,
      ckpt_timesteps=FLAGS.ckpt_timesteps,
      rep_learn_keywords=FLAGS.rep_learn_keywords)
  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'tb', hparam_str))
  result_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'results', hparam_str))

  pretrain = (FLAGS.pretrain > 0)

  action_dim = env.action_spec().maximum.item()+1
  model = gsf.GSF(
    env.observation_spec(),
    env.action_spec(),
    target_entropy=-action_dim,
    f_reg=FLAGS.f_reg,
    reward_bonus=FLAGS.reward_bonus,
    num_augmentations=FLAGS.num_data_augs,
    env_name=FLAGS.env_name,
    rep_learn_keywords=FLAGS.rep_learn_keywords,
    batch_size=FLAGS.batch_size,
    n_quantiles=FLAGS.n_quantiles,
    temp=FLAGS.temp,
    num_training_levels=train_levels)
  bc_pretraining_steps = FLAGS.pretrain
  if pretrain:
    model_save_path = os.path.join(FLAGS.save_dir, 'weights', hparam_str)
    checkpoint = tf.train.Checkpoint(**model.model_dict)
    tf_step_counter = tf.Variable(0, dtype=tf.int32)
    manager = tf.train.CheckpointManager(
          checkpoint,
          directory=model_save_path,
          max_to_keep=1,
          checkpoint_interval=FLAGS.save_interval,
          step_counter=tf_step_counter)

    # Load the checkpoint in case it exists
    state = manager.restore_or_initialize()
    if state is not None:
        # loaded variables from checkpoint folder
      timesteps_already_done = int(re.findall(
            'ckpt-([0-9]*)', state)[0])  #* FLAGS.save_interval
      print('Loaded model from timestep %d' % timesteps_already_done)
    else:
      print('Training from scratch')
      timesteps_already_done = 0

    tf_step_counter.assign(timesteps_already_done)

    print('Pretraining')
    for i in tqdm.tqdm(range(bc_pretraining_steps)):
      info_dict = model.update_step(
            dataset_iter, train_target='encoder')
        # (quantile_states, quantile_bins)
      if i % FLAGS.log_interval == 0:
        with summary_writer.as_default():
          for k, v in info_dict.items():
            v = tf.reduce_mean(v)
            tf.summary.scalar(f'pretrain/{k}', v, step=i)

      tf_step_counter.assign(i)
      manager.save(checkpoint_number=i)

  tf_step_counter.assign(timesteps_already_done)

  for i in tqdm.tqdm(range(timesteps_already_done, FLAGS.num_updates)):
    with summary_writer.as_default():
      info_dict = model.update_step(
          dataset_iter,
          train_target='rl' if pretrain else 'both')
    if i % FLAGS.log_interval == 0:
      with summary_writer.as_default():
        for k, v in info_dict.items():
          v = tf.reduce_mean(v)
          tf.summary.scalar(f'training/{k}', v, step=i)

    if (i + 1) % FLAGS.eval_interval == 0:
      average_returns, average_length = evaluation.evaluate(
          env,
          model)
      average_returns_all, average_length_all = evaluation.evaluate(
          env_all,
          model)

      with result_writer.as_default():
        tf.summary.scalar('evaluation/returns-200', average_returns, step=i + 1)
        tf.summary.scalar('evaluation/length-200', average_length, step=i + 1)
        tf.summary.scalar(
            'evaluation/returns-all', average_returns_all, step=i + 1)
        tf.summary.scalar(
            'evaluation/length-all', average_length_all, step=i + 1)

    tf_step_counter.assign(i)
    manager.save(checkpoint_number=i)

if __name__ == '__main__':
  app.run(main)