""" Source codes for the RTO algorithm.
Under review of ICLR 2022.
Please do not distribute. """

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import os.path as osp
import sys
import time

if sys.version_info.major > 2:
  xrange = range

import joblib
import numpy as np
import tensorflow as tf
from gym import spaces

from rpto.data_server.data_server import RTODataServer
from rpto.utils import logger


def as_func(obj):
  if isinstance(obj, float):
    return lambda x: obj
  else:
    assert callable(obj)
    return obj


def average_tf(grads):
  avg_grads = []
  for g in zip(*grads):
    if g[0] is not None:
      grad = tf.stack(g)
      grad = tf.reduce_mean(grad, 0)
    else:
      grad = None
    avg_grads.append(grad)
  return avg_grads


class RTOLearner:
  def __init__(self, sess,
               batch_size,
               ob_space,
               ac_space,
               dynamics_space,
               memory,
               policy,
               policy_config,
               dynamics,
               dynamics_config,
               ent_coef=1e-2,
               n_v=1,
               vf_coef=0.5,
               rpo_coef=1.0,
               rto_sl_coef=1.0,
               max_grad_norm=0.5,
               log_interval=100,
               save_interval=0,
               gpu_id=-1,
               total_timesteps=5e7,
               adv_normalize=True,
               merge_pi=False,
               rnn=False,
               rollout_len=1,
               hs_len=64,
               reward_weights=None,
               optimizer_type='adam',
               init_model_path=None):

    self.sess = sess
    self.ob_space = ob_space
    self.ac_space = ac_space
    self.policy = policy
    self.dynamics = dynamics
    self.ent_coef = ent_coef
    self.n_v = n_v
    self.reward_weights = reward_weights
    self.vf_coef = vf_coef
    self.rpo_coef = rpo_coef
    self.rto_sl_coef = rto_sl_coef
    self.log_interval = log_interval
    self.save_interval = save_interval
    self.total_timesteps = total_timesteps
    self.adv_normalize = adv_normalize
    self.merge_pi = merge_pi and isinstance(self.ac_space, spaces.Tuple)
    self.rnn = rnn
    self.hs_len = None
    if self.rnn:
      self.hs_len = hs_len
    self._train_op = []
    self.batch_size = batch_size

    # Prepare dataset and learner ph
    self.LR = tf.placeholder(tf.float32, [])
    self.CLIPRANGE = tf.placeholder(tf.float32, [])
    self._data_server_tar = RTODataServer(
      memory=memory,
      ob_space=ob_space,
      ac_space=ac_space,
      dynamics_space=dynamics_space,
      n_v=n_v,
      batch_size=batch_size,
      rnn=self.rnn,
      hs_len=self.hs_len,
      gpu_id=0)

    # build
    if gpu_id < 0:
      device = '/cpu:0'
    else:
      device = '/gpu:%d' % gpu_id
    with tf.device(device):
      input_data_tar = self._data_server_tar.input_data
      model = policy(
        ob_space=ob_space,
        ac_space=ac_space,
        sess=self.sess,
        n_v=n_v,
        input_data=input_data_tar,
        reuse=tf.AUTO_REUSE,
        next_s_as_input=False,
        batch_size=batch_size,
        scope_name='model',
        **policy_config)
      model_next_s_input = policy(
        ob_space=ob_space,
        ac_space=ac_space,
        sess=self.sess,
        n_v=n_v,
        input_data=input_data_tar,
        reuse=True,
        next_s_as_input=True,
        diff_next_s_as_input=False,
        batch_size=batch_size,
        scope_name='model',
        **policy_config)
      model_diff_next_s_input = policy(
        ob_space=ob_space,
        ac_space=ac_space,
        sess=self.sess,
        n_v=n_v,
        input_data=input_data_tar,
        reuse=True,
        next_s_as_input=False,
        diff_next_s_as_input=True,
        batch_size=batch_size,
        scope_name='model',
        **policy_config)
      dynamics_model = dynamics(
        sess=self.sess,
        reuse=tf.AUTO_REUSE,
        scope_name='dynamics',
        **dynamics_config)

      self.model_params = tf.trainable_variables(scope='model')
      self.dynamics_params = tf.trainable_variables(scope='dynamics')
      grads, rto_losses = self.get_grads(model_next_s_input, model_diff_next_s_input, input_data_tar)

    self.model = model
    self.dynamics_model = dynamics_model
    self.losses = [*rto_losses, self.dynamics_params[0], self.dynamics_params[1], self.dynamics_params[2]]

    # average grads and clip
    with tf.device('/cpu:0'):
      if max_grad_norm is not None:
        grads, self.grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
        grads = list(zip(grads, self.dynamics_params))
      else:
        self.grad_norm = tf.global_norm(grads)
        grads = list(zip(grads, self.dynamics_params))

    if optimizer_type == 'adam':
      self.optimizer = tf.train.AdamOptimizer(learning_rate=self.LR, epsilon=1e-5)
    elif optimizer_type == 'sgd':
      self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.LR)
    else:
      raise BaseException('Unknown optimizer type.')
    self._train_op = self.optimizer.apply_gradients(grads)
    self._build_ops()
    tf.global_variables_initializer().run(session=self.sess)
    logger.configure(dir='training_log/',
                     format_strs=['stdout', 'log', 'tensorboard', 'csv'])
    if init_model_path is not None:
      logger.log('Training starts from', init_model_path)
      self.load_policy_model(load_path=init_model_path)
      logger.log('Restore initial model successfully.')

  def _build_ops(self):
    ## other useful operators
    self.new_params = [tf.placeholder(p.dtype, shape=p.get_shape()) for p in self.model_params]
    self.param_assign_ops = [p.assign(new_p) for p, new_p in zip(self.model_params, self.new_params)]
    self.opt_params = self.optimizer.variables()
    self.new_opt_params = [tf.placeholder(p.dtype, shape=p.get_shape()) for p in self.opt_params]
    self.opt_param_assign_ops = [p.assign(new_p) for p, new_p in zip(self.opt_params, self.new_opt_params)]
    self.reset_optimizer_op = tf.variables_initializer(self.optimizer.variables())

    self.loss_names = [
      'rto_loss',
      'rto_loss2',
      'rto_loss_supervised',
      'm_cart',
      'm_pole',
      'half_pole_len',
    ]

    def train_batch(lr, cliprange):
      td_map = {self.LR: lr, self.CLIPRANGE: cliprange}
      return self.sess.run(
        self.losses + [self.grad_norm, self._train_op],  # self.print_ops
        td_map
      )[0:len(self.loss_names)]

    def save(save_path):
      ps = self.sess.run(self.dynamics_params)
      joblib.dump(ps, save_path)

    def load_policy_model(load_path):
      loaded_params = joblib.load(load_path)
      self.sess.run(self.param_assign_ops,
                    feed_dict={p: v for p, v in zip(self.new_params, loaded_params)})

    def reset():
      self.sess.run(self.reset_optimizer_op)

    self.train_batch = train_batch
    self.save = save
    self.load_policy_model = load_policy_model
    self.reset = reset

  def get_grads(self, model_next_s_input, model_diff_next_s_input, input_data_tar):
    dynamics_grads = input_data_tar.DYNAMICS_GRADS  # (bs, dynamics_dim, state_dim)
    next_v = input_data_tar.NEXT_V2

    loss = 0.5 * tf.square(next_v - model_diff_next_s_input.vf) * (
        1.0 - tf.cast(input_data_tar.DONE, tf.float32))
    loss = tf.reduce_mean(loss)

    loss2 = 0.5 * tf.square(next_v - model_next_s_input.vf) * (
        1.0 - tf.cast(input_data_tar.DONE, tf.float32))
    loss2 = tf.reduce_mean(loss2)

    loss_supervised = 0.5 * tf.square(input_data_tar.NEXT_X - input_data_tar.DIFF_NEXT_X)
    loss_supervised = tf.reduce_mean(loss_supervised)

    next_s_grads = tf.gradients(loss, input_data_tar.DIFF_NEXT_X)  # (bs, state_dim)
    assert next_s_grads is not None, 'make sure that the network takes next_state as input'
    grads = tf.reduce_sum(
      tf.multiply(dynamics_grads, tf.expand_dims(next_s_grads[0], axis=1)),
      axis=-1)
    grads_supervied = tf.reduce_mean(tf.multiply(dynamics_grads,
                                     tf.expand_dims(input_data_tar.DIFF_NEXT_X - input_data_tar.NEXT_X, axis=1)),
                                     axis=-1)
    grads = tf.reduce_mean(grads + self.rto_sl_coef * grads_supervied, axis=0)

    return [tf.squeeze(g) for g in
            tf.split(grads, num_or_size_splits=grads.shape[0])], [loss, loss2, loss_supervised]

  def train(self, learning_rate=1e-5, clip_range=0.2):
    lr = as_func(learning_rate)
    cliprange = as_func(clip_range)
    nupdates = int(self.total_timesteps // self.batch_size)
    mblossvals = []
    tfirststart = time.time()
    tstart = time.time()
    total_samples = self._data_server_tar._replay_mem._unroll_num * self._data_server_tar._replay_mem._unroll_len

    for update in xrange(1, nupdates + 1):
      frac = 1.0 - (update - 1.0) / nupdates
      lrnow = lr(frac)
      cliprangenow = cliprange(frac)
      mblossvals.append(self.train_batch(lrnow, cliprangenow))

      # logging stuff
      if update % self.log_interval == 0 or update == 1:
        lossvals = np.mean(mblossvals, axis=0)
        mblossvals = []
        tnow = time.time()
        consuming_fps = int(
          self.batch_size * min(update, self.log_interval) / (tnow - tstart)
        )
        time_elapsed = tnow - tfirststart
        total_samples_now = self._data_server_tar._replay_mem._unroll_num * self._data_server_tar._replay_mem._unroll_len
        receiving_fps = (total_samples_now - total_samples) / (tnow - tstart)
        total_samples = total_samples_now
        tstart = time.time()

        scope = ''
        logger.logkvs({
          scope + "nupdates": update,
          scope + "total_timesteps": update * self.batch_size,
          scope + "all_consuming_fps": consuming_fps,
          scope + 'time_elapsed': time_elapsed,
          scope + "total_samples": total_samples,
          scope + "receiving_fps": receiving_fps,
          scope + "ave_tar_episode_rwd": self._data_server_tar._replay_mem.stat_ave_reward(),
          })
        logger.logkvs({scope + lossname: lossval for lossname, lossval
                       in zip(self.loss_names, lossvals)})
        logger.dumpkvs()

      if self.save_interval and (update % self.save_interval == 0 or update == 1) and logger.get_dir():
        checkdir = osp.join(logger.get_dir(), 'checkpoints')
        os.makedirs(checkdir, exist_ok=True)
        savepath = osp.join(checkdir, '%.5i' % update)
        logger.log('Saving log to', savepath)
        self.save(savepath)
    return
