""" Source codes for the RPO algorithm.
Under review of ICLR 2022.
Please do not distribute. """

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import os.path as osp
import sys
import time

if sys.version_info.major > 2:
  xrange = range

import joblib
import numpy as np
import tensorflow as tf
from gym import spaces

from rpto.data_server.data_server import PPODataServer
from rpto.utils import logger


def as_func(obj):
  if isinstance(obj, float):
    return lambda x: obj
  else:
    assert callable(obj)
    return obj


def average_tf(grads):
  avg_grads = []
  for g in zip(*grads):
    if g[0] is not None:
      grad = tf.stack(g)
      grad = tf.reduce_mean(grad, 0)
    else:
      grad = None
    avg_grads.append(grad)
  return avg_grads


class RPOLearner:
  def __init__(self, sess,
               batch_size,
               ob_space,
               ac_space,
               memory_sim,
               memory_tar,
               policy,
               policy_config,
               ent_coef=1e-2,
               n_v=1,
               vf_coef=0.5,
               rpo_coef=1.0,
               max_grad_norm=0.5,
               log_interval=100,
               save_interval=0,
               gpu_id=-1,
               total_timesteps=5e7,
               adv_normalize=True,
               merge_pi=False,
               rnn=False,
               rollout_len=1,
               hs_len=64,
               reward_weights=None,
               optimizer_type='adam',
               init_model_path=None,
               log_dir='training_log/'):

    self.sess = sess
    self.ob_space = ob_space
    self.ac_space = ac_space
    self.policy = policy
    self.ent_coef = ent_coef
    self.n_v = n_v
    self.reward_weights = reward_weights
    self.vf_coef = vf_coef
    self.rpo_coef = rpo_coef
    self.log_interval = log_interval
    self.save_interval = save_interval
    self.total_timesteps = total_timesteps
    self.adv_normalize = adv_normalize
    self.merge_pi = merge_pi and isinstance(self.ac_space, spaces.Tuple)
    self.rnn = rnn
    self.hs_len = None
    if self.rnn:
      self.hs_len = hs_len
    self._train_op = []
    self.batch_size = batch_size

    # Prepare dataset and learner ph
    self.LR = tf.placeholder(tf.float32, [])
    self.CLIPRANGE = tf.placeholder(tf.float32, [])
    self._data_server_sim = PPODataServer(
      memory=memory_sim,
      ob_space=ob_space,
      ac_space=ac_space,
      n_v=n_v,
      batch_size=batch_size,
      rnn=self.rnn,
      hs_len=self.hs_len,
      gpu_id=gpu_id)  # whether define prefetch on GPU
    self._data_server_tar = PPODataServer(
      memory=memory_tar,
      ob_space=ob_space,
      ac_space=ac_space,
      n_v=n_v,
      batch_size=batch_size,
      rnn=self.rnn,
      hs_len=self.hs_len,
      gpu_id=gpu_id)

    # build
    if gpu_id < 0:
      device = '/cpu:0'
    else:
      device = '/gpu:%d' % gpu_id
    with tf.device(device):
      input_data_sim = self._data_server_sim.input_data
      input_data_tar = self._data_server_tar.input_data
      model_sim_input = policy(
        ob_space=ob_space,
        ac_space=ac_space,
        sess=self.sess,
        n_v=n_v,
        input_data=input_data_sim,
        reuse=tf.AUTO_REUSE,
        scope_name='model',
        **policy_config)
      model_tar_input = policy(
        ob_space=ob_space,
        ac_space=ac_space,
        sess=self.sess,
        n_v=n_v,
        input_data=input_data_tar,
        reuse=True,
        scope_name='model',
        **policy_config)
      loss, vf_loss, losses = self.build_loss(model_sim_input,
                                              model_tar_input,
                                              input_data_sim,
                                              input_data_tar)
      self.params = tf.trainable_variables(scope='model')
      self.params_vf = tf.trainable_variables(scope='model/vf')
      self.param_norm = tf.global_norm(self.params)
      grads = tf.gradients(loss, self.params)
      grads_vf = tf.gradients(vf_loss, self.params_vf)

    self.model = model_sim_input
    self.losses = losses

    # average grads and clip
    with tf.device('/cpu:0'):
      if max_grad_norm is not None:
        grads, self.grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
        grads = list(zip(grads, self.params))
      else:
        self.grad_norm = tf.global_norm(grads)
        grads = list(zip(grads, self.params))

    if optimizer_type == 'adam':
      self.optimizer = tf.train.AdamOptimizer(learning_rate=self.LR, epsilon=1e-5)
    elif optimizer_type == 'sgd':
      self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.LR)
    else:
      raise BaseException('Unknown optimizer type.')
    self._train_op = self.optimizer.apply_gradients(grads)
    self._build_ops()
    tf.global_variables_initializer().run(session=self.sess)
    logger.configure(dir=log_dir,
                     format_strs=['stdout', 'log', 'tensorboard', 'csv'])
    if init_model_path is not None:
      logger.log('Training starts from', init_model_path)
      self.load(load_path=init_model_path)
      logger.log('Restore initial model successfully.')

  def _build_ops(self):
    ## other useful operators
    self.new_params = [tf.placeholder(p.dtype, shape=p.get_shape()) for p in self.params]
    self.param_assign_ops = [p.assign(new_p) for p, new_p in zip(self.params, self.new_params)]
    self.opt_params = self.optimizer.variables()
    self.new_opt_params = [tf.placeholder(p.dtype, shape=p.get_shape()) for p in self.opt_params]
    self.opt_param_assign_ops = [p.assign(new_p) for p, new_p in zip(self.opt_params, self.new_opt_params)]
    self.reset_optimizer_op = tf.variables_initializer(self.optimizer.variables())

    self.loss_names = [
      'policy_loss',
      'value_loss',
      'policy_entropy',
      'approxkl',
      'clipfrac',
      'mean_return_sim',
      'explained_var_sim',
      'mean_return_tar',
      'grad_norm',
      'param_norm',
    ]

    def train_batch(lr, cliprange):
      td_map = {self.LR: lr, self.CLIPRANGE: cliprange}
      return self.sess.run(
        self.losses + [self.grad_norm, self.param_norm, self._train_op],
        td_map
      )[0:len(self.loss_names)]

    def save(save_path):
      ps = self.sess.run(self.params)
      joblib.dump(ps, save_path)

    def load(load_path):
      loaded_params = joblib.load(load_path)
      self.sess.run(self.param_assign_ops,
                    feed_dict={p: v for p, v in zip(self.new_params, loaded_params)})

    def reset():
      self.sess.run(self.reset_optimizer_op)

    self.train_batch = train_batch
    self.save = save
    self.load = load
    self.reset = reset

  def build_loss(self, model_sim_input, model_tar_input,
                 input_data_sim, input_data_tar):
    def _reward_shaping(x):
      return tf.squeeze(tf.matmul(x, np.asarray([self.reward_weights], dtype=np.float32), transpose_b=True))
    if self.reward_weights is not None:
      rwd_shape_func = _reward_shaping
    else:
      rwd_shape_func = lambda x: x

    Relative_ADV = rwd_shape_func(input_data_tar.R - input_data_tar.OLDVPRED)
    ADV = rwd_shape_func(input_data_sim.R - input_data_sim.OLDVPRED)

    mean_return_sim, var_return_sim = tf.nn.moments(
      rwd_shape_func(input_data_sim.R), axes=[0], keep_dims=True)
    mean_return_sim = tf.reduce_mean(mean_return_sim)
    mean_return_tar, var_return_tar = tf.nn.moments(
      rwd_shape_func(input_data_tar.R), axes=[0], keep_dims=True)
    mean_return_tar = tf.reduce_mean(mean_return_tar)

    neglogpac_sim = model_tar_input.pd.neglogp(input_data_sim.A)
    neglogpac_tar = model_tar_input.pd.neglogp(input_data_tar.A)
    entropy = tf.reduce_mean(model_sim_input.pd.entropy(), axis=0)  # reduce mean at the batch dimension
    if self.merge_pi:
      ratio_sim = tf.exp(tf.reduce_sum(input_data_sim.OLDNEGLOGPAC - neglogpac_sim, axis=-1))
      ratio_tar = tf.exp(tf.reduce_sum(input_data_tar.OLDNEGLOGPAC - neglogpac_tar, axis=-1))
    else:
      ratio_sim = tf.exp(input_data_sim.OLDNEGLOGPAC - neglogpac_sim)
      ratio_tar = tf.exp(input_data_tar.OLDNEGLOGPAC - neglogpac_tar)

    # normalize ADV
    Relative_ADV = Relative_ADV - tf.reduce_mean(Relative_ADV, axis=0)
    if self.adv_normalize:
      Relative_ADV = Relative_ADV / tf.sqrt(tf.reduce_mean(tf.square(Relative_ADV), axis=0) + 1e-8)

    # value use sim
    vpred = model_sim_input.vf
    vpredclipped = input_data_sim.OLDVPRED + tf.clip_by_value(model_sim_input.vf - input_data_sim.OLDVPRED,
                                                              -self.CLIPRANGE, self.CLIPRANGE)
    vf_losses1 = tf.square(rwd_shape_func(vpred - input_data_sim.R))
    vf_losses2 = tf.square(rwd_shape_func(vpredclipped - input_data_sim.R))
    vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))

    pg_losses_sim1 = -ADV * ratio_sim
    pg_losses_sim2 = -ADV * tf.clip_by_value(ratio_sim, 1.0 - self.CLIPRANGE,
                                         1.0 + self.CLIPRANGE)
    pg_loss_sim = tf.reduce_mean(tf.maximum(pg_losses_sim1, pg_losses_sim2))

    pg_losses_tar1 = -Relative_ADV * ratio_tar
    pg_losses_tar2 = -Relative_ADV * tf.clip_by_value(ratio_tar, 1.0 - self.CLIPRANGE,
                                         1.0 + self.CLIPRANGE)
    pg_loss_tar = tf.reduce_mean(tf.maximum(pg_losses_tar1, pg_losses_tar2))

    if isinstance(self.ent_coef, list):
      entropy_list = tf.unstack(entropy, axis=0)
      assert len(entropy_list) == len(self.ent_coef), 'Lengths of ent and ent_coef mismatch.'
      print('ent_coef: {}'.format(self.ent_coef))
      entropy = tf.add_n([e*ec for e, ec in zip(entropy_list, self.ent_coef)])
    else:
      entropy = tf.reduce_sum(entropy) * self.ent_coef

    loss = (pg_loss_sim + pg_loss_tar * self.rpo_coef - entropy + vf_loss * self.vf_coef)

    approxkl = .5 * tf.reduce_mean(tf.square(neglogpac_tar - input_data_tar.OLDNEGLOGPAC))
    mean_res, var_res = tf.nn.moments(vpred - input_data_sim.R, axes=[0], keep_dims=True)
    explained_var_sim = tf.reduce_mean(1 - var_res / var_return_sim)
    clipfrac = tf.reduce_mean(
      tf.to_float(tf.greater((ratio_tar - 1.0) * tf.sign(Relative_ADV), self.CLIPRANGE)))
    return loss, vf_loss, [pg_loss_tar,
                           vf_loss,
                           entropy,
                           approxkl,
                           clipfrac,
                           mean_return_sim,
                           explained_var_sim,
                           mean_return_tar]

  def train(self, learning_rate=1e-5, clip_range=0.2):
    lr = as_func(learning_rate)
    cliprange = as_func(clip_range)
    nupdates = int(self.total_timesteps // self.batch_size)
    mblossvals = []
    tfirststart = time.time()
    tstart = time.time()
    total_samples = self._data_server_sim._replay_mem._unroll_num * self._data_server_sim._replay_mem._unroll_len

    for update in xrange(1, nupdates + 1):
      frac = 1.0 - (update - 1.0) / nupdates
      lrnow = lr(frac)
      cliprangenow = cliprange(frac)
      mblossvals.append(self.train_batch(lrnow, cliprangenow))

      # logging stuff
      if update % self.log_interval == 0 or update == 1:
        lossvals = np.mean(mblossvals, axis=0)
        mblossvals = []
        tnow = time.time()
        consuming_fps = int(
          self.batch_size * min(update, self.log_interval) / (tnow - tstart)
        )
        time_elapsed = tnow - tfirststart
        total_samples_now = self._data_server_sim._replay_mem._unroll_num * self._data_server_sim._replay_mem._unroll_len
        receiving_fps = (total_samples_now - total_samples) / (tnow - tstart)
        total_samples = total_samples_now
        tstart = time.time()

        scope = ''
        logger.logkvs({
          scope + "nupdates": update,
          scope + "total_timesteps": update * self.batch_size,
          scope + "all_consuming_fps": consuming_fps,
          scope + 'time_elapsed': time_elapsed,
          scope + "total_samples": total_samples,
          scope + "receiving_fps": receiving_fps,
          scope + "ave_sim_episode_rwd": self._data_server_sim._replay_mem.stat_ave_reward(),
          scope + "ave_tar_episode_rwd": self._data_server_tar._replay_mem.stat_ave_reward(),
          })
        logger.logkvs({scope + lossname: lossval for lossname, lossval
                       in zip(self.loss_names, lossvals)})
        logger.dumpkvs()

      if self.save_interval and (update % self.save_interval == 0 or update == 1) and logger.get_dir():
        checkdir = osp.join(logger.get_dir(), 'checkpoints')
        os.makedirs(checkdir, exist_ok=True)
        savepath = osp.join(checkdir, '%.5i' % update)
        logger.log('Saving log to', savepath)
        self.save(savepath)
    return
