import os
import sys
import time
from contextlib import contextmanager
from collections import deque
import scipy as sp
import scipy.stats

import gym
from mpi4py import MPI
import tensorflow as tf
import numpy as np

np.set_printoptions(threshold=sys.maxsize)
from tqdm import tqdm
import causal_irl.algorithms.common.tf_utils as tf_util
from causal_irl.algorithms.common.tf_utils import total_episode_reward_logger
from stable_baselines.common import explained_variance, zipsame, dataset, fmt_row, colorize, \
    SetVerbosity, TensorboardWriter
from causal_irl.algorithms.common.base_class import ActorCriticRLModel
from stable_baselines import logger
from stable_baselines.common.mpi_adam import MpiAdam
from stable_baselines.common.cg import conjugate_gradient
from causal_irl.algorithms.common.policies import ActorCriticPolicy
from stable_baselines.common.misc_util import flatten_lists
from causal_irl.algorithms.common.runners import traj_segment_generator
from stable_baselines.trpo_mpi.utils import add_vtarg_and_adv

from causal_irl.envs.my_observation_wrapper import MyObservationWrapper


class TRPO(ActorCriticRLModel):
    """
    Trust Region Policy Optimization (https://arxiv.org/abs/1502.05477)

    :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...)
    :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
    :param gamma: (float) the discount value
    :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon)
    :param max_kl: (float) the Kullback-Leibler loss threshold
    :param cg_iters: (int) the number of iterations for the conjugate gradient calculation
    :param lam: (float) GAE factor
    :param entcoeff: (float) the weight for the entropy loss
    :param cg_damping: (float) the compute gradient dampening factor
    :param vf_stepsize: (float) the value function stepsize
    :param vf_iters: (int) the value function's number iterations for learning
    :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
    :param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
    :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
    :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
    :param full_tensorboard_log: (bool) enable additional logging when using tensorboard
        WARNING: this logging can take a lot of space quickly
    :param seed: (int) Seed for the pseudo-random generators (python, numpy, tensorflow).
        If None (default), use random seed. Note that if you want completely deterministic
        results, you must set `n_cpu_tf_sess` to 1.
    :param n_cpu_tf_sess: (int) The number of threads for TensorFlow operations
        If None, the number of cpu of the current machine will be used.
    """

    def __init__(self, policy, env, gamma=0.99, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, lam=0.98,
                 entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, verbose=0, tensorboard_log=None,
                 _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False, val_interval=None,
                 save_path=None,
                 exp_code="000",
                 seed=None, n_cpu_tf_sess=1):
        super(TRPO, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=False,
                                   _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs,
                                   seed=seed, n_cpu_tf_sess=n_cpu_tf_sess)
        self.using_airl = False
        self.using_gail = False
        self.timesteps_per_batch = timesteps_per_batch
        self.cg_iters = cg_iters
        self.cg_damping = cg_damping
        self.gamma = gamma
        self.lam = lam
        self.max_kl = max_kl
        self.vf_iters = vf_iters
        self.vf_stepsize = vf_stepsize
        self.entcoeff = entcoeff
        self.tensorboard_log = tensorboard_log
        self.full_tensorboard_log = full_tensorboard_log

        # GAIL Params
        self.hidden_size_adversary = 100
        self.adversary_entcoeff = 1e-3
        self.expert_dataset = None
        self.val_interval = val_interval
        self.g_step = 1
        self.d_step = 1
        self.d_stepsize = 3e-4
        self.best_val = np.inf
        self.best_ret = -np.inf
        self.save_path = save_path
        self.exp_code = exp_code
        self.expert_idx = None
        self.start_idx = None
        self.dropout_rate = 0.0
        self.action_noise = 0.0
        self.ob_noise = 0.0

        self.graph = None
        self.sess = None
        self.policy_pi = None
        self.loss_names = None
        self.assign_old_eq_new = None
        self.compute_losses = None
        self.compute_lossandgrad = None
        self.compute_fvp = None
        self.compute_vflossandgrad = None
        self.d_adam = None
        self.vfadam = None
        self.get_flat = None
        self.set_from_flat = None
        self.timed = None
        self.allmean = None
        self.nworkers = None
        self.rank = None
        self.reward_giver = None
        self.step = None
        self.proba_step = None
        self.initial_state = None
        self.params = None
        self.summary = None
        self.fusion = None
        # keep track of most recent 20 iterations of training
        self.traj_buffer = None

        if _init_setup_model:
            self.setup_model()

    def _get_pretrain_placeholders(self):
        policy = self.policy_pi
        action_ph = policy.pdtype.sample_placeholder([None])
        if isinstance(self.action_space, gym.spaces.Discrete):
            return policy.obs_ph, action_ph, policy.policy
        return policy.obs_ph, action_ph, policy.deterministic_action

    def setup_model(self):
        # prevent import loops
        from causal_irl.algorithms.gail.adversary import TransitionClassifier
        from causal_irl.algorithms.airl.airl_adversary import AIRL_Discriminator
        with SetVerbosity(self.verbose):

            assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the TRPO model must be " \
                                                               "an instance of common.policies.ActorCriticPolicy."

            self.nworkers = MPI.COMM_WORLD.Get_size()
            self.rank = MPI.COMM_WORLD.Get_rank()
            np.set_printoptions(precision=3)

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.set_random_seed(self.seed)
                self.sess = tf_util.make_session(num_cpu=self.n_cpu_tf_sess, graph=self.graph)

                if self.using_gail:
                    self.reward_giver = TransitionClassifier(self.observation_space, self.action_space,
                                                             self.hidden_size_adversary,
                                                             entcoeff=self.adversary_entcoeff, normalize=True)

                # Construct network for new policy
                self.policy_pi = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
                                             None, reuse=False, **self.policy_kwargs)
                if self.using_airl:
                    self.reward_giver = AIRL_Discriminator(self.observation_space, self.action_space, self.policy_pi,
                                                           normalize=False)

                # Network for old policy
                with tf.variable_scope("oldpi", reuse=False):
                    old_policy = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
                                             None, reuse=False, **self.policy_kwargs)

                with tf.variable_scope("loss", reuse=False):
                    atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
                    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

                    observation = self.policy_pi.obs_ph
                    action = self.policy_pi.pdtype.sample_placeholder([None])
                    rate_ph = self.policy_pi.rate_ph

                    kloldnew = old_policy.proba_distribution.kl(self.policy_pi.proba_distribution)
                    ent = self.policy_pi.proba_distribution.entropy()
                    meankl = tf.reduce_mean(kloldnew)
                    meanent = tf.reduce_mean(ent)
                    entbonus = self.entcoeff * meanent

                    vferr = tf.reduce_mean(tf.square(self.policy_pi.value_flat - ret))

                    # advantage * pnew / pold
                    ratio = tf.exp(self.policy_pi.proba_distribution.logp(action) -
                                   old_policy.proba_distribution.logp(action))
                    surrgain = tf.reduce_mean(ratio * atarg)

                    optimgain = surrgain + entbonus
                    losses = [optimgain, meankl, entbonus, surrgain, meanent]
                    self.loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

                    dist = meankl

                    all_var_list = tf_util.get_trainable_vars("model")
                    var_list = [v for v in all_var_list if "/vf" not in v.name and "/q/" not in v.name]
                    vf_var_list = [v for v in all_var_list if "/pi" not in v.name and "/logstd" not in v.name]

                    self.get_flat = tf_util.GetFlat(var_list, sess=self.sess)
                    self.set_from_flat = tf_util.SetFromFlat(var_list, sess=self.sess)

                    klgrads = tf.gradients(dist, var_list)
                    flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
                    shapes = [var.get_shape().as_list() for var in var_list]
                    start = 0
                    tangents = []
                    for shape in shapes:
                        var_size = tf_util.intprod(shape)
                        tangents.append(tf.reshape(flat_tangent[start: start + var_size], shape))
                        start += var_size
                    gvp = tf.add_n([tf.reduce_sum(grad * tangent)
                                    for (grad, tangent) in zipsame(klgrads, tangents)])  # pylint: disable=E1111
                    # Fisher vector products
                    fvp = tf_util.flatgrad(gvp, var_list)

                    tf.summary.scalar('entropy_loss', meanent)
                    tf.summary.scalar('policy_gradient_loss', optimgain)
                    tf.summary.scalar('value_function_loss', surrgain)
                    tf.summary.scalar('approximate_kullback-leibler', meankl)
                    tf.summary.scalar('loss', optimgain + meankl + entbonus + surrgain + meanent)

                    self.assign_old_eq_new = \
                        tf_util.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                                          zipsame(tf_util.get_globals_vars("oldpi"),
                                                                  tf_util.get_globals_vars("model"))])
                    # import pdb;
                    # pdb.set_trace()
                    # rate_ph

                    self.compute_losses = tf_util.function([observation, old_policy.obs_ph, action, atarg, rate_ph],
                                                           losses)
                    self.compute_fvp = tf_util.function(
                        [flat_tangent, observation, old_policy.obs_ph, action, atarg, rate_ph],
                        fvp) #
                    self.compute_vflossandgrad = tf_util.function([observation, old_policy.obs_ph, ret, rate_ph],
                                                                  tf_util.flatgrad(vferr, vf_var_list)) #

                    @contextmanager
                    def timed(msg):
                        if self.rank == 0 and self.verbose >= 1:
                            print(colorize(msg, color='magenta'))
                            start_time = time.time()
                            yield
                            print(colorize("done in {:.3f} seconds".format((time.time() - start_time)),
                                           color='magenta'))
                        else:
                            yield

                    def allmean(arr):
                        assert isinstance(arr, np.ndarray)
                        out = np.empty_like(arr)
                        MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
                        out /= self.nworkers
                        return out

                    tf_util.initialize(sess=self.sess)

                    th_init = self.get_flat()
                    MPI.COMM_WORLD.Bcast(th_init, root=0)
                    self.set_from_flat(th_init)

                with tf.variable_scope("Adam_mpi", reuse=False):
                    self.vfadam = MpiAdam(vf_var_list, sess=self.sess)
                    if self.using_gail or self.using_airl:
                        self.d_adam = MpiAdam(self.reward_giver.get_trainable_variables(), sess=self.sess)
                        self.d_adam.sync()
                    self.vfadam.sync()

                with tf.variable_scope("input_info", reuse=False):
                    tf.summary.scalar('discounted_rewards', tf.reduce_mean(ret))
                    tf.summary.scalar('learning_rate', tf.reduce_mean(self.vf_stepsize))
                    tf.summary.scalar('advantage', tf.reduce_mean(atarg))
                    tf.summary.scalar('kl_clip_range', tf.reduce_mean(self.max_kl))

                    if self.full_tensorboard_log:
                        tf.summary.histogram('discounted_rewards', ret)
                        tf.summary.histogram('learning_rate', self.vf_stepsize)
                        tf.summary.histogram('advantage', atarg)
                        tf.summary.histogram('kl_clip_range', self.max_kl)
                        if tf_util.is_image(self.observation_space):
                            tf.summary.image('observation', observation)
                        else:
                            tf.summary.histogram('observation', observation)

                self.timed = timed
                self.allmean = allmean

                self.step = self.policy_pi.step
                self.proba_step = self.policy_pi.proba_step
                self.initial_state = self.policy_pi.initial_state

                self.params = tf_util.get_trainable_vars("model") + tf_util.get_trainable_vars("oldpi")
                if self.using_gail or self.using_airl:
                    self.params.extend(self.reward_giver.get_trainable_variables())

                self.summary = tf.summary.merge_all()

                self.compute_lossandgrad = \
                    tf_util.function([observation, old_policy.obs_ph, action, atarg, ret, rate_ph],
                                     [self.summary, tf_util.flatgrad(optimgain, var_list)] + losses) #

    def _initialize_dataloader(self):
        """Initialize dataloader."""
        if self.using_airl:
            self.expert_idx = np.random.permutation(self.expert_dataset[0].shape[0])
            self.start_idx = 0
            return
        batchsize = self.timesteps_per_batch // self.d_step
        self.expert_dataset.init_dataloader(batchsize)

    def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="TRPO",
              reset_num_timesteps=True):
        def calculate_lprobs(obs, actions, policy):
            nlogprob = policy.proba_obs_action(obs, actions)
            return -nlogprob

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:
            self._setup_learn()

            with self.sess.as_default():
                callback.on_training_start(locals(), globals())

                seg_gen = traj_segment_generator(self.policy_pi, self.env, self.timesteps_per_batch,
                                                 reward_giver=self.reward_giver,
                                                 gail=self.using_gail, airl=self.using_airl, callback=callback)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(maxlen=40)  # rolling buffer for episode rewards

                true_reward_buffer = None
                if self.using_gail or self.using_airl:
                    true_reward_buffer = deque(maxlen=40)

                    self._initialize_dataloader()

                    #  Stats not used for now
                    # TODO: replace with normal tb logging
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])
                pbar = tqdm(total_timesteps)
                while True:
                    if timesteps_so_far >= total_timesteps:
                        break
                    logger.log("********** Iteration %i ************" % iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(self.compute_fvp(vec, *fvpargs, self.dropout_rate, sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()

                        # Stop training early (triggered by the callback)
                        if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                            break
                        observation, action, lprob = seg["observations"], seg["actions"], np.expand_dims(seg["lprobs"],
                                                                                                         axis=1)
                        nobs = observation[1:]
                        nobs = np.r_[nobs, 0.0 * np.expand_dims(np.ones_like(nobs[0]), axis=0)]
                        # if self.fusion:
                        #     curr_observation = observation
                        #     curr_action = action
                        #     curr_nobs = nobs
                        #     if len(self.traj_buffer) > 0:
                        #         pidxs = np.random.randint(0, len(self.traj_buffer))
                        #         past_obs, past_acs, past_nobs = self.traj_buffer[pidxs]
                        #         past_lprobs = np.expand_dims(calculate_lprobs(past_obs, past_acs, self.policy_pi),
                        #                                      axis=1)
                        #         observation = np.concatenate([observation, past_obs])
                        #         action = np.concatenate([action, past_acs])
                        #         lprob = np.concatenate([lprob, past_lprobs])
                        #         nobs = np.concatenate([nobs, past_nobs])
                        #     if len(self.traj_buffer) == 20:
                        #         self.traj_buffer.pop(0)
                        #     self.traj_buffer.append((curr_observation, curr_action, curr_nobs))

                        lprob_expert_all = np.zeros_like(lprob)
                        nobs_expert_all = np.zeros_like(nobs)
                        if self.using_airl:
                            lprob_expert_all = np.expand_dims(
                                calculate_lprobs(self.expert_dataset[0], self.expert_dataset[1], self.policy_pi),
                                axis=1)
                            nobs_expert_all = self.expert_dataset[0][1:]
                            nobs_expert_all = np.r_[
                                nobs_expert_all, 0.0 * np.expand_dims(np.ones_like(nobs_expert_all[0]), axis=0)]
                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))

                        atarg, tdlamret = seg["adv"], seg["tdlamret"]
                        vpredbefore = seg["vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / (
                                    atarg.std() + 1e-8)  # standardized advantage function estimate

                        # true_rew is the reward without discount
                        if writer is not None:
                            total_episode_reward_logger(self.episode_reward,
                                                        seg["true_rewards"].reshape(
                                                            (self.n_envs, -1)),
                                                        seg["dones"].reshape((self.n_envs, -1)),
                                                        writer, self.num_timesteps)
                        args = seg["observations"], seg["observations"], seg["actions"], atarg
                        # Subsampling: see p40-42 of John Schulman thesis
                        # http://joschu.net/docs/thesis.pdf
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = tf.RunMetadata() if self.full_tensorboard_log else None
                            # run loss backprop with summary, and save the metadata (memory, compute time, ...)
                            if writer is not None:
                                summary, grad, *lossbefore = self.compute_lossandgrad(*args, tdlamret, self.dropout_rate, sess=self.sess,
                                                                                      options=run_options,
                                                                                      run_metadata=run_metadata)
                                if self.full_tensorboard_log:
                                    writer.add_run_metadata(run_metadata, 'step%d' % steps)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *lossbefore = self.compute_lossandgrad(*args, tdlamret, self.dropout_rate, sess=self.sess,
                                                                                options=run_options,
                                                                                run_metadata=run_metadata)

                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("conjugate_gradient"):
                                stepdir = conjugate_gradient(fisher_vector_product, grad, cg_iters=self.cg_iters,
                                                             verbose=self.rank == 0 and self.verbose >= 1)
                            if not np.isfinite(stepdir).all:
                                continue
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(self.compute_losses(*args, self.dropout_rate, sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log("Got non-finite value of losses -- bad!")
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log("violated KL constraint. shrinking step.")
                                elif improve < 0:
                                    logger.log("surrogate didn't improve. shrinking step.")
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

                            for (loss_name, loss_val) in zip(self.loss_names, mean_losses):
                                logger.record_tabular(loss_name, loss_val)

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret) in dataset.iterbatches((seg["observations"], seg["tdlamret"]),
                                                                         include_final_partial_batch=False,
                                                                         batch_size=128,
                                                                         shuffle=True):
                                    grad = self.allmean(self.compute_vflossandgrad(mbob, mbob, mbret, self.dropout_rate, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)

                        # ----------------- Validating G ----------------
                        if self.val_interval is not None:
                            if self.num_timesteps % self.val_interval == 0:
                                # ob_val, ac_val = self.expert_dataset.get_next_batch(split='val')
                                # nlog = self.policy_pi.proba_obs_action(ob_val, ac_val)
                                # curr_val = np.mean(-nlog)
                                # logger.log(curr_val)
                                segs = [seg_gen.__next__() for _ in range(5)]
                                curr_rets = [np.mean(seg["ep_true_rets"]) for seg in segs]
                                curr_ret = np.mean(curr_rets)
                                logger.log(curr_ret)
                                if curr_ret > self.best_ret:
                                    logger.log("saving model at timestep {}".format(self.num_timesteps))
                                    self.best_ret = curr_ret
                                    if self.using_gail:
                                        model_name = 'gail' + \
                                                     ('' if self.dropout_rate == 0.0 else str(self.dropout_rate)) + \
                                                     ('' if self.action_noise == 0.0 else '-action-noise{}'.format(
                                                         self.action_noise)) + \
                                                     ('' if self.ob_noise == 0.0 else '-ob-noise{}'.format(
                                                         self.ob_noise))
                                        self.save(os.path.join(self.save_path,
                                                               self.exp_code + "_best_{}_{}_traj{}_{}_{}".format(
                                                                   model_name,
                                                                   self.env.unwrapped.spec.id,
                                                                   self.expert_dataset.num_traj,
                                                                   self.env.mode if isinstance(self.env,
                                                                                               MyObservationWrapper) else 'testing',
                                                                   self.seed)))
                                    elif self.using_airl:
                                        self.save(os.path.join(self.save_path,
                                                               self.exp_code + "_best_airl_{}_traj{}_{}_{}".format(
                                                                   self.env.unwrapped.spec.id,
                                                                   -1.0,  # self.expert_dataset.num_traj
                                                                   self.env.mode if isinstance(self.env,
                                                                                               MyObservationWrapper) else 'testing',
                                                                   self.seed)))
                                    else:
                                        self.save(os.path.join(self.save_path,
                                                               self.exp_code + "_best_trpo_{}_{}_{}".format(
                                                                   self.env.unwrapped.spec.id,
                                                                   self.env.mode if isinstance(self.env,
                                                                                               MyObservationWrapper) else 'testing',
                                                                   self.seed
                                                               )))
                    # Stop training early (triggered by the callback)
                    if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                        break

                    logger.record_tabular("explained_variance_tdlam_before",
                                          explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        assert len(observation) == self.timesteps_per_batch
                        batch_size = self.timesteps_per_batch // self.d_step

                        # NOTE: uses only the last g step for observation
                        d_losses = []  # list of tuples, each of which gives the loss for a minibatch
                        # NOTE: for recurrent policies, use shuffle=False?
                        for ob_batch, ac_batch in dataset.iterbatches((observation, action),
                                                                      include_final_partial_batch=False,
                                                                      batch_size=batch_size,
                                                                      shuffle=True):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch(split='train')
                            ob_expert = ob_expert + np.random.randn(*self.observation_space.shape) * self.action_noise
                            ac_expert = ac_expert + np.random.randn(*self.action_space.shape) * self.ob_noise
                            # update running mean/std for reward_giver
                            if self.reward_giver.normalize:
                                self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))

                            # Reshape actions if needed when using discrete actions
                            if isinstance(self.action_space, gym.spaces.Discrete):
                                if len(ac_batch.shape) == 2:
                                    ac_batch = ac_batch[:, 0]
                                if len(ac_expert.shape) == 2:
                                    ac_expert = ac_expert[:, 0]

                            *newlosses, grad = self.reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad), self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists, zip(*list_lr_pairs))
                        true_reward_buffer.extend(true_rets)
                    elif self.using_airl:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing AIRL Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        # assert len(observation) == self.timesteps_per_batch
                        batch_size = 32
                        start_idx = 0
                        idx = np.random.permutation(len(observation))
                        d_losses = []  # list of tuples, each of which gives the loss for a minibatch
                        for i in range(self.d_step):
                            # batch_size = self.timesteps_per_batch // self.d_step
                            batch_idx = idx[start_idx: batch_size + start_idx]
                            start_idx += batch_size
                            if start_idx + batch_size > len(observation):
                                start_idx = 0
                            expert_idx = self.expert_idx[self.start_idx: self.start_idx + batch_size]
                            self.start_idx += batch_size
                            if self.start_idx + batch_size > self.expert_dataset[0].shape[0]:
                                self.start_idx = 0
                            ob_batch, ac_batch, lprob_batch, nob_batch = observation[batch_idx, :], action[batch_idx,
                                                                                                    :], lprob[batch_idx,
                                                                                                        :], nobs[
                                                                                                            batch_idx,
                                                                                                            :]
                            ob_expert, ac_expert, lprob_expert, nob_expert = self.expert_dataset[0][expert_idx, :], \
                                                                             self.expert_dataset[1][expert_idx,
                                                                             :], lprob_expert_all[expert_idx,
                                                                                 :], nobs_expert_all[expert_idx, :]
                            # rand_actions = np.random.rand(ob_expert.shape[0], 1)
                            # lprob_random = np.expand_dims(calculate_lprobs(ob_expert, rand_actions, self.policy_pi)[:, None], axis=1)
                            # update running mean/std for reward_giver
                            if self.reward_giver.normalize:
                                self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))

                            # Reshape actions if needed when using discrete actions
                            if isinstance(self.action_space, gym.spaces.Discrete):
                                if len(ac_batch.shape) == 2:
                                    ac_batch = ac_batch[:, 0]
                                if len(ac_expert.shape) == 2:
                                    ac_expert = ac_expert[:, 0]

                            self.reward_giver.debug(ob_batch, ac_batch, nob_batch, lprob_batch,
                                                    ob_expert, ac_expert, nob_expert, lprob_expert)
                            *newlosses, grad = self.reward_giver.lossandgrad(ob_batch, ac_batch, lprob_batch, nob_batch,
                                                                             ob_expert, ac_expert, lprob_expert,
                                                                             nob_expert)
                            self.d_adam.update(self.allmean(grad), self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists, zip(*list_lr_pairs))
                        true_reward_buffer.extend(true_rets)
                    else:
                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*list_lr_pairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)

                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean", np.mean(reward_buffer))
                        logger.record_tabular("LogprobMean", np.mean(lprob))
                        logger.record_tabular("LogprobExpertMean", np.mean(lprob_expert_all))
                        # logger.record_tabular("LogprobMax", np.max(lprob))
                    if self.using_gail or self.using_airl:
                        logger.record_tabular("EpTrueRewMean", np.mean(true_reward_buffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    pbar.update(current_it_timesteps)
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        callback.on_training_end()
        return self

    def save(self, save_path, cloudpickle=False):
        data = {
            "gamma": self.gamma,
            "timesteps_per_batch": self.timesteps_per_batch,
            "max_kl": self.max_kl,
            "cg_iters": self.cg_iters,
            "lam": self.lam,
            "entcoeff": self.entcoeff,
            "cg_damping": self.cg_damping,
            "vf_stepsize": self.vf_stepsize,
            "vf_iters": self.vf_iters,
            "hidden_size_adversary": self.hidden_size_adversary,
            "adversary_entcoeff": self.adversary_entcoeff,
            "g_step": self.g_step,
            "d_step": self.d_step,
            "d_stepsize": self.d_stepsize,
            "using_gail": self.using_gail,
            "using_airl": self.using_airl,
            "verbose": self.verbose,
            "policy": self.policy,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "n_envs": self.n_envs,
            "n_cpu_tf_sess": self.n_cpu_tf_sess,
            "seed": self.seed,
            "val_interval": self.val_interval,
            "best_val": self.best_val,
            "best_ret": self.best_ret,
            "num_timesteps": self.num_timesteps,
            "save_path": self.save_path,
            "exp_code": self.exp_code,
            "dropout_date": self.dropout_rate,
            "action_noise": self.action_noise,
            "ob_noise": self.ob_noise,
            "_vectorize_action": self._vectorize_action,
            "policy_kwargs": self.policy_kwargs
        }

        params_to_save = self.get_parameters()

        self._save_to_file(save_path, data=data, params=params_to_save, cloudpickle=cloudpickle)
