import sys
import os
import json
import argparse
import pickle
import torch
import numpy as np
from pathlib import Path
from shutil import rmtree
from mpi4py import MPI
from risk.common.mpi_logger import LoggerMPI
from risk.common.experience_buffer import ExperienceBuffer
from risk.common.utils import get_env_object, get_network_object
from risk.common.mpi_data_utils import mpi_sum, average_grads, sync_weights, mpi_statistics_scalar, \
    mpi_avg, collect_dict_of_lists, mpi_gather_objects, print_now


# CPU/GPU usage regulation.  One can assign more than one thread here, but it is probably best to use 1 in most cases.
os.environ['OMP_NUM_THREADS'] = '1'
torch.set_num_threads(1)


class PolicyOptimizer(object):

    def __init__(self, config):
        """
        Policy Optimization Agent that either uses policy gradient or surrogate objective in conjunction
        with configurable variance reduction measures.
        """
        self.id = MPI.COMM_WORLD.Get_rank()
        self.config = config
        self.process_config()
        self.mode = ''
        self.env = None
        self.obs = None
        self.pi_network = None
        self.v_network = None
        self.pi_optimizer = None
        self.v_optimizer = None
        self.buffer = None
        self.logger = None
        self.epsilon = 1.e-6
        num_workers = MPI.COMM_WORLD.Get_size()
        torch.manual_seed((self.id + 1 + self.config['seed'] * num_workers) * 2000)
        np.random.seed((self.id + 1 + self.config['seed'] * num_workers) * 5000)

    def process_config(self):
        """  Processes configuration, filling in missing values as appropriate  """
        # General training configuration:
        self.config.setdefault('seed', 0)                     # Random seed parameter
        self.config.setdefault('training_frames', int(5e7))   # Number of frames to use for training run
        self.config.setdefault('max_ep_length', -1)           # Maximum episode length (< 0 means no max)
        self.config.setdefault('batch_size', 30000)           # Average number of experiences to base an update on
        self.config.setdefault('pi_lr', .0003)                # Policy optimizer learning rate
        self.config.setdefault('v_lr', .001)                  # Value optimizer learning rate
        self.config.setdefault('train_v_iter', 80)            # Value updates per epoch
        self.config.setdefault('train_pi_iter', 80)           # Policy updates per epoch
        self.config.setdefault('gamma', 0.99)                 # Discount factor gamma
        self.config.setdefault('lambda', 0.95)                # GAE factor lambda
        self.config.setdefault('clip', 0.2)                   # Clip factor for policy trust regions (< 0 means none)
        self.config.setdefault('v_clip', -1)                  # Clip factor for value trust regions (< 0 means none)
        self.config.setdefault('max_kl', -1)                  # KL criteria for early stopping (< 0 means ignore)
        # Variance reduction measures:
        self.config.setdefault('reward_to_go', True)          # Whether or not to use reward-to-go
        self.config.setdefault('gae', True)                   # Whether or not to use generalized advantage estimation
        self.config.setdefault('trust_regions', True)         # Whether or not to use trust regions
        self.config.setdefault('surrogate', False)            # Whether or not to use surrogate objective
        if self.config['gae'] or self.config['trust_regions'] or self.config['surrogate']:
            assert self.config['reward_to_go'], "Given rest of configuration, must use reward-to-go."
            assert 'v_network' in self.config, "Given rest of configuration, must include value network."
        if self.config['surrogate']:
            assert self.config['trust_regions'], "Surrogate objective should only be used with trust regions."
        if not self.config['trust_regions']:
            assert self.config['train_pi_iter'] == 1, "Only update policy once per batch if no trust regions."
            assert self.config['max_kl'] < 0, "Only use early stopping in conjunction with trust regions."
        self.config.setdefault('full_kl', True)               # Whether to use full KL estimation or approximation
        # Testing configuration:
        self.config.setdefault('test_iteration', -1)          # Test latest network
        self.config.setdefault('test_episodes', 1000)
        self.config.setdefault('test_random_base', 100000)
        self.config['test_random_base'] = self.config['test_random_base']*(self.config['seed'] + 1)
        self.config.setdefault('render', False)
        # Logging and storage configurations:
        self.config.setdefault('log_info', True)
        self.config.setdefault('checkpoint_every', int(1e7))
        self.config.setdefault('evaluation_every', -1)        # By default, don't run evaluations
        self.config.setdefault('evaluation_episodes', int(self.config['batch_size'] / self.config['max_ep_length']))
        self.config.setdefault('use_prior_nets', False)       # Whether to pick up where previous training left off
        self.config.setdefault('model_folder', '../../output/rl_training')
        self.config.setdefault('log_folder', '../../logs/rl_training')
        self.config['model_folder'] = os.path.join(os.getcwd(), self.config['model_folder'])
        self.config['log_folder'] = os.path.join(os.getcwd(), self.config['log_folder'])
        if sys.platform[:3] == 'win':
            self.config['model_folder'] = self.config['model_folder'].replace('/', '\\')
            self.config['log_folder'] = self.config['log_folder'].replace('/', '\\')
        if self.id == 0:
            if not self.config['use_prior_nets']:  # start a fresh training run
                if os.path.isdir(self.config['log_folder']):
                    rmtree(self.config['log_folder'], ignore_errors=True)
                if os.path.isdir(self.config['model_folder']):
                    rmtree(self.config['model_folder'], ignore_errors=True)
            Path(self.config['model_folder']).mkdir(parents=True, exist_ok=True)
            Path(self.config['log_folder']).mkdir(parents=True, exist_ok=True)

    def train(self):
        """  Train neural network  """
        # Initialize relevant objects:
        self.mode = 'train'
        self.initialize_env()
        last_checkpoint = self.initialize_networks()
        self.initialize_optimizers()
        self.obs = self.env.reset()
        self.initialize_logger()
        # Run training:
        total_steps = max(last_checkpoint * self.config['checkpoint_every'], 0)
        last_evaluation = total_steps // self.config['evaluation_every'] if total_steps > 0 else -1
        while total_steps < self.config['training_frames']:
            # Collect data:
            self.pi_network.train()
            if self.v_network is not None:
                self.v_network.train()
            self.buffer = ExperienceBuffer()  # reset experience buffer
            steps_current, mean_steps, step_tracker, all_episode_summaries = 0, 0, [0], {}
            while steps_current < self.config['batch_size'] - mean_steps / 2:
                episode_summary = self.run_trajectory()
                steps_current = mpi_sum(MPI.COMM_WORLD, self.buffer.steps)
                step_tracker.append(steps_current - sum(step_tracker))
                mean_steps = sum(step_tracker[1:])/(len(step_tracker[1:]))
                if steps_current == 0:  # first iteration
                    all_episode_summaries = {k: [v] for k, v in episode_summary.items()}
                else:
                    all_episode_summaries = self.concatenate_dict_of_lists(all_episode_summaries, episode_summary)
            previous_steps = total_steps
            total_steps += steps_current
            # Update network(s), as prescribed by config:
            losses = self.update_network()
            # Update logging, save the model:
            local_steps = self.buffer.steps
            if total_steps // self.config['evaluation_every'] > last_evaluation:
                evaluation = self.run_evaluation()
                last_evaluation += 1
            else:
                evaluation = None
            self.update_logging(all_episode_summaries, losses, evaluation, local_steps, previous_steps)
            if self.id == 0:
                self.pi_network.save(os.path.join(self.config['model_folder'], 'model-latest.pt'))
                if self.v_network is not None:
                    self.v_network.save(os.path.join(self.config['model_folder'], 'value-latest.pt'))
                if total_steps // self.config['checkpoint_every'] > last_checkpoint:  # Periodically keep checkpoint
                    last_checkpoint += 1
                    suffix = str(int(last_checkpoint * self.config['checkpoint_every']))
                    self.pi_network.save(os.path.join(self.config['model_folder'], 'model-' + suffix + '.pt'))
                    if self.v_network is not None:
                        self.v_network.save(os.path.join(self.config['model_folder'], 'value-' + suffix + '.pt'))

    def initialize_env(self):
        """  Initialize environment object.  Reset not included here to support CrowdNav.  """
        self.env = get_env_object(self.config)
        if self.mode == 'test':
            self.env.seed(self.id*self.config['test_episodes'] + self.config['test_random_base'])

    def initialize_networks(self):
        """  Initialize network objects  """
        last_checkpoint = -1
        self.pi_network = get_network_object(self.config['pi_network'], self.env)  # env needed only for CrowdNav
        if 'v_network' in self.config:
            self.v_network = get_network_object(self.config['v_network'])
        if self.config['use_prior_nets']:
            if self.mode == 'test' and self.config['test_iteration'] is not None:
                if self.config['test_iteration'] < 0:
                    self.pi_network.restore(os.path.join(self.config['model_folder'], 'model-latest.pt'))
                    if self.v_network is not None:
                        self.v_network.restore(os.path.join(self.config['model_folder'], 'value-latest.pt'))
                else:
                    self.pi_network.restore(os.path.join(self.config['model_folder'],
                                                         'model-' + str(self.config['test_iteration']) + '.pt'))
                    if self.v_network is not None:
                        self.v_network.restore(os.path.join(self.config['model_folder'],
                                                            'value-' + str(self.config['test_iteration']) + '.pt'))
            else:
                last_checkpoint = self.find_last_checkpoint()
                if last_checkpoint > 0:
                    self.pi_network.restore(os.path.join(self.config['model_folder'],
                                                         'model-' + str(last_checkpoint) + '.pt'))
                    if self.v_network is not None:
                        self.v_network.restore(os.path.join(self.config['model_folder'],
                                                            'value-' + str(last_checkpoint) + '.pt'))
                    last_checkpoint = last_checkpoint / self.config['checkpoint_every']
        sync_weights(MPI.COMM_WORLD, self.pi_network.parameters())
        if self.v_network is not None:
            sync_weights(MPI.COMM_WORLD, self.v_network.parameters())
        return last_checkpoint

    def initialize_optimizers(self):
        """  Initializes Adam optimizer for training network.  Only one worker actually updates parameters.  """
        self.pi_optimizer = torch.optim.Adam(params=self.pi_network.parameters(), lr=self.config['pi_lr'])
        if self.v_network is not None:
            self.v_optimizer = torch.optim.Adam(params=self.v_network.parameters(), lr=self.config['v_lr'])

    def initialize_logger(self):
        """  Initialize logger (only on one process)  """
        self.logger = LoggerMPI(self.config['log_folder'])
        self.logger.log_graph(self.obs, self.pi_network)

    def run_trajectory(self, random_seed=None, env=None):
        """  Run trajectories based on current network(s)  """
        episode_buffer, episode_info = np.array([]).reshape(0, 9), {}
        if env is None:
            env = self.env
        num_frames = 0
        while True:
            policy, value, action, log_prob, gauss_action, gauss_log_prob = self.forward_pass()
            output_obs, reward, done, info = self.env.step(action)
            if self.config['log_info']:
                self.concatenate_dict_of_lists(episode_info, info)
            if self.config['render'] and self.mode == 'test':
                env.render()
            num_frames += 1
            if num_frames == self.config['max_ep_length']:
                done = True
            episode_buffer = self.update_episode_buffer(episode_buffer, action, reward, policy, log_prob, gauss_action,
                                                        gauss_log_prob, value, done)
            self.obs = output_obs
            if done:
                if random_seed is not None:
                    env.seed(random_seed)  # for testing
                self.obs = env.reset()
                if self.mode == 'train':
                    q_values = self.compute_target_values(episode_buffer[:, 2])
                    self.buffer.update(episode_buffer, q_values)
                break
        episode_summary = {'episode_reward': np.sum(episode_buffer[:, 2]), 'episode_length': num_frames,
                           'episode_mean_value': np.mean(episode_buffer[:, 7]) if 'v_network' in self.config else None,
                           **{k: sum(v) for k, v in episode_info.items()}}
        return episode_summary

    def forward_pass(self):
        """  Runs forward pass of network(s).  For continuous action spaces, policy will be a tuple of mean, std. """
        if self.v_network is None:
            policy = self.pi_network.forward_with_processing(self.obs)
            value = None
        else:
            policy = self.pi_network.forward_with_processing(self.obs)
            value = self.v_network.forward_with_processing(self.obs)
            value = value.detach().numpy()[0]
        sampling_off = self.mode == 'test'
        action, log_prob, gauss_action, gauss_log_prob = self.pi_network.get_action_and_log_prob(policy, sampling_off)
        return policy, value, action, log_prob, gauss_action, gauss_log_prob

    def update_episode_buffer(self, episode_buffer, action, reward, policy,
                              log_prob, gauss_action, gauss_log_prob, value, done):
        """  Updates episode buffer for current step  """
        if self.pi_network.config['discrete']:
            policy_to_store = np.squeeze(policy.detach().numpy())
        else:
            policy_to_store = np.concatenate((policy[0].detach().numpy(), policy[1].detach().numpy()))
        experience = np.reshape(np.array([self.obs, action, reward, policy_to_store, log_prob, gauss_action,
                                          gauss_log_prob, value, done], dtype=object), (1, 9))
        return np.concatenate((episode_buffer, experience))

    def compute_target_values(self, rewards):
        """  Computes value function targets (without bootstrapping)  """
        trajectory_length = rewards.shape[0]
        if not self.config['reward_to_go']:  # Return full-episode discounted return at each time step
            indices = np.arange(trajectory_length)
            discounts = np.power(self.config['gamma'], indices)
            discounted_rewards = rewards * discounts
            discounted_episode_reward = np.sum(discounted_rewards)
            target_values = np.ones(rewards.shape) * discounted_episode_reward
        else:
            target_values = np.zeros((trajectory_length,))
            for start in range(trajectory_length):
                indices = np.arange(start, trajectory_length)
                discounts = np.power(self.config['gamma'], indices - start)
                discounted_future_rewards = rewards[start:] * discounts
                target_values[start] = np.sum(discounted_future_rewards)
        return target_values

    def update_network(self):
        """
        Updates the central network based on processing from the workers.
        Currently no mini-batching, might eventually include.
        """
        self.pi_network.eval()
        if self.v_network is not None:
            self.v_network.eval()
        # Update value network:
        observations = torch.from_numpy(np.vstack(self.buffer.observations)).float()
        values = torch.from_numpy(self.buffer.values.astype(float)).float()
        pi_losses, v_losses, entropies = [], [], []
        if 'v_network' in self.config:
            for i in range(self.config['train_v_iter']):  # update value function
                self.v_optimizer.zero_grad()
                target_values = torch.from_numpy(self.buffer.q_values.astype(float)).float()
                v_loss = self.compute_value_loss(observations, target_values, values)
                v_losses.append(v_loss.item())
                v_loss.backward()
                average_grads(MPI.COMM_WORLD, self.v_network.parameters())
                if self.id == 0:
                    self.v_optimizer.step()
                sync_weights(MPI.COMM_WORLD, self.v_network.parameters())
            self.update_values(observations)
        # Update advantage estimates, standardizing across workers:
        self.estimate_advantages()
        mean_adv, std_adv = mpi_statistics_scalar(MPI.COMM_WORLD, self.buffer.advantages)
        self.buffer.advantages = (self.buffer.advantages - mean_adv) / std_adv
        advantages = torch.from_numpy(self.buffer.advantages.astype(float)).float()
        # Update policy network:
        actions = self.buffer.actions
        log_probs = torch.from_numpy(self.buffer.log_probs.astype(float)).float()
        for i in range(self.config['train_pi_iter']):
            self.pi_optimizer.zero_grad()
            pi_loss, entropy, kld = self.compute_policy_loss(observations, actions, advantages, log_probs)
            if self.config['surrogate']:
                pi_loss = torch.mean(pi_loss, dim=0)  # assumes same number of experiences per worker
            else:
                number_of_episodes = np.sum(self.buffer.dones)
                pi_loss = torch.sum(pi_loss, dim=0) / number_of_episodes
            pi_losses.append(pi_loss.item())
            entropies.append(entropy.item())
            mean_kld = mpi_avg(MPI.COMM_WORLD, kld)
            if mean_kld > self.config['max_kl'] > 0:
                if self.id == 0:
                    print_now('Policy KL divergence exceeds limit; stopping update at step %d.' % i)
                break
            pi_loss.backward()
            average_grads(MPI.COMM_WORLD, self.pi_network.parameters())
            if self.id == 0:
                self.pi_optimizer.step()
            sync_weights(MPI.COMM_WORLD, self.pi_network.parameters())
        return {'pi_losses': pi_losses, 'v_losses': v_losses, 'entropies': entropies}

    def compute_value_loss(self, observations, target_values, old_values):
        """  Compute value function loss  """
        new_values = self.v_network(observations).view(-1)
        if self.config['v_clip'] > 0:
            clipped_values = old_values + torch.clamp(new_values - old_values,
                                                      -self.config['v_clip'], self.config['v_clip'])
            vf_losses_1 = torch.pow(new_values - target_values, 2)
            vf_losses_2 = torch.pow(clipped_values - target_values, 2)
            value_loss = torch.mean(torch.max(vf_losses_1, vf_losses_2), dim=0)
        else:
            value_loss = torch.mean(torch.pow(new_values - target_values, 2), dim=0)
        return value_loss

    def update_values(self, observations):
        """  Estimate values with updated value network, store in buffer  """
        self.buffer.values = self.v_network(observations).view(-1).detach().numpy()

    def estimate_advantages(self):
        """  Estimate advantages for a sequence of observations and q values  """
        if not self.config['reward_to_go']:
            self.buffer.advantages = self.buffer.q_values
        else:
            if 'v_network' in self.config:
                if self.config['gae']:
                    self.buffer.advantages = self.estimate_generalized_advantage()
                else:
                    self.buffer.advantages = self.buffer.q_values - self.buffer.values
            else:
                self.buffer.advantages = deepcopy(self.buffer.q_values)

    def estimate_generalized_advantage(self):
        """  Generalized advantage estimation, given rewards and value estimates for a given episode  """
        rewards, values, dones = self.buffer.rewards, self.buffer.values, self.buffer.dones
        terminals = np.nonzero(dones.astype(int))[0]
        terminals = np.concatenate((np.array([-1]), terminals)).tolist()
        gae = np.zeros(rewards.shape)
        for i in range(len(terminals[:-1])):
            episode_rewards = rewards[terminals[i] + 1:terminals[i + 1] + 1]
            episode_values = values[terminals[i] + 1:terminals[i + 1] + 1]
            episode_next_values = np.concatenate((episode_values[1:], np.array([0])))  # end-of-episode bootstrap is 0
            episode_deltas = episode_rewards + self.config['gamma'] * episode_next_values - episode_values
            for start in range(len(episode_values)):
                indices = np.arange(start, len(episode_values))
                discounts = np.power(self.config['gamma'] * self.config['lambda'], indices - start)
                discounted_future_deltas = episode_deltas[start:] * discounts
                gae[start + terminals[i] + 1] = np.sum(discounted_future_deltas)
        return gae

    def compute_policy_loss(self, observations, actions, advantages, old_log_probs):
        """  Compute policy loss and entropy  """
        new_policies = self.pi_network(observations)
        clipped_new_log_probs = torch.zeros(old_log_probs.size())
        if self.pi_network.config['discrete']:
            actions_one_hot = torch.from_numpy(
                np.eye(self.pi_network.config['action_dim'])[np.squeeze(actions)]).float()
            new_policies = torch.masked_fill(new_policies, new_policies < self.epsilon, self.epsilon)
            new_log_probs = torch.sum(torch.log(new_policies) * actions_one_hot, dim=1)
            if self.config['trust_regions']:
                log_diff = torch.clamp(new_log_probs - old_log_probs,
                                       np.log(1.0 - self.config['clip']), np.log(1.0 + self.config['clip']))
                clipped_new_log_probs = old_log_probs + log_diff
            entropy = -torch.mean(new_policies * torch.log(new_policies))
        else:
            new_means, new_stds = new_policies
            new_dist = torch.distributions.Normal(new_means, new_stds)
            if self.pi_network.config['squashed']:
                gauss_actions_torch = torch.from_numpy(np.vstack(self.buffer.gauss_actions)).float()
                new_gauss_log_probs = new_dist.log_prob(gauss_actions_torch).sum(dim=-1)
                old_gauss_log_probs = torch.from_numpy(self.buffer.gauss_log_probs.astype(float)).float()
                new_noise = (self.buffer.gauss_actions - new_means.detach().numpy()) / new_stds.detach().numpy()
                diff_gauss_actions = new_means + new_stds * torch.from_numpy(new_noise).float()
                correction = torch.sum(torch.log(1 - torch.tanh(diff_gauss_actions) ** 2 + self.epsilon), dim=-1)
                new_log_probs = new_gauss_log_probs - correction
                if self.config['trust_regions']:
                    log_diff = torch.clamp(new_gauss_log_probs - old_gauss_log_probs,
                                           np.log(1.0 - self.config['clip']), np.log(1.0 + self.config['clip']))
                    clipped_new_log_probs = old_gauss_log_probs + log_diff - correction
                entropy = -torch.mean(new_gauss_log_probs - correction)  # approximate!
            else:
                actions_torch = torch.from_numpy(np.vstack(actions)).float()
                new_log_probs = torch.sum(new_dist.log_prob(actions_torch), dim=-1)
                entropy = torch.mean(torch.log(new_stds) + .5 * np.log(2 * np.pi * np.e))
                if self.config['trust_regions']:
                    log_diff = torch.clamp(new_log_probs - old_log_probs,
                                           np.log(1.0 - self.config['clip']), np.log(1.0 + self.config['clip']))
                    clipped_new_log_probs = old_log_probs + log_diff
        if self.config['surrogate']:  # trust regions assumed
            pi_losses_1 = -advantages * torch.exp(new_log_probs - old_log_probs)
            pi_losses_2 = -advantages * torch.exp(clipped_new_log_probs - old_log_probs)
            pi_loss = torch.max(pi_losses_1, pi_losses_2)
        else:
            if self.config['trust_regions']:
                pi_losses_1 = -advantages * new_log_probs
                pi_losses_2 = -advantages * clipped_new_log_probs
                pi_loss = torch.max(pi_losses_1, pi_losses_2)
            else:
                pi_loss = -advantages * new_log_probs
        kld = self.compute_kld(new_policies, old_log_probs, new_log_probs)
        return pi_loss, entropy, kld

    def compute_kld(self, policy_predictions, old_log_probs, new_log_probs):
        """ Compute KL divergence for early stopping  """
        if self.config['full_kl']:  # compute full kld
            if self.pi_network.config['discrete']:
                old_policies = self.buffer.policies
                new_policies = policy_predictions.detach().numpy()
                all_terms = new_policies * (np.log(new_policies) - np.log(old_policies))
                kld = np.mean(np.sum(all_terms, axis=1))
            else:
                mu_old, sigma_old = np.split(self.buffer.policies, 2, axis=1)
                mu_new, sigma_new = policy_predictions[0].detach().numpy(), policy_predictions[1].detach().numpy()
                if not self.pi_network.config['log_std_net']:
                    sigma_new = np.repeat(np.expand_dims(sigma_new, 0), sigma_old.shape[0], axis=0)
                var_old, var_new = sigma_old ** 2, sigma_new ** 2
                all_kld = np.log(sigma_new/sigma_old) + 0.5*(((mu_new-mu_old)**2 + var_old)/(var_new + 1.e-8) - 1)
                kld = np.mean(np.sum(all_kld, axis=1))
        else:  # use approximation from Spinning Up
            kld = (old_log_probs - new_log_probs).mean().item()
        return kld

    def run_evaluation(self):
        """  Run evaluation episodes, collect data for logging  """
        # Run evaluation episodes:
        self.mode = 'test'
        self.obs = self.env.reset()  # make sure environment is reset
        local_episodes, total_episodes, local_episode_summaries = 0, 0, {}
        self.pi_network.eval()
        if self.v_network is not None:
            self.v_network.eval()
        while total_episodes < self.config['evaluation_episodes']:
            self.buffer = ExperienceBuffer()  # reset experience buffer
            episode_summary = self.run_trajectory()
            if len(local_episode_summaries.keys()) == 0:  # first iteration
                local_episode_summaries = {k: [v] for k, v in episode_summary.items()}
            else:
                local_episode_summaries = self.concatenate_dict_of_lists(local_episode_summaries, episode_summary)
            local_episodes += 1
            total_episodes = int(mpi_sum(MPI.COMM_WORLD, local_episodes))
        # Put back to resume training:
        self.mode = 'train'
        self.obs = self.env.reset()
        # Collect, process, return data:
        episode_data = collect_dict_of_lists(MPI.COMM_WORLD, local_episode_summaries)
        evaluation_metrics = self.compute_metrics(episode_data)
        evaluation_info = self.process_evaluation_info(episode_data)
        return {**evaluation_metrics, **evaluation_info}

    def compute_metrics(self, episode_data):
        """  Computes metrics to be evaluated as learning progresses  """
        mean_reward = sum(episode_data['episode_reward']) / len(episode_data)
        return {'mean': mean_reward}

    def update_logging(self, episode_summaries, losses, evaluation, steps, previous_steps):
        """  Updates TensorBoard logging based on most recent update  """
        local_keys = list(episode_summaries.keys())
        all_keys = mpi_gather_objects(MPI.COMM_WORLD, local_keys)
        keys_in_each = self.find_common(all_keys)
        for k in keys_in_each:
            self.logger.log_mean_value('Performance/' + k, episode_summaries[k], steps, previous_steps)
        for k, v in losses.items():
            self.logger.log_mean_value('Losses/' + k, v, steps, previous_steps)
        if evaluation is not None:
            for k, v in evaluation.items():
                self.logger.log_scalar('Evaluation/' + k, v, steps, previous_steps)
        self.logger.flush()

    def test(self):
        """  Run testing episodes with fixed random seed, collect and save data  """
        # Run testing episodes:
        self.mode = 'test'
        self.initialize_env()
        self.initialize_networks()
        self.obs = self.env.reset()
        local_episodes, total_episodes, local_episode_summaries = 0, 0, {}
        self.pi_network.eval()
        if self.v_network is not None:
            self.v_network.eval()
        while total_episodes < self.config['test_episodes']:
            self.buffer = ExperienceBuffer()  # reset experience buffer
            random_seed = self.id*self.config['test_episodes'] + self.config['test_random_base'] + local_episodes + 1
            episode_summary = self.run_trajectory(int(random_seed))
            if len(local_episode_summaries.keys()) == 0:  # first iteration
                local_episode_summaries = {k: [v] for k, v in episode_summary.items()}
            else:
                local_episode_summaries = self.concatenate_dict_of_lists(local_episode_summaries, episode_summary)
            local_episodes += 1
            total_episodes = int(mpi_sum(MPI.COMM_WORLD, local_episodes))
            if self.id == 0:
                print_now(str(total_episodes) + ' episodes complete.')
        # Collect, process, save data:
        test_output = collect_dict_of_lists(MPI.COMM_WORLD, local_episode_summaries)
        self.store_test_results(test_output)
        return test_output

    def store_test_results(self, test_output):
        """  Save a pickle with test results  """
        if self.id == 0:
            test_file = os.path.join(self.config['model_folder'], 'test_results.pkl')
            with open(test_file, 'wb') as opened_test:
                pickle.dump(test_output, opened_test)

    @staticmethod
    def process_evaluation_info(episode_data):
        """  Processes evaluation info, returning dictionary of mean values """
        mean_info = {}
        for k, v in episode_data.items():
            if k[:5] == 'info_':
                mean_info[k] = sum(v) / len(v)
        return mean_info

    @staticmethod
    def concatenate_dict_of_arrays(base_dict, new_dict):
        """  Collect a dictionary of numpy arrays  """
        for k in new_dict:
            base_dict[k] = np.concatenate((base_dict[k], new_dict[k]))
        return base_dict

    @staticmethod
    def concatenate_dict_of_lists(base_dict, new_dict):
        """  Collect a dictionary of lists  """
        for k in new_dict:
            if k in base_dict:
                base_dict[k].append(new_dict[k])
            else:
                base_dict[k] = [new_dict[k]]
        return base_dict

    @staticmethod
    def find_common(list_of_lists):
        """  Returns members common to each list in a list of lists  """
        common = set(list_of_lists[0])
        for item in list_of_lists[1:]:
            common = common.intersection(set(item))
        return sorted(list(common))


if __name__ == '__main__':
    """  Runs ConfigurablePolicyGradient training or testing for a given input configuration file  """
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', help='Configuration file to run', required=True)
    parser.add_argument('--mode', default='train', required=False, help='mode ("train" or "test")')
    in_args = parser.parse_args()
    with open(os.path.join(os.getcwd(), in_args.config), 'r') as f1:
        config1 = json.load(f1)
    if in_args.mode.lower() == 'train':
        pg_object = PolicyOptimizer(config1)
        pg_object.train()
    else:
        config1['use_prior_nets'] = True
        pg_object = PolicyOptimizer(config1)
        pg_object.test()
