# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn
import torch.optim as optim
from itertools import chain
import lightning as L
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split


from rsl_rl.modules import ActorCritic, GatedActorCriticWithINV, P4RLAsymmetricActorCritic
from rsl_rl.modules.rnd import RandomNetworkDistillation
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import string_to_callable, TrainErrorLogger, get_parameter_grad_norm
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsMLP
from rsl_rl.addons.invdynamics.inv_dynamics_utils import INVModelsEnsemble
from typing import Type, Dict, List, Tuple
from isaaclab.utils.buffers import CircularBuffer



def resolve_optimizer(optimizer_name) -> Type[optim.Adam]:
    """Retrieve optimizer class given its string name."""
    return getattr(optim, optimizer_name)

class PPO:
    """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""

    policy: ActorCritic
    """The actor critic module."""

    def __init__(
        self,
        policy,
        num_learning_epochs=5,
        num_mini_batches=4,
        clip_param=0.2,
        gamma=0.99,
        lam=0.95,
        value_loss_coef=1.0,
        entropy_coef=0.01,
        learning_rate=0.001,
        max_grad_norm=1.0,
        use_clipped_value_loss=True,
        schedule="adaptive",
        desired_kl=0.01,
        device="cpu",
        normalize_advantage_per_mini_batch=False,
        optimizer: str = "Adam", 
        pretrained_module_lr_factor: float|None = None,
        # RND parameters
        rnd_cfg: dict | None = None,
        # Symmetry parameters
        symmetry_cfg: dict | None = None,
        # Distributed training parameters
        multi_gpu_cfg: dict | None = None,
        # P4RL: inv dynamics module parameters (only for data collection, not for RL policy learning)
        inv_dynamics_cfg: dict | None = None,
    ):
        # device-related parameters
        self.device = device
        self.is_multi_gpu = multi_gpu_cfg is not None
        # Multi-GPU parameters
        if multi_gpu_cfg is not None:
            self.gpu_global_rank = multi_gpu_cfg["global_rank"]
            self.gpu_world_size = multi_gpu_cfg["world_size"]
        else:
            self.gpu_global_rank = 0
            self.gpu_world_size = 1

        # P4RL: inv dynamics module components
        if inv_dynamics_cfg is not None:
            self.inv_ensemble = INVModelsEnsemble(inv_dynamics_cfg, device=self.device)
            # note that inv_ensemble is not necessarily a list of models, it can be a single model as well
            # if ensemble_size is 1, a single model will be used, and the prediction error will be intrinsic reward
            # if the ensemble size is greater than 1, the intrinsic reward will be the disagreement between the models

            # if inv_dynamics_cfg["reward_scale"] is not greater than 0, it means intrinsic reward is not used
            # so we set inv_reward_scale to None 
            self.inv_reward_scale = inv_dynamics_cfg["reward_scale"] if inv_dynamics_cfg["reward_scale"] > 0 else None
            self.inv_reward_max = inv_dynamics_cfg["reward_max"]
            self.inv_input_timesteps = inv_dynamics_cfg["input_timesteps"]
        else:
            self.inv_ensemble = None
            self.inv_reward_scale = None

        # RND components
        if rnd_cfg is not None:
            # Extract parameters used in ppo
            rnd_lr = rnd_cfg.pop("learning_rate", 1e-3)
            # Create RND module
            self.rnd = RandomNetworkDistillation(device=self.device, **rnd_cfg)
            # Create RND optimizer
            params = self.rnd.predictor.parameters()
            self.rnd_optimizer = optim.Adam(params, lr=rnd_lr)
        else:
            self.rnd = None
            self.rnd_optimizer = None

        # Symmetry components
        if symmetry_cfg is not None:
            # Check if symmetry is enabled
            use_symmetry = symmetry_cfg["use_data_augmentation"] or symmetry_cfg["use_mirror_loss"]
            # Print that we are not using symmetry
            if not use_symmetry:
                print("Symmetry not used for learning. We will use it for logging instead.")
            # If function is a string then resolve it to a function
            if isinstance(symmetry_cfg["data_augmentation_func"], str):
                symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"])
            # Check valid configuration
            if symmetry_cfg["use_data_augmentation"] and not callable(symmetry_cfg["data_augmentation_func"]):
                raise ValueError(
                    "Data augmentation enabled but the function is not callable:"
                    f" {symmetry_cfg['data_augmentation_func']}"
                )
            # Store symmetry configuration
            self.symmetry = symmetry_cfg
        else:
            self.symmetry = None

        # PPO components
        self.policy = policy
        self.policy.to(self.device)
        # Create optimizer
        self.setup_optimizers(optimizer, learning_rate, pretrained_module_lr_factor)
        # Create rollout storage
        self.storage: RolloutStorage = None  # type: ignore
        self.transition = RolloutStorage.Transition()

        # PPO parameters
        self.clip_param = clip_param
        self.num_learning_epochs = num_learning_epochs
        self.num_mini_batches = num_mini_batches
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.gamma = gamma
        self.lam = lam
        self.max_grad_norm = max_grad_norm
        self.use_clipped_value_loss = use_clipped_value_loss
        self.desired_kl = desired_kl
        self.schedule = schedule
        self.learning_rate = learning_rate
        self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch

    def setup_optimizers(self, optimizer_str: str, learning_rate: float, pretrained_module_lr_factor: float|None):
        """
        P4RL: Setup the optimizer for the PPO algorithm. Note that we use SEPARATE optimizers for actor and critic 
        to support updating only actor or critic.
        This is necessary because if the shared optimizer is used to optimize a loss only dependant on the actor/critic, 
        the zero gradients of the other part will cause the momentum to be problematic.
        """
        optimizer_class: Type[optim.Adam] = resolve_optimizer(optimizer_str)
        if pretrained_module_lr_factor is None:
            actor_params = [
                {"params": self.policy.actor.parameters()}, 
                {"params": self.policy.exploration_params}]
            critic_params = self.policy.critic.parameters()
        else:
            # actor params
            actor_pretrained_params = list(self.policy.get_actor_pretrained_params())
            actor_pretrained_ids = set(map(id, actor_pretrained_params))
            actor_params = [
                {"params": actor_pretrained_params, "lr": learning_rate * pretrained_module_lr_factor},
                {"params": [p for p in self.policy.actor.parameters() if id(p) not in actor_pretrained_ids]},
                {"params": self.policy.exploration_params},
            ]

            # critic params
            critic_pretrained_params = list(self.policy.get_critic_pretrained_params())
            critic_pretrained_ids = set(map(id, critic_pretrained_params))
            critic_params = [
                {"params": critic_pretrained_params, "lr": learning_rate * pretrained_module_lr_factor},
                {"params": [p for p in self.policy.critic.parameters() if id(p) not in critic_pretrained_ids]},
            ]

        self.actor_optimizer = optimizer_class(actor_params, lr=learning_rate)
        self.critic_optimizer = optimizer_class(critic_params, lr=learning_rate)


    def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, actions_shape):
        # create rollout storage
        self.storage = RolloutStorage(
            training_type,
            num_envs,
            num_transitions_per_env,
            obs,
            actions_shape,
            self.device,
        )

    def act(self, obs):
        if self.policy.is_recurrent:
            self.transition.hidden_states = self.policy.get_hidden_states()
        # compute the actions and values
        self.transition.actions = self.policy.act(obs).detach()
        self.transition.values = self.policy.evaluate(obs).detach()
        self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach()
        self.transition.action_mean = self.policy.action_mean.detach()
        self.transition.action_sigma = self.policy.action_std.detach()
        # need to record obs before env.step()
        self.transition.observations = obs
        return self.transition.actions

    def process_env_step(self, obs, rewards, dones, extras, buffers: None | Dict[str, CircularBuffer]=None):
        # update the normalizers
        self.policy.update_normalization(obs)
        if self.rnd:
            self.rnd.update_normalization(obs)

        # Record the rewards and dones
        # Note: we clone here because later on we bootstrap the rewards based on timeouts
        self.transition.rewards = rewards.clone()
        self.transition.dones = dones
        self.intrinsic_rewards = torch.zeros_like(self.transition.rewards)

        # Compute the intrinsic rewards and add to extrinsic rewards
        if self.rnd:
            # Compute the intrinsic rewards
            self.intrinsic_rewards += self.rnd.get_intrinsic_reward(obs)
            # Add intrinsic rewards to extrinsic rewards
            self.transition.rewards += self.intrinsic_rewards

        # P4RL: compute inference reward of inverse dynamics model
        if self.inv_reward_scale:
            assert buffers is not None, "Buffers must be provided for inverse dynamics model, which means the save_trajectories_prob need to be greater than 0.0"
            intrinsic_rewards = self.get_intrinsic_reward_from_inv_dynamics_model(buffers, dones) # TODO
            self.intrinsic_rewards += intrinsic_rewards
            # Add intrinsic rewards to extrinsic rewards
            self.transition.rewards += self.intrinsic_rewards

        # P4RL: compute inference reward of inverse dynamics model
        if self.inv_reward_scale:
            assert buffers is not None, "Buffers must be provided for inverse dynamics model, which means the save_trajectories_prob need to be greater than 0.0"
            intrinsic_rewards = self.get_intrinsic_reward_from_inv_dynamics_model(buffers, dones) # TODO
            self.intrinsic_rewards += intrinsic_rewards
            # Add intrinsic rewards to extrinsic rewards
            self.transition.rewards += self.intrinsic_rewards

        # Bootstrapping on time outs
        if "time_outs" in extras:
            self.transition.rewards += self.gamma * torch.squeeze(
                self.transition.values * extras["time_outs"].unsqueeze(1).to(self.device), 1
            )

        # record the transition
        self.storage.add_transitions(self.transition)
        self.transition.clear()
        self.policy.reset(dones)


    def get_intrinsic_reward_from_inv_dynamics_model(self, buffers: Dict[str, CircularBuffer], dones: torch.Tensor):
        """Compute the intrinsic reward from the inverse dynamics model."""
        if self.inv_ensemble is None:
            raise RuntimeError("Inverse dynamics model is not initialized.")
        # Get the intrinsic reward from the inverse dynamics model
        x = buffers["inv_input_buffer"].buffer[:, -self.inv_input_timesteps:]
        a = buffers["action_hist_buffer"].buffer[:, -self.inv_input_timesteps:] # NOTE: this action index should be -1, because in RL loop after env.step(), the variables are a_t and s_tp1
        
        # be careful with the following intrinsic reward calculation.
        intrinsic_rewards = torch.clip(self.inv_reward_scale*self.inv_ensemble.get_intrinsic_reward(x, a, dones), max=self.inv_reward_max)
        # Add intrinsic rewards to extrinsic rewards
        self.transition.rewards += intrinsic_rewards
        return intrinsic_rewards

    def compute_returns(self, obs):
        # compute value for the last step
        last_values = self.policy.evaluate(obs).detach()
        self.storage.compute_returns(
            last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
        )

    def update(self, update_actor=True, update_critic=True):  # noqa: C901
        """Update the policy and value function using the collected transitions."""

        mean_value_loss = 0
        mean_surrogate_loss = 0
        mean_entropy = 0
        mean_grad_norm_before_clipping = 0

        # -- RND loss
        if self.rnd:
            mean_rnd_loss = 0
        else:
            mean_rnd_loss = None
        # -- Symmetry loss
        if self.symmetry:
            mean_symmetry_loss = 0
        else:
            mean_symmetry_loss = None

        # generator for mini batches
        if self.policy.is_recurrent:
            generator = self.storage.recurrent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        else:
            generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)

        num_updates = 0

        # iterate over batches
        for (
            obs_batch,
            actions_batch,
            target_values_batch,
            advantages_batch,
            returns_batch,
            old_actions_log_prob_batch,
            old_mu_batch,
            old_sigma_batch,
            hid_states_batch,
            masks_batch,
        ) in generator:

            # number of augmentations per sample
            # we start with 1 and increase it if we use symmetry augmentation
            num_aug = 1
            # original batch size
            # we assume policy group is always there and needs augmentation
            original_batch_size = obs_batch["policy"].shape[0]

            # check if we should normalize advantages per mini batch
            if self.normalize_advantage_per_mini_batch:
                with torch.no_grad():
                    advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8)

            # Perform symmetric augmentation
            if self.symmetry and self.symmetry["use_data_augmentation"]:
                # augmentation using symmetry
                data_augmentation_func = self.symmetry["data_augmentation_func"]
                # returned shape: [batch_size * num_aug, ...]
                obs_batch, actions_batch = data_augmentation_func(  # TODO: needs changes on the isaac lab side
                    obs=obs_batch,
                    actions=actions_batch,
                    env=self.symmetry["_env"],
                )
                # compute number of augmentations per sample
                # we assume policy group is always there and needs augmentation
                num_aug = int(obs_batch["policy"].shape[0] / original_batch_size)
                # repeat the rest of the batch
                # -- actor
                old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1)
                # -- critic
                target_values_batch = target_values_batch.repeat(num_aug, 1)
                advantages_batch = advantages_batch.repeat(num_aug, 1)
                returns_batch = returns_batch.repeat(num_aug, 1)

            # Recompute actions log prob and entropy for current batch of transitions
            # Note: we need to do this because we updated the policy with the new parameters
            # -- actor
            self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
            actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch)
            # -- critic
            value_batch = self.policy.evaluate(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
            # -- entropy
            # we only keep the entropy of the first augmentation (the original one)
            mu_batch = self.policy.action_mean[:original_batch_size]
            sigma_batch = self.policy.action_std[:original_batch_size]
            entropy_batch = self.policy.entropy[:original_batch_size]

            # if isinstance(self.actor_critic, GatedActorCriticWithINV):
            #     self.last_gating_values = self.actor_critic.gating_value[:original_batch_size]

            # KL
            if self.desired_kl is not None and self.schedule == "adaptive" and update_actor:
                with torch.inference_mode():
                    kl = torch.sum(
                        torch.log(sigma_batch / old_sigma_batch + 1.0e-5)
                        + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch))
                        / (2.0 * torch.square(sigma_batch))
                        - 0.5,
                        axis=-1,
                    ) # type: ignore
                    kl_mean = torch.mean(kl)

                    # Reduce the KL divergence across all GPUs
                    if self.is_multi_gpu:
                        torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM)
                        kl_mean /= self.gpu_world_size

                    # Update the learning rate
                    # Perform this adaptation only on the main process
                    # TODO: Is this needed? If KL-divergence is the "same" across all GPUs,
                    #       then the learning rate should be the same across all GPUs.
                    if self.gpu_global_rank == 0:
                        if kl_mean > self.desired_kl * 2.0:
                            old_learning_rate = self.learning_rate
                            self.learning_rate = max(1e-5, self.learning_rate / 1.5)
                            scaling_factor = self.learning_rate / old_learning_rate
                        elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
                            old_learning_rate = self.learning_rate
                            self.learning_rate = min(1e-2, self.learning_rate * 1.5)
                            scaling_factor = self.learning_rate / old_learning_rate
                        else:
                            scaling_factor = 1.0

                    # Update the learning rate for all GPUs
                    if self.is_multi_gpu:
                        lr_tensor = torch.tensor(self.learning_rate, device=self.device)
                        torch.distributed.broadcast(lr_tensor, src=0)
                        self.learning_rate = lr_tensor.item()

                    # Update the learning rate for all parameter groups
                    for param_group in self.actor_optimizer.param_groups:
                        param_group["lr"] *= scaling_factor
                    for param_group in self.critic_optimizer.param_groups: # also update critic optimizer learning rate might help stabilize the training
                        param_group["lr"] *= scaling_factor

            # Surrogate loss
            ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
            surrogate = -torch.squeeze(advantages_batch) * ratio
            surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
                ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
            )
            surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

            # Value function loss
            if self.use_clipped_value_loss:
                value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(
                    -self.clip_param, self.clip_param
                )
                value_losses = (value_batch - returns_batch).pow(2)
                value_losses_clipped = (value_clipped - returns_batch).pow(2)
                value_loss = torch.max(value_losses, value_losses_clipped).mean()
            else:
                value_loss = (returns_batch - value_batch).pow(2).mean()

            actor_loss = surrogate_loss - self.entropy_coef * entropy_batch.mean()
            critic_loss = self.value_loss_coef * value_loss

            loss = torch.tensor(0.0, device=self.device)
            if update_actor:
                loss += actor_loss
            if update_critic:
                loss += critic_loss

            # Symmetry loss
            if self.symmetry:
                # obtain the symmetric actions
                # if we did augmentation before then we don't need to augment again
                if not self.symmetry["use_data_augmentation"]:
                    data_augmentation_func = self.symmetry["data_augmentation_func"]
                    obs_batch, _ = data_augmentation_func(obs=obs_batch, actions=None, env=self.symmetry["_env"])
                    # compute number of augmentations per sample
                    num_aug = int(obs_batch.shape[0] / original_batch_size)

                # actions predicted by the actor for symmetrically-augmented observations
                mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone())

                # compute the symmetrically augmented actions
                # note: we are assuming the first augmentation is the original one.
                #   We do not use the action_batch from earlier since that action was sampled from the distribution.
                #   However, the symmetry loss is computed using the mean of the distribution.
                action_mean_orig = mean_actions_batch[:original_batch_size]
                _, actions_mean_symm_batch = data_augmentation_func(
                    obs=None, actions=action_mean_orig, env=self.symmetry["_env"]
                )

                # compute the loss (we skip the first augmentation as it is the original one)
                mse_loss = torch.nn.MSELoss()
                symmetry_loss = mse_loss(
                    mean_actions_batch[original_batch_size:], actions_mean_symm_batch.detach()[original_batch_size:]
                )


                # --------------------P4RL EXPERIMENTAL---------------------
                # if symmetry_loss>0.1 and num_updates>0:
                #     break # stop using this whole batch of data for further update because it is already unstable
                # --------------------P4RL EXPERIMENTAL---------------------


                # add the loss to the total loss
                if self.symmetry["use_mirror_loss"]:
                    loss += self.symmetry["mirror_loss_coeff"] * symmetry_loss
                else:
                    symmetry_loss = symmetry_loss.detach()

            # Random Network Distillation loss
            # TODO: Move this processing to inside RND module.
            if self.rnd:
                # extract the rnd_state
                # TODO: Check if we still need torch no grad. It is just an affine transformation.
                with torch.no_grad():
                    rnd_state_batch = self.rnd.get_rnd_state(obs_batch[:original_batch_size])
                    rnd_state_batch = self.rnd.state_normalizer(rnd_state_batch)
                # predict the embedding and the target
                predicted_embedding = self.rnd.predictor(rnd_state_batch)
                target_embedding = self.rnd.target(rnd_state_batch).detach()
                # compute the loss as the mean squared error
                mseloss = torch.nn.MSELoss()
                rnd_loss = mseloss(predicted_embedding, target_embedding)

            # -- For RND
            if self.rnd:
                self.rnd_optimizer.zero_grad()  # type: ignore
                rnd_loss.backward()
            if self.rnd_optimizer:
                self.rnd_optimizer.step()

            # Compute the gradients
            # -- For PPO

            if update_actor:
                self.actor_optimizer.zero_grad()
            if update_critic:
                self.critic_optimizer.zero_grad()

            if update_actor or update_critic:
                loss.backward()


            # Collect gradients from all GPUs
            if self.is_multi_gpu:
                self.reduce_parameters()

            # Apply the gradients
            # -- For PPO
            # P4RL: clip the gradients separately for actor and critic to decouple their updates, theoretically this should be better
            # The original implementation clips the gradients of the whole model, but it might also bring the benefit that
            # when the value loss is large, the actor gradients are clipped to very small values, so the actor does not 
            # change much until the value loss is at a reasonable level.
            # nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
            # nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
            # (unimplemented) clip that of noise std as well 

            grad_norm_before_clipping = get_parameter_grad_norm(self.policy.parameters())

            nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)

            # grad_norm_1 = get_parameter_grad_norm(self.policy.parameters())

            if update_actor:
                self.actor_optimizer.step()
            if update_critic:
                self.critic_optimizer.step()

            # Store the losses
            mean_value_loss += value_loss.item()
            mean_surrogate_loss += surrogate_loss.item()
            mean_entropy += entropy_batch.mean().item()
            mean_grad_norm_before_clipping += grad_norm_before_clipping.item()
            # -- RND loss
            if mean_rnd_loss is not None:
                mean_rnd_loss += rnd_loss.item()
            # -- Symmetry loss
            if mean_symmetry_loss is not None:
                mean_symmetry_loss += symmetry_loss.item()

            num_updates += 1

        # -- For PPO
        # num_updates = self.num_learning_epochs * self.num_mini_batches
        mean_value_loss /= num_updates
        mean_surrogate_loss /= num_updates
        mean_entropy /= num_updates
        mean_grad_norm_before_clipping /= num_updates
        # -- For RND
        if mean_rnd_loss is not None:
            mean_rnd_loss /= num_updates
        # -- For Symmetry
        if mean_symmetry_loss is not None:
            mean_symmetry_loss /= num_updates
        # -- Clear the storage
        self.storage.clear()

        # construct the loss dictionary
        loss_dict = {
            "value_function": mean_value_loss,
            "surrogate": mean_surrogate_loss,
            "entropy": mean_entropy,
            "grad_norm_before_clipping": mean_grad_norm_before_clipping,
            "num_updates": num_updates,
        }
        if self.rnd:
            loss_dict["rnd"] = mean_rnd_loss
        if self.symmetry:
            loss_dict["symmetry"] = mean_symmetry_loss

        return loss_dict

    """
    Helper functions
    """

    def broadcast_parameters(self):
        """Broadcast model parameters to all GPUs."""
        # obtain the model parameters on current GPU
        model_params = [self.policy.state_dict()]
        if self.rnd:
            model_params.append(self.rnd.predictor.state_dict())
        # broadcast the model parameters
        torch.distributed.broadcast_object_list(model_params, src=0)
        # load the model parameters on all GPUs from source GPU
        self.policy.load_state_dict(model_params[0])
        if self.rnd:
            self.rnd.predictor.load_state_dict(model_params[1])

    def reduce_parameters(self):
        """Collect gradients from all GPUs and average them.

        This function is called after the backward pass to synchronize the gradients across all GPUs.
        """
        # Create a tensor to store the gradients
        grads = [param.grad.view(-1) for param in self.policy.parameters() if param.grad is not None]
        if self.rnd:
            grads += [param.grad.view(-1) for param in self.rnd.parameters() if param.grad is not None]
        all_grads = torch.cat(grads)

        # Average the gradients across all GPUs
        torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)
        all_grads /= self.gpu_world_size

        # Get all parameters
        all_params = self.policy.parameters()
        if self.rnd:
            all_params = chain(all_params, self.rnd.parameters())

        # Update the gradients for all parameters with the reduced gradients
        offset = 0
        for param in all_params:
            if param.grad is not None:
                numel = param.numel()
                # copy data back from shared buffer
                param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data))
                # update the offset for the next parameter
                offset += numel
