import time
from pgdrive.envs.safe_pgdrive_env import SafePGDriveEnv
import numpy as np
from ray.rllib.utils.framework import try_import_tf

import safe_rl.pg.trust_region as tro
from safe_rl.pg.agents import PPOAgent
from safe_rl.pg.buffer import CPOBuffer
from safe_rl.pg.network import count_vars, get_vars, mlp_actor_critic, \
    placeholders, placeholders_from_spaces
from safe_rl.pg.utils import values_as_sorted_list
from safe_rl.utils import core
from safe_rl.utils.logx import EpochLogger, json_dump
from safe_rl.utils.mpi_tf import MpiAdamOptimizer, sync_all_params
from safe_rl.utils.mpi_tools import proc_id, num_procs, mpi_sum

from safe_rl.pg.update_cost_model import update_cost_model, CostModelBuffer, build_cost_model_data, \
    get_accuracy_and_callback

# import tensorflow as tf

tf, _, _ = try_import_tf()



# Multi-purpose agent runner for policy optimization algos
# (PPO, TRPO, their primal-dual equivalents, CPO)
def run_polopt_agent(
        env_fn,
        agent=PPOAgent(),
        actor_critic=mlp_actor_critic,
        ac_kwargs=dict(),
        seed=0,
        render=False,
        # Experience collection:
        steps_per_epoch=4000,
        epochs=50,
        # max_ep_len=1000,
        # Discount factors:
        gamma=0.99,
        lam=0.97,
        cost_gamma=0.99,
        cost_lam=0.97,
        # Policy learning:
        ent_reg=0.,
        # Cost constraints / penalties:
        # cost_lim=25,
        penalty_init=1.,
        penalty_lr=5e-2,
        # KL divergence:
        target_kl=0.01,
        # Value learning:
        vf_lr=1e-3,
        vf_iters=80,
        # Logging:
        logger=None,
        logger_kwargs=dict(),
        save_freq=1,
        tmp_file=None,
        saferl_config=None
):

    eval_env =SafePGDriveEnv(dict(environment_num=100, start_seed=0))
    # Check Safe-RL config
    saferl_config = core.check_saferl_config(saferl_config)

    cost_lim = saferl_config["cost_threshold"]

    # =========================================================================#
    #  Prepare logger, seed, and environment in this process                  #
    # =========================================================================#

    logger = EpochLogger(**logger_kwargs) if logger is None else logger
    logger.save_config(locals())

    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()

    agent.set_logger(logger)

    # =========================================================================#
    #  Create computation graph for actor and critic (not training routine)   #
    # =========================================================================#

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    ac_kwargs["adversarial"] = agent.adversarial

    ac_kwargs["cost_model"] = agent.cost_model

    obs_rgb = saferl_config.get("obs_rgb", False)

    if obs_rgb:
        import gym
        obs_space = gym.spaces.Box(0.0, 1.0, shape=(84, 84, 3),
                                   dtype=np.float32)
    else:
        obs_space = env.observation_space

    # Inputs to computation graph from environment spaces
    x_ph, a_ph = placeholders_from_spaces(
        obs_space, env.action_space
    )

    # Inputs to computation graph for batch data
    adv_ph, cadv_ph, ret_ph, cret_ph, logp_old_ph = placeholders(
        *(None for _ in range(5))
    )

    # Inputs to computation graph for special purposes
    surr_cost_rescale_ph = tf.placeholder(tf.float32, shape=())
    cur_cost_ph = tf.placeholder(tf.float32, shape=())

    # Outputs from actor critic
    ac_outs = actor_critic(x_ph, a_ph, **ac_kwargs)

    if agent.adversarial:
        pi, logp, logp_pi, pi_info, pi_info_phs, d_kl, ent, v, vc, lambda_ = \
            ac_outs
    elif agent.cost_model:
        pi, logp, logp_pi, pi_info, pi_info_phs, d_kl, ent, v, vc, cost_model \
            = ac_outs
    else:
        pi, logp, logp_pi, pi_info, pi_info_phs, d_kl, ent, v, vc = ac_outs

    # Organize placeholders for zipping with data from buffer on updates
    buf_phs = [x_ph, a_ph, adv_ph, cadv_ph, ret_ph, cret_ph, logp_old_ph]
    buf_phs += values_as_sorted_list(pi_info_phs)

    # Organize symbols we have to compute at each step of acting in env
    get_action_ops = dict(pi=pi, v=v, logp_pi=logp_pi, pi_info=pi_info)

    if agent.adversarial:
        get_action_ops["lambda"] = lambda_

    # if agent.cost_model:
    #     get_action_ops["cost_model"] = cost_model

    # If agent is reward penalized, it doesn't use a separate value function
    # for costs and we don't need to include it in get_action_ops; otherwise
    # we do.
    if not (agent.reward_penalized):
        get_action_ops['vc'] = vc

    # Count variables
    var_counts = tuple(count_vars(scope) for scope in ['pi', 'vf', 'vc'])
    logger.log(
        '\nNumber of parameters: \t pi: %d, \t v: %d, \t vc: %d\n' % var_counts
    )

    # Make a sample estimate for entropy to use as sanity check
    approx_ent = tf.reduce_mean(-logp)

    # =========================================================================#
    #  Create replay buffer                                                   #
    # =========================================================================#

    # Obs/act shapes
    obs_shape = env.observation_space.shape
    act_shape = env.action_space.shape

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())

    print("========== DEBUG MESSAGE ==========")
    print("We have {} processes.".format(num_procs()))
    print("steps_per_epoch: {}".format(steps_per_epoch))
    print("local_steps_per_epoch: {}".format(local_steps_per_epoch))
    print("========== DEBUG MESSAGE ==========")

    pi_info_shapes = {k: v.shape.as_list()[1:] for k, v in pi_info_phs.items()}
    buf = CPOBuffer(
        local_steps_per_epoch, obs_shape, act_shape, pi_info_shapes, gamma,
        lam, cost_gamma, cost_lam, saferl_config=saferl_config
    )

    # =========================================================================#
    #  Create computation graph for penalty learning, if applicable           #
    # =========================================================================#

    if agent.use_penalty:
        if not agent.adversarial:
            with tf.variable_scope('penalty'):
                # param_init = np.log(penalty_init)
                param_init = np.log(max(np.exp(penalty_init) - 1, 1e-8))
                penalty_param = tf.get_variable(
                    'penalty_param',
                    initializer=float(param_init),
                    trainable=agent.learn_penalty,
                    dtype=tf.float32
                )
            # penalty = tf.exp(penalty_param)
            penalty = tf.nn.softplus(penalty_param)
        else:
            penalty_param = lambda_
            penalty = tf.nn.softplus(penalty_param)

    if agent.learn_penalty:
        if agent.penalty_param_loss:
            with tf.control_dependencies([
                tf.print("[****] Penalty: ", penalty_param)
            ]):
                penalty_loss = -penalty_param * (cur_cost_ph - cost_lim)
        else:
            penalty_loss = -penalty * (cur_cost_ph - cost_lim)
        train_penalty = MpiAdamOptimizer(
            learning_rate=penalty_lr
        ).minimize(penalty_loss, saferl_config.get("grad_clip", 10.0))

    if agent.cost_model:
        # Build the learning scheme for classifier
        with tf.variable_scope("cost_prob"):
            cost_prob = tf.nn.sigmoid(cost_model)

        cost_model_label = tf.placeholder(tf.float32, shape=(None,))
        cost_model_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
            cost_model_label, cost_model)

        cost_model_lr = saferl_config.get("cost_model_lr", penalty_lr)

        train_cost_model = MpiAdamOptimizer(
            learning_rate=cost_model_lr
        ).minimize(cost_model_loss, saferl_config.get("grad_clip", 10.0))

    # =========================================================================#
    #  Create computation graph for policy learning                           #
    # =========================================================================#

    # Likelihood ratio
    ratio = tf.exp(logp - logp_old_ph)

    if saferl_config.get(core.USE_REJECT):
        # Product > 0 means a bad action: increase reward, increase cost
        # So in that case we reverse the advantage.

        # Note: 20200730
        #   1. This advantage improve both Reward and Cost in 1, 3 quadrant
        #       of x: Qc, y: Qr figure.
        #   2. However, it improve Qc and sacrificing Qr in 2, 4 quadrant.
        # adv_ph = tf.where(adv_ph * cadv_ph > 0, -adv_ph, adv_ph)

        # Update V2: 20200730
        #   1. Now we set 2, 4 quadrant to zero to avoid undefined improvement.
        # adv_ph = tf.where(adv_ph * cadv_ph > 0, 0 * adv_ph, adv_ph)

        # Update V3: 20200811
        #   1. If reward can be improved, then improve, otherwise choose
        #       the direction that can improve cost
        # adv_ph = tf.where(
        #     adv_ph > 0, adv_ph, tf.where(cadv_ph < 0, -adv_ph, adv_ph)
        # )

        # Update V4: 20200813
        #   1. 2nd quad set to 0, while 4th quad set to -Ar
        # adv_ph = tf.where(
        #     adv_ph > 0,
        #     tf.where(cadv_ph > 0, 0 * adv_ph, adv_ph),  # (2nd quad, 1st quad)
        #     tf.where(cadv_ph < 0, -adv_ph, adv_ph)  # (4th quad, 3rd quad)
        # )

        # Update V5: 20200813
        #   1. We want to make sure what break the learning in V1 experiment
        adv_ph = tf.where(
            adv_ph > 0,
            tf.where(cadv_ph > 0, -adv_ph, adv_ph),
            adv_ph
        )

    adv_norm = tf.norm(adv_ph)
    cadv_norm = tf.norm(cadv_ph)

    if agent.cost_model:
        adv_ph = tf.stop_gradient(
            (1 - cost_prob) * adv_ph - cost_prob * cadv_ph
        )

    # Surrogate advantage / clipped surrogate advantage
    if agent.clipped_adv and (
            (not saferl_config.get(core.USE_CTNB)) or (
            saferl_config[core.USE_REJECT] and saferl_config[core.USE_CTNB]
    )
    ):
        # min_adv = tf.where(
        #     adv_ph > 0, (1 + agent.clip_ratio) * adv_ph,
        #     (1 - agent.clip_ratio) * adv_ph
        # )
        # surr_adv = tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))

        surr_adv = tf.reduce_mean(
            tf.minimum(
                adv_ph * ratio,
                adv_ph * tf.clip_by_value(
                    ratio, 1 - agent.clip_ratio, 1 + agent.clip_ratio
                )
            )
        )
    elif agent.clipped_adv and saferl_config.get(core.USE_CTNB):
        assert not saferl_config[core.USE_REJECT]
        two_clip_c = 3.0
        surr_adv_old = tf.minimum(
            adv_ph * ratio,
            adv_ph * tf.
            clip_by_value(ratio, 1 - agent.clip_ratio, 1 + agent.clip_ratio)
        )
        surr_adv = tf.reduce_mean(
            tf.where(
                adv_ph > 0, surr_adv_old,
                tf.maximum(surr_adv_old, two_clip_c * adv_ph)
            )
        )
    else:
        surr_adv = tf.reduce_mean(ratio * adv_ph)

    # Surrogate cost
    if agent.clipped_adv and saferl_config.get(core.USE_CTNB) and (
            not saferl_config[core.USE_REJECT]
    ):
        # V1
        # min_cadv = tf.where(
        #     cadv_ph > 0, (1 + agent.clip_ratio) * cadv_ph,
        #     (1 - agent.clip_ratio) * cadv_ph
        # )
        # surr_cost = tf.reduce_mean(tf.minimum(ratio * cadv_ph, min_cadv))

        # V2
        # max_cadv = tf.where(
        #     cadv_ph < 0, (1 + agent.clip_ratio) * cadv_ph,
        #     (1 - agent.clip_ratio) * cadv_ph
        # )
        # surr_cost = tf.reduce_mean(tf.maximum(ratio * cadv_ph, max_cadv))

        # V3
        # surr_cost = tf.reduce_mean(tf.maximum(
        #     cadv_ph * ratio,
        #     cadv_ph * tf.clip_by_value(ratio, 1 - agent.clip_ratio,
        #                                1 + agent.clip_ratio)))

        # V4
        # Use negative cost-advantage
        two_clip_c = 3.0
        # surr_cost = tf.reduce_mean(
        #     tf.maximum(
        #         tf.minimum(
        #             -cadv_ph * ratio,
        #             -cadv_ph * tf.clip_by_value(ratio, 1 - agent.clip_ratio,
        #                                         1 + agent.clip_ratio)
        #         ),
        #         -two_clip_c * cadv_ph
        #     )
        # )

        # V5
        cadv_ph = -cadv_ph
        surr_cadv_old = tf.minimum(
            cadv_ph * ratio,
            cadv_ph * tf.
            clip_by_value(ratio, 1 - agent.clip_ratio, 1 + agent.clip_ratio)
        )
        surr_cost = tf.reduce_mean(
            tf.where(
                cadv_ph > 0, surr_cadv_old,
                tf.maximum(surr_cadv_old, two_clip_c * cadv_ph)
            )
        )

    else:
        surr_cost = tf.reduce_mean(ratio * cadv_ph)

    # Create policy objective function, including entropy regularization
    if agent.cost_model:
        assert not agent.objective_penalized

    #     cost_model_penalty_weight = saferl_config.get("cost_model_penalty_weight", 0.0)
    #
    #     pi_objective = (1 - cost_prob) * surr_adv - cost_prob * surr_cost \
    #                    + ent_reg * ent \
    #                    + cost_model_penalty_weight * tf.nn.sigmoid_cross_entropy_with_logits(
    #         labels=tf.zeros_like(cost_model), logits=cost_model, name="cost_model_penalty"
    #     )
    # else:
    pi_objective = surr_adv + ent_reg * ent

    # Possibly include surr_cost in pi_objective
    if agent.objective_penalized:
        pi_objective -= penalty * surr_cost
        pi_objective /= (1 + penalty)

    # Loss function for pi is negative of pi_objective
    pi_loss = -pi_objective

    # Optimizer-specific symbols
    if agent.trust_region:

        # Symbols needed for CG solver for any trust region method
        pi_params = get_vars('pi')
        v_ph, hvp = tro.hessian_vector_product(d_kl, pi_params)
        if agent.damping_coeff > 0:
            hvp += agent.damping_coeff * v_ph

        flat_g = tro.flat_grad(pi_loss, pi_params)
        # Symbols needed for CG solver for CPO only
        flat_b = tro.flat_grad(surr_cost, pi_params)
        if saferl_config.get(core.USE_CTNB):
            assert not saferl_config[core.USE_REJECT]
            policy_grad_norm = tf.norm(flat_g)
            safety_grad_norm = tf.norm(flat_b)

            if not saferl_config.get(core.USE_IPD):
                threshold = saferl_config["cost_threshold"]
            else:
                threshold = -100

            flat_g, _, _, cosine_similarity = core.fuse_two_gradient(
                policy_grad_flatten=flat_g,
                safety_grad_flatten=flat_b,
                current_cost_ph=cur_cost_ph,
                threshold=threshold,
                _check_inf=lambda x: core.check_inf(x, False)
            )

        # Symbols for getting and setting params
        get_pi_params = tro.flat_concat(pi_params)
        set_pi_params = tro.assign_params_from_flat(v_ph, pi_params)

        training_package = dict(
            flat_g=flat_g,
            flat_b=flat_b,
            v_ph=v_ph,
            hvp=hvp,
            get_pi_params=get_pi_params,
            set_pi_params=set_pi_params
        )

        if saferl_config.get(core.USE_CTNB):
            training_package.update(
                cosine_similarity=cosine_similarity,
                policy_grad_norm=policy_grad_norm,
                safety_grad_norm=safety_grad_norm
            )

    elif agent.first_order:

        # Optimizer for first-order policy optimization
        if saferl_config.get(core.USE_CTNB) and (
                not saferl_config[core.USE_REJECT]):

            if not saferl_config.get(core.USE_IPD):
                threshold = saferl_config["cost_threshold"]
            else:
                threshold = -100

            train_pi, cosine_similarity, policy_grad_norm, safety_grad_norm = \
                core.ctnb(
                    policy_loss=pi_loss,
                    safety_loss=surr_cost,
                    optimizer=MpiAdamOptimizer(learning_rate=agent.pi_lr),
                    current_cost_ph=cur_cost_ph,
                    threshold=threshold,
                    grad_clip=saferl_config.get("grad_clip", 10.0),
                    test_mode=False
                )
            training_package = dict(
                train_pi=train_pi,
                cosine_similarity=cosine_similarity,
                policy_grad_norm=policy_grad_norm,
                safety_grad_norm=safety_grad_norm
            )
        else:
            train_pi = MpiAdamOptimizer(learning_rate=agent.pi_lr
                                        ).minimize(pi_loss)

            # Prepare training package for agent
            training_package = dict(train_pi=train_pi)
    else:
        raise NotImplementedError

    # Provide training package to agent
    training_package.update(
        dict(
            pi_loss=pi_loss,
            surr_cost=surr_cost,
            d_kl=d_kl,
            target_kl=target_kl,
            cost_lim=cost_lim
        )
    )
    agent.prepare_update(training_package)

    # =========================================================================#
    #  Create computation graph for value learning                            #
    # =========================================================================#

    # Value losses
    v_loss = tf.reduce_mean((ret_ph - v) ** 2)
    vc_loss = tf.reduce_mean((cret_ph - vc) ** 2)

    # If agent uses penalty directly in reward function, don't train a separate
    # value function for predicting cost returns. (Only use one vf for r - p*c.)
    if agent.reward_penalized:
        total_value_loss = v_loss
    else:
        total_value_loss = v_loss + vc_loss

    # Optimizer for value learning
    train_vf = MpiAdamOptimizer(learning_rate=vf_lr).minimize(
        total_value_loss, grad_clip_norm=saferl_config.get("grad_clip", 10.0)
    )

    # =========================================================================#
    #  Create session, sync across procs, and set up saver                    #
    # =========================================================================#

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(
        sess, inputs={'x': x_ph}, outputs={
            'pi': pi,
            'v': v,
            'vc': vc
        }
    )

    # Create buffer of cost model learning:
    if agent.cost_model:
        cost_model_dataset = CostModelBuffer(disable=not saferl_config.get("use_cost_model_buffer", False))

    # =========================================================================#
    #  Provide session to agent                                               #
    # =========================================================================#
    agent.prepare_session(sess)

    # =========================================================================#
    #  Create function for running update (called at end of each epoch)       #
    # =========================================================================#

    def update():
        cur_cost = logger.get_stats('EpCost')[0]
        assert np.isscalar(cur_cost), cur_cost

        c = cur_cost - cost_lim
        if c > 0 and agent.cares_about_cost:
            logger.log(
                'Warning! Safety constraint is already violated.', 'red'
            )

        # =====================================================================#
        #  Prepare feed dict                                                  #
        # =====================================================================#
        buf_get_inputs, buf_get_cost = buf.get()
        inputs = {k: v for k, v in zip(buf_phs, buf_get_inputs)}
        inputs[surr_cost_rescale_ph] = logger.get_stats('EpLen')[0]
        inputs[cur_cost_ph] = cur_cost

        # =====================================================================#
        #  Make some measurements before updating                             #
        # =====================================================================#

        measures = dict(
            LossPi=pi_loss, SurrCost=surr_cost, LossV=v_loss, Entropy=ent, adv_norm=adv_norm, cadv_norm=cadv_norm,
        )
        if not (agent.reward_penalized):
            measures['LossVC'] = vc_loss
        if agent.use_penalty:
            measures['Penalty'] = penalty
        measures['ratio'] = ratio
        if agent.adversarial:
            measures['lambda'] = lambda_

        pre_update_measures = sess.run(measures, feed_dict=inputs)
        logger.store(**pre_update_measures)

        # =====================================================================#
        #  Update penalty if learning penalty                                 #
        # =====================================================================#
        if agent.learn_penalty:
            if not agent.adversarial:
                sess.run(train_penalty, feed_dict={cur_cost_ph: cur_cost})
            else:
                sess.run(train_penalty, feed_dict={
                    cur_cost_ph: cur_cost, x_ph: inputs[x_ph]})

        if agent.cost_model:
            cost_model_length = saferl_config.get("cost_model_length", 20)

            # def _update_cost_model(ground_truth, ground_truth_index):
            #     ret = sess.run(
            #         [train_cost_model, cost_model_loss, cost_prob],
            #         feed_dict={
            #             cost_model_label: ground_truth[ground_truth_index],
            #             x_ph: inputs[x_ph][ground_truth_index],
            #             a_ph: inputs[a_ph][ground_truth_index]
            #         }
            #     )
            #     return ret
            #
            # def _eval_cost_model(ground_truth, cost_model_index_max):
            #     pred_prob = sess.run(
            #         cost_prob,
            #         feed_dict={
            #             cost_model_label: ground_truth[:cost_model_index_max],
            #             x_ph: inputs[x_ph][:cost_model_index_max],
            #             a_ph: inputs[a_ph][:cost_model_index_max]
            #         }
            #     )
            #     return pred_prob

            # First, get the latest data from current batch
            cost_model_data_index, cost_model_ground_truth = build_cost_model_data(cost_model_length, buf_get_cost)

            # Second, save the latest data into the cost model dataset
            cost_model_dataset.save(
                states=inputs[x_ph][cost_model_data_index],
                actions=inputs[a_ph][cost_model_data_index],
                labels=cost_model_ground_truth[cost_model_data_index]
            )

            # Third, retrieve a batch of data from cost model dataset
            cost_model_loss_list = update_cost_model(
                cost_model_dataset, sess, train_cost_model, cost_model_loss, cost_model_label, x_ph, a_ph)

            # cost_model_ground_truth, cost_model_loss_list, \
            # cost_model_pred_val, accuracy, recall, cost_model_dataset_size = \
            # update_cost_model(
            #     cost_model_length, buf_get_cost, _update_cost_model,
            #     _eval_cost_model
            # )

            # states = inputs[x_ph][cost_model_data_index],
            # actions = inputs[a_ph][cost_model_data_index],
            # labels = cost_model_ground_truth[cost_model_data_index]

            pred_prob = sess.run(
                cost_prob,
                feed_dict={
                    cost_model_label: cost_model_ground_truth[cost_model_data_index],
                    x_ph: inputs[x_ph][cost_model_data_index],
                    a_ph: inputs[a_ph][cost_model_data_index]
                }
            )

            accuracy, recall = get_accuracy_and_callback(pred_prob, cost_model_ground_truth[cost_model_data_index])

            cost_model_dataset_size = len(cost_model_dataset)
            logger.store(
                cost_model_prob=cost_model_ground_truth,
                cost_model_loss=cost_model_loss_list,
                cost_model_pred_prob=pred_prob,
                cost_model_accuracy=accuracy,
                cost_model_recall=recall,
                cost_model_dataset_size=cost_model_dataset_size
            )

        # =====================================================================#
        #  Update policy                                                      #
        # =====================================================================#
        agent.update_pi(inputs)

        # =====================================================================#
        #  Update value function                                              #
        # =====================================================================#
        for _ in range(vf_iters):
            sess.run(train_vf, feed_dict=inputs)

        # =====================================================================#
        #  Make some measurements after updating                              #
        # =====================================================================#

        del measures['Entropy']
        measures['KL'] = d_kl

        post_update_measures = sess.run(measures, feed_dict=inputs)

        deltas = dict()
        for k in post_update_measures:
            if k in pre_update_measures:
                deltas['Delta' +
                       k] = post_update_measures[k] - pre_update_measures[k]
        logger.store(KL=post_update_measures['KL'], **deltas)

    # =========================================================================#
    #  Run main environment interaction loop                                  #
    # =========================================================================#

    start_time = time.time()
    o, r, d, c, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0, 0
    ep_velocity = 0
    cur_penalty = 0
    cum_cost = 0

    if saferl_config.get(core.USE_IPD_SOFT):
        ipd_ep_cost = 0

    for epoch in range(epochs):

        early_stop_count_this_iter = 0
        num_episodes_this_iter = 0

        if agent.use_penalty:
            if agent.adversarial:
                cur_penalty = sess.run(penalty, {x_ph: o[np.newaxis]})
            else:
                cur_penalty = sess.run(penalty)

        for t in range(local_steps_per_epoch):

            # Possibly render
            if render and proc_id() == 0 and t < 1000:
                env.render()

            # Get outputs from policy
            get_action_outs = sess.run(get_action_ops, feed_dict={x_ph: o[np.newaxis]})

            a = get_action_outs['pi']
            v_t = get_action_outs['v']
            vc_t = get_action_outs.get(
                'vc', 0
            )  # Agent may not use cost value func
            logp_t = get_action_outs['logp_pi']
            pi_info_t = get_action_outs['pi_info']

            if agent.adversarial:
                lambda_t = get_action_outs["lambda"]

            # if agent.cost_model:
            #     cost_prob = get_action_outs["cost_prob"]

            # Step in environment
            o2, r, d, info = env.step(a)

            # Include penalty on cost
            c = info.get('cost', 0)

            # Track cumulative cost over training
            cum_cost += c

            # save and log
            if agent.reward_penalized:
                r_total = r - cur_penalty * c
                r_total = r_total / (1 + cur_penalty)
                buf.store(o, a, r_total, v_t, 0, 0, logp_t, pi_info_t)
            else:

                if agent.adversarial:
                    buf.store(o, a, r, v_t, c, vc_t, logp_t, pi_info_t)
                else:
                    buf.store(o, a, r, v_t, c, vc_t, logp_t, pi_info_t)
            logger.store(VVals=v_t, CostVVals=vc_t)

            o = o2
            ep_ret += r
            ep_velocity += info["velocity"]
            ep_cost += c
            ep_len += 1

            # Conduct IPD
            if saferl_config.get(core.USE_IPD) or saferl_config.get(
                    core.USE_IPD_SOFT):
                if saferl_config.get(core.USE_IPD_SOFT):
                    ipd_ep_cost += (
                            c - saferl_config["cost_threshold"] /
                            saferl_config["max_ep_len"]
                    )
                    d, info = core.ipd(d, info, ipd_ep_cost, 0.0)
                else:
                    d, info = core.ipd(
                        d, info, ep_cost, saferl_config["cost_threshold"]
                    )
                if info.get("early_stop"):
                    early_stop_count_this_iter += 1

            terminal = d or (ep_len == saferl_config["max_ep_len"])
            if terminal or (t == local_steps_per_epoch - 1):

                # If trajectory didn't reach terminal state, bootstrap value
                # target(s)
                if d and not (ep_len == saferl_config["max_ep_len"]):
                    # Note: we do not count env time out as true terminal state
                    last_val, last_cval = 0, 0
                else:
                    feed_dict = {x_ph: o[np.newaxis]}
                    if agent.reward_penalized:
                        last_val = sess.run(v, feed_dict=feed_dict)
                        last_cval = 0
                    else:
                        last_val, last_cval = sess.run(
                            [v, vc], feed_dict=feed_dict
                        )
                buf.finish_path(last_val, last_cval)

                # Only save EpRet / EpLen if trajectory finished
                if terminal:
                    logger.store(EpRet=ep_ret, EpLen=ep_len, EpCost=ep_cost, EpSuccess=1 if info["arrive_dest"] else 0,
                                 EpVelocity=ep_velocity/ep_len)

                    if saferl_config.get(core.USE_IPD_SOFT):
                        logger.store(IPDEpCost=ipd_ep_cost)

                    num_episodes_this_iter += 1
                else:
                    print(
                        'Warning: trajectory cut off by epoch at %d steps.' %
                        ep_len
                    )

                # Reset environment
                o, r, d, c, ep_ret, ep_len, ep_cost = env.reset(
                ), 0, False, 0, 0, 0, 0
                ep_velocity = 0

                if saferl_config.get(core.USE_IPD_SOFT):
                    ipd_ep_cost = 0

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            # Save more checkpoints than only one by passing epoch
            logger.save_state({'env': env}, epoch)
            # logger.save_state({'env': env}, None)

        # =====================================================================#
        #  Run RL update                                                      #
        # =====================================================================#
        update()

        # =====================================================================#
        #  Cumulative cost calculations                                       #
        # =====================================================================#
        cumulative_cost = mpi_sum(cum_cost)
        cumulative_num_episodes_this_iter = mpi_sum(num_episodes_this_iter)
        cumulative_early_stop_count_this_iter = mpi_sum(
            early_stop_count_this_iter
        )
        cost_rate = cumulative_cost / ((epoch + 1) * steps_per_epoch)
        early_stop_rate = cumulative_early_stop_count_this_iter / \
                          cumulative_num_episodes_this_iter

        env.close()
        evaluation_episode_num=20
        # add evaluation
        episode_reward = 0
        success_num = 0
        episode_num = 0
        episode_cost = 0
        velocity = []
        state = eval_env.reset()
        episode_overtake = []
        eval_step = 0
        while episode_num < evaluation_episode_num:
            # Get outputs from policy
            get_action_outs = sess.run(get_action_ops, feed_dict={x_ph: state[np.newaxis]})
            eval_step += 1
            next_state, r, done, info = eval_env.step(get_action_outs["pi"][0])
            velocity.append(info["velocity"])
            state = next_state
            episode_reward += r
            episode_cost += info["cost"]
            if done or eval_step>saferl_config["max_ep_len"]:
                episode_num += 1
                eval_env.reset()
                eval_step = 0
                if info["arrive_dest"]:
                    success_num += 1
                episode_overtake.append(info["overtake_vehicle_num"])
        eval_res = dict(
            Eval_EpRet=episode_reward / evaluation_episode_num,
            Eval_EpSuccess=success_num / evaluation_episode_num,
            Eval_EpCost=episode_cost / evaluation_episode_num,
            Eval_EpVelocity=np.mean(velocity),
        )
        eval_env.close()
        env.reset()
        logger.store(**eval_res)



        # =====================================================================#
        #  Log performance and stats                                          #
        # =====================================================================#

        logger.log_tabular('Epoch', epoch)

        # Performance stats
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpSuccess', with_min_and_max=True)
        logger.log_tabular('EpVelocity', with_min_and_max=True)
        logger.log_tabular('EpCost', with_min_and_max=True)

        logger.log_tabular('Eval_EpRet', with_min_and_max=True)
        logger.log_tabular('Eval_EpSuccess', with_min_and_max=True)
        logger.log_tabular('Eval_EpVelocity', with_min_and_max=True)
        logger.log_tabular('Eval_EpCost', with_min_and_max=True)

        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('CumulativeCost', cumulative_cost)
        logger.log_tabular('CostRate', cost_rate)
        # Value function values
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('CostVVals', with_min_and_max=True)

        # Pi loss and change
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)

        # Surr cost and change
        logger.log_tabular('SurrCost', average_only=True)
        logger.log_tabular('DeltaSurrCost', average_only=True)

        # V loss and change
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('ratio', average_only=False)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('adv_norm', average_only=False)
        logger.log_tabular('cadv_norm', average_only=False)

        if agent.adversarial:
            logger.log_tabular("lambda", average_only=False)

        if agent.cost_model:
            logger.log_tabular("cost_model_prob", average_only=True)
            logger.log_tabular("cost_model_pred_prob", average_only=True)
            logger.log_tabular("cost_model_loss", average_only=True)
            logger.log_tabular("cost_model_recall", average_only=True)
            logger.log_tabular("cost_model_accuracy", average_only=True)
            logger.log_tabular("cost_model_dataset_size", average_only=True)

        # Vc loss and change, if applicable (reward_penalized agents don't
        # use vc)
        if not (agent.reward_penalized):
            logger.log_tabular('LossVC', average_only=True)
            logger.log_tabular('DeltaLossVC', average_only=True)

        if agent.use_penalty or agent.save_penalty:
            logger.log_tabular('Penalty', average_only=True)
            logger.log_tabular('DeltaPenalty', average_only=True)
        else:
            logger.log_tabular('Penalty', 0)
            logger.log_tabular('DeltaPenalty', 0)

        # Anything from the agent?
        agent.log()

        # Policy stats
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)

        # Time and steps elapsed
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('Time', time.time() - start_time)

        # Modification to fit tune
        logger.log_tabular('timesteps_this_iter', steps_per_epoch)
        logger.log_tabular(
            'episodes_this_iter', cumulative_num_episodes_this_iter
        )

        # Custom stats
        if saferl_config.get(core.USE_IPD) or saferl_config.get(
                core.USE_IPD_SOFT):
            logger.log_tabular('early_stop_rate', early_stop_rate)
            logger.log_tabular(
                'early_stop_this_iter', cumulative_early_stop_count_this_iter
            )

        if saferl_config.get(core.USE_IPD_SOFT):
            logger.log_tabular('IPDEpCost', with_min_and_max=True)

        # Write the latest progress to temporary file
        if (proc_id() == 0) and (tmp_file is not None):
            json_dump(logger.log_current_row, tmp_file)

        # Show results!
        logger.dump_tabular()
