# ===== Diverse Init ver. =====
# from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

# import numpy as np
# import torch as th
# from gym import spaces
# from torch.nn import functional as F

# from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
# from stable_baselines3.common.noise import ActionNoise
# from stable_baselines3 import DDPG
# from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
# from stable_baselines3.common.utils import (
#     get_linear_fn,
#     get_parameters_by_name,
#     polyak_update,
# )
# from stable_baselines3.common.policies import BasePolicy
# from rl_zoo3.gamid.policies import GamidPolicy, MlpPolicy
# from stable_baselines3.her.her_replay_buffer import HerReplayBuffer

# from sklearn.mixture import GaussianMixture
# from scipy.stats import multivariate_normal
# from scipy import linalg

# from torch.distributions import Categorical

# SelfGamid = TypeVar("SelfGamid", bound="Gamid")


# class Gamid(DDPG):

#     policy_aliases: Dict[str, Type[BasePolicy]] = {
#         "MlpPolicy": MlpPolicy,
#     }

#     def __init__(
#         self,
#         policy: Union[str, Type[GamidPolicy]],
#         env: Union[GymEnv, str],
#         learning_rate: Union[float, Schedule] = 0.001,
#         buffer_size: int = 1000000,
#         learning_starts: int = 100,
#         batch_size: int = 100,
#         tau: float = 0.005,
#         gamma: float = 0.99,
#         train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
#         gradient_steps: int = -1,
#         action_noise: Optional[ActionNoise] = None,
#         replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
#         replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
#         optimize_memory_usage: bool = False,
#         policy_delay: int = 1,
#         target_policy_noise: float = 0.1,
#         target_noise_clip: float = 0.0,
#         n_actors: int = 1,
#         n_critics: int = 1,
#         temperature_initial: float = 0.9,
#         temperature_final: float = 0.5,
#         temperature_fraction: float = 0.5,
#         exploration_fraction: float = 0.1,
#         exploration_initial_eps: float = 0.5,
#         exploration_final_eps: float = 0.1,
#         actors_loss_fn: Optional[str] = None,
#         tensorboard_log: Optional[str] = None,
#         policy_kwargs: Optional[Dict[str, Any]] = None,
#         verbose: int = 0,
#         seed: Optional[int] = None,
#         device: Union[th.device, str] = "auto",
#         _init_setup_model: bool = False,
#     ):

#         # Coefficients for actor 2 loss components
#         self.n_actors = n_actors
#         self.actors_loss_fn = actors_loss_fn

#         self._n_calls = 0

#         self.temperature_initial = temperature_initial
#         self.temperature_final = temperature_final
#         self.temperature_fraction = temperature_fraction

#         self.temperature = 0.0
#         self.temperature_schedule = None

#         # Epsilon greedy selection of actors
#         self.exploration_initial_eps = exploration_initial_eps
#         self.exploration_final_eps = exploration_final_eps
#         self.exploration_fraction = exploration_fraction

#         self.exploration_rate = 0.0
#         self.exploration_schedule = None

#         super().__init__(
#             policy=policy,
#             env=env,
#             learning_rate=learning_rate,
#             buffer_size=buffer_size,
#             learning_starts=learning_starts,
#             batch_size=batch_size,
#             tau=tau,
#             gamma=gamma,
#             train_freq=train_freq,
#             gradient_steps=gradient_steps,
#             action_noise=action_noise,
#             replay_buffer_class=replay_buffer_class,
#             replay_buffer_kwargs=replay_buffer_kwargs,
#             optimize_memory_usage=optimize_memory_usage,
#             tensorboard_log=tensorboard_log,
#             policy_kwargs=policy_kwargs,
#             verbose=verbose,
#             seed=seed,
#             device=device,
#             _init_setup_model=False,
#         )

#         self.policy_delay = policy_delay
#         self.target_noise_clip = target_noise_clip
#         self.target_policy_noise = target_policy_noise

#         self.policy_kwargs["n_actors"] = n_actors
#         self.policy_kwargs["n_critics"] = n_critics

#         self.actor, self.actor_target = None, None
#         self.critic, self.critic_target = None, None

#         self._setup_model()

#     def _setup_model(self) -> None:
#         self._setup_lr_schedule()
#         self.set_random_seed(self.seed)

#         # Use DictReplayBuffer if needed
#         if self.replay_buffer_class is None:
#             if isinstance(self.observation_space, spaces.Dict):
#                 self.replay_buffer_class = DictReplayBuffer
#             else:
#                 self.replay_buffer_class = ReplayBuffer

#         elif self.replay_buffer_class == HerReplayBuffer:
#             assert (
#                 self.env is not None
#             ), "You must pass an environment when using `HerReplayBuffer`"

#             # If using offline sampling, we need a classic replay buffer too
#             if self.replay_buffer_kwargs.get("online_sampling", True):
#                 replay_buffer = None
#             else:
#                 replay_buffer = DictReplayBuffer(
#                     self.buffer_size,
#                     self.observation_space,
#                     self.action_space,
#                     device=self.device,
#                     optimize_memory_usage=self.optimize_memory_usage,
#                 )

#             self.replay_buffer = HerReplayBuffer(
#                 self.env,
#                 self.buffer_size,
#                 device=self.device,
#                 replay_buffer=replay_buffer,
#                 **self.replay_buffer_kwargs,
#             )

#         if self.replay_buffer is None:
#             self.replay_buffer = self.replay_buffer_class(
#                 self.buffer_size,
#                 self.observation_space,
#                 self.action_space,
#                 device=self.device,
#                 n_envs=self.n_envs,
#                 optimize_memory_usage=self.optimize_memory_usage,
#                 **self.replay_buffer_kwargs,
#             )

#         self.policy = self.policy_class(  # pytype:disable=not-instantiable
#             self.observation_space,
#             self.action_space,
#             self.lr_schedule,
#             **self.policy_kwargs,  # pytype:disable=not-instantiable
#         )
#         self.policy = self.policy.to(self.device)

#         # Convert train freq parameter to TrainFreq object
#         self._convert_train_freq()

#         self._create_aliases()
#         # Running mean and running var
#         self.actor_batch_norm_stats = get_parameters_by_name(self.actor, ["running_"])
#         self.critic_batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
#         self.actor_batch_norm_stats_target = get_parameters_by_name(
#             self.actor_target, ["running_"]
#         )
#         self.critic_batch_norm_stats_target = get_parameters_by_name(
#             self.critic_target, ["running_"]
#         )

#         self.temperature_schedule = get_linear_fn(
#             self.temperature_initial, self.temperature_final, self.temperature_fraction
#         )

#         self.exploration_schedule = get_linear_fn(
#             self.exploration_initial_eps,
#             self.exploration_final_eps,
#             self.exploration_fraction,
#         )

#         self.greedy_actor_count = th.zeros(self.n_actors).to(self.device)

#     def _create_aliases(self) -> None:
#         self.actor = self.policy.actor
#         self.actor_target = self.policy.actor_target
#         self.critic = self.policy.critic
#         self.critic_target = self.policy.critic_target

#     def _on_step(self) -> None:
#         """
#         Update the exploration rate and target network if needed.
#         This method is called in ``collect_rollouts()`` after each step in the environment.
#         """
#         self._n_calls += 1

#         self.temperature = self.temperature_schedule(self._current_progress_remaining)
#         self.exploration_rate = self.exploration_schedule(
#             self._current_progress_remaining
#         )

#         actions = self.policy._predict(
#             th.tensor(self._last_obs).to(self.device), deterministic=True
#         )
#         q_values = []
#         for action in actions:
#             q_values.append(
#                 self.policy.critic_target(
#                     th.tensor(self._last_obs).to(self.device),
#                     action,
#                 )
#             )
#         actions = th.stack(list(actions), dim=0)
#         # Convert list to tensor
#         q_values = th.FloatTensor(q_values).squeeze(dim=1)
#         # Min Q from among multiple critics
#         q_values, _ = th.min(q_values, dim=-1)

#         _, actor_idx = th.max(q_values, dim=-1)
#         actor_idx = actor_idx.to(self.device)

#         self.greedy_actor_count[actor_idx] += 1
#         actor_spread = Categorical(
#             probs=self.greedy_actor_count / self.greedy_actor_count.sum()
#         ).entropy()

#         self.logger.record("train/temperature", self.temperature)
#         self.logger.record("train/epsilon", self.exploration_rate)
#         self.logger.record("train/actor_spread", actor_spread.item())

#     def predict(
#         self,
#         observation: Union[np.ndarray, Dict[str, np.ndarray]],
#         state: Optional[Tuple[np.ndarray, ...]] = None,
#         episode_start: Optional[np.ndarray] = None,
#         deterministic: bool = False,
#         exploration_rate: float = 0,
#     ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
#         """
#         Get the policy action from an observation (and optional hidden state).
#         Includes sugar-coating to handle different observations (e.g. normalizing images).

#         :param observation: the input observation
#         :param state: The last hidden states (can be None, used in recurrent policies)
#         :param episode_start: The last masks (can be None, used in recurrent policies)
#             this correspond to beginning of episodes,
#             where the hidden states of the RNN must be reset.
#         :param deterministic: Whether or not to return deterministic actions.
#         :param actor_selection_probs: Probabilities of selecting actors.
#         :return: the model's action and the next hidden state
#             (used in recurrent policies)
#         """
#         return self.policy.predict(
#             observation, state, episode_start, deterministic, exploration_rate
#         )

#     def _sample_action(
#         self,
#         learning_starts: int,
#         action_noise=None,
#         n_envs: int = 1,
#     ):
#         """
#         Sample an action according to the exploration policy.
#         This is either done by sampling the probability distribution of the policy,
#         or sampling a random action (from a uniform distribution over the action space)
#         or by adding noise to the deterministic output.

#         :param action_noise: Action noise that will be used for exploration
#             Required for deterministic policy (e.g. TD3). This can also be used
#             in addition to the stochastic policy for SAC.
#         :param learning_starts: Number of steps before learning for the warm-up phase.
#         :param n_envs:
#         :return: action to take in the environment
#             and scaled action that will be stored in the replay buffer.
#             The two differs when the action space is not normalized (bounds are not [-1, 1]).
#         """
#         # Select action randomly or according to policy
#         if self.num_timesteps < learning_starts and not (
#             self.use_sde and self.use_sde_at_warmup
#         ):
#             # Warmup phase
#             unscaled_action = np.array(
#                 [self.action_space.sample() for _ in range(n_envs)]
#             )
#         else:
#             # Note: when using continuous actions,
#             # we assume that the policy uses tanh to scale the action
#             # We use non-deterministic action in the case of SAC, for TD3, it does not matter
#             unscaled_action, _ = self.predict(
#                 self._last_obs,
#                 deterministic=False,
#                 exploration_rate=self.exploration_rate,
#             )

#         # Rescale the action from [low, high] to [-1, 1]
#         if isinstance(self.action_space, spaces.Box):
#             scaled_action = self.policy.scale_action(unscaled_action)

#             # Add noise to the action (improve exploration)
#             if action_noise is not None:
#                 scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

#             # We store the scaled action in the buffer
#             buffer_action = scaled_action
#             # scaled_action = np.power(scaled_action, 2) - np.power(scaled_action, 3) + 1
#             action = self.policy.unscale_action(scaled_action)
#         else:
#             # Discrete case, no need to normalize or clip
#             buffer_action = unscaled_action
#             action = buffer_action

#         return action, buffer_action

#     def log_loss(self, action_1, action_2):
#         """
#         Loss: Log distance between actions.
#         """
#         # loss = th.exp(-th.norm(action_1 - action_2, dim=0) ** 2 / 0.01).mean()
#         loss = -th.log(th.norm(action_1 - action_2, dim=1) + 0.01).mean()
#         return loss

#     def mse_loss(self, action_1, action_2):
#         """
#         Loss: MSE between actions.
#         """
#         loss = -th.norm(action_1 - action_2, p=2, dim=1).mean()
#         return loss

#     def train(self, gradient_steps: int, batch_size: int = 100) -> None:
#         # Switch to train mode (this affects batch norm / dropout)
#         self.policy.set_training_mode(True)

#         # Update learning rate according to lr schedule
#         self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])

#         actor_losses, diversity_losses, critic_losses, distances_means = [], [], [], []
#         for _ in range(gradient_steps):

#             self._n_updates += 1
#             # Sample replay buffer
#             replay_data = self.replay_buffer.sample(
#                 batch_size, env=self._vec_normalize_env
#             )

#             with th.no_grad():
#                 # Select action according to policy
#                 next_actions_all = th.stack(
#                     self.actor_target(replay_data.next_observations), dim=0
#                 )

#                 # print(f"{next_actions_all}")
#                 # input()

#                 next_q_values_all = th.stack(
#                     [
#                         th.cat(
#                             self.critic_target(
#                                 replay_data.next_observations, next_actions
#                             ),
#                             dim=1,
#                         )
#                         for next_actions in next_actions_all
#                     ],
#                     dim=0,
#                 ).to(self.device)

#                 # print(f"{next_q_values_all}")
#                 # input()

#                 next_q_values_all, _ = th.min(next_q_values_all, dim=-1)
#                 # print(f"{next_q_values_all}")
#                 # input()

#                 next_actors = th.argmax(next_q_values_all, dim=0).unsqueeze(dim=1)
#                 # print(f"{next_actors}")
#                 # input()

#                 next_actors = next_actors.expand(
#                     -1, self.action_space.shape[0]
#                 ).unsqueeze(dim=1)
#                 # next_actors = th.cat((next_actors, next_actors), dim=1).unsqueeze(dim=1)
#                 # print(f"{next_actors}")
#                 # input()

#                 next_actions_all = th.stack(
#                     self.actor_target(replay_data.next_observations), dim=1
#                 )
#                 # print(f"{next_actions_all}")
#                 # input()

#                 next_actions = th.gather(
#                     next_actions_all, dim=1, index=next_actors.long()
#                 ).squeeze(1)
#                 # print(f"{next_actions}")
#                 # input()

#                 noise = replay_data.actions.clone().data.normal_(
#                     0, self.target_policy_noise
#                 )
#                 noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
#                 next_actions = (next_actions + noise).clamp(-1, 1)

#                 # Compute the next Q-values: min over all critics targets
#                 next_q_values = th.cat(
#                     self.critic_target(replay_data.next_observations, next_actions),
#                     dim=1,
#                 )
#                 # print(f"{next_q_values}")
#                 # input()

#                 next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
#                 # print(f"{next_q_values}")
#                 # input()

#                 target_q_values = (
#                     replay_data.rewards
#                     + (1 - replay_data.dones) * self.gamma * next_q_values
#                 )

#                 # For the actor losses
#                 # mu_all_target = self.actor_target(replay_data.observations)
#                 mu_all_target = self.actor(replay_data.observations)

#             # Get current Q-values estimates for each critic network
#             current_q_values = self.critic(
#                 replay_data.observations, replay_data.actions
#             )

#             # Compute critic loss
#             critic_loss = sum(
#                 F.mse_loss(current_q, target_q_values) for current_q in current_q_values
#             )
#             critic_losses.append(critic_loss.item())

#             # Optimize the critics
#             self.critic.optimizer.zero_grad()
#             critic_loss.backward()
#             self.critic.optimizer.step()

#             # Delayed policy updates
#             if self._n_updates % self.policy_delay == 0:
#                 # mu_all = self.actor(replay_data.observations)

#                 # dpg_loss, diversity_loss, distances_mean = 0, 0, 0
#                 # for targ_idx in range(self.n_actors):
#                 #     # Compute actor loss
#                 #     dpg_loss += -self.critic.q1_forward(
#                 #         replay_data.observations, mu_all[targ_idx]
#                 #     ).mean()
#                 #     for idx in range(self.n_actors):
#                 #         if targ_idx == idx:
#                 #             continue
#                 #         # # Compute diversity loss
#                 #         # diversity_loss += (1.0 / (self.n_actors - 1)) * self.mse_loss(
#                 #         #     mu_all_target[targ_idx], mu_all[idx]
#                 #         # )
#                 #         # diversity_loss += (1.0 / (self.n_actors - 1)) * th.exp(
#                 #         #     th.norm
#                 #         # (mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                 #         # ).mean()
#                 #         distances_mean += th.norm(
#                 #             mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1
#                 #         ).mean()
#                 #         # diversity_loss += th.exp(
#                 #         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                 #         # ).mean()
#                 #         # diversity_loss += -th.log(
#                 #         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                 #         #     ** 2
#                 #         # ).mean()
#                 #         diversity_loss += (
#                 #             th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                 #             ** 2
#                 #         ).mean()

#                 mu_all = self.actor(replay_data.observations)
#                 mu_all = th.stack(list(mu_all), dim=0).to(self.device)

#                 mu_dpg = self.actor(replay_data.observations)

#                 with th.no_grad():
#                     mu_all_target = th.stack(list(mu_all_target), dim=0).to(self.device)

#                 # mu_all_target[greedy_actor, th.arange(self.batch_size)] = 2 * th.ones(
#                 #     self.action_space.shape[0]
#                 # ).to(self.device)
#                 # mu_all[greedy_actor, th.arange(self.batch_size)] = 2 * th.ones(
#                 #     self.action_space.shape[0]
#                 # ).to(self.device)

#                 dpg_loss, diversity_loss, distances_mean = 0, 0, 0
#                 for targ_idx in range(self.n_actors):
#                     # Compute actor loss
#                     dpg_loss += -self.critic.q1_forward(
#                         replay_data.observations, mu_dpg[targ_idx]
#                     ).mean()
#                     for idx in range(self.n_actors):
#                         if targ_idx == idx:
#                             continue
#                         # If idx or targ_idx == greedy_actor
#                         # # Compute diversity loss
#                         # diversity_loss += (1.0 / (self.n_actors - 1)) * self.mse_loss(
#                         #     mu_all_target[targ_idx], mu_all[idx]
#                         # )
#                         # diversity_loss += (1.0 / (self.n_actors - 1)) * th.exp(
#                         #     th.norm
#                         # (mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                         # ).mean()
#                         distances_mean += th.norm(
#                             mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1
#                         ).mean()

#                         # diversity_loss += th.exp(
#                         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                         # ).mean()
#                         # diversity_loss += -th.log(
#                         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                         #     ** 2
#                         # ).mean()
#                         # diversity_loss += (
#                         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                         #     ** 2
#                         # ).mean()

#                         diversity_loss += (
#                             100
#                             * th.exp(
#                                 -th.norm(
#                                     (mu_all_target[targ_idx] - mu_all[idx]),
#                                     p=2,
#                                     dim=1,
#                                 )
#                             ).mean()
#                         )
#                         # diversity_loss += th.norm(
#                         #     (mu_all_target[targ_idx] - mu_all[idx]),
#                         #     p=2,
#                         #     dim=1,
#                         # ).mean()

#                 distances_mean /= self.n_actors

#                 if self.num_timesteps < self.learning_starts:
#                     actor_loss = diversity_loss
#                 else:
#                     actor_loss = dpg_loss

#                 actor_losses.append(actor_loss.item())
#                 diversity_losses.append(diversity_loss.item())
#                 distances_means.append(distances_mean.item())

#                 # Optimize the actor
#                 self.actor.optimizer.zero_grad()
#                 actor_loss.backward()
#                 self.actor.optimizer.step()

#                 polyak_update(
#                     self.critic.parameters(), self.critic_target.parameters(), self.tau
#                 )
#                 polyak_update(
#                     self.actor.parameters(), self.actor_target.parameters(), self.tau
#                 )
#                 # Copy running stats, see GH issue #996
#                 polyak_update(
#                     self.critic_batch_norm_stats,
#                     self.critic_batch_norm_stats_target,
#                     1.0,
#                 )
#                 polyak_update(
#                     self.actor_batch_norm_stats,
#                     self.actor_batch_norm_stats_target,
#                     1.0,
#                 )

#         self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
#         if len(actor_losses) > 0:
#             self.logger.record("train/actor_loss", np.mean(actor_losses))
#             self.logger.record("train/diversity_loss", np.mean(diversity_losses))
#             self.logger.record("train/distances_mean", np.mean(distances_means))
#         self.logger.record("train/critic_loss", np.mean(critic_losses))

#     def learn(
#         self: SelfGamid,
#         total_timesteps: int,
#         callback: MaybeCallback = None,
#         log_interval: int = 4,
#         tb_log_name: str = "Gamid",
#         reset_num_timesteps: bool = True,
#         progress_bar: bool = False,
#     ) -> SelfGamid:
#         return super().learn(
#             total_timesteps=total_timesteps,
#             callback=callback,
#             log_interval=log_interval,
#             tb_log_name=tb_log_name,
#             reset_num_timesteps=reset_num_timesteps,
#             progress_bar=progress_bar,
#         )

# ===== CAGrad ver. =====

# from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

# import numpy as np
# import torch as th
# from gym import spaces
# from torch.nn import functional as F

# from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
# from stable_baselines3.common.noise import ActionNoise
# from stable_baselines3 import DDPG
# from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
# from stable_baselines3.common.utils import (
#     get_linear_fn,
#     get_parameters_by_name,
#     polyak_update,
# )
# from stable_baselines3.common.policies import BasePolicy
# from rl_zoo3.gamid.policies import GamidPolicy, MlpPolicy
# from stable_baselines3.her.her_replay_buffer import HerReplayBuffer

# from sklearn.mixture import GaussianMixture
# from scipy.stats import multivariate_normal
# from scipy import linalg

# from torch.distributions import Categorical

# SelfGamid = TypeVar("SelfGamid", bound="Gamid")


# class Gamid(DDPG):

#     policy_aliases: Dict[str, Type[BasePolicy]] = {
#         "MlpPolicy": MlpPolicy,
#     }

#     def __init__(
#         self,
#         policy: Union[str, Type[GamidPolicy]],
#         env: Union[GymEnv, str],
#         learning_rate: Union[float, Schedule] = 0.001,
#         buffer_size: int = 1000000,
#         learning_starts: int = 100,
#         batch_size: int = 100,
#         tau: float = 0.005,
#         gamma: float = 0.99,
#         train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
#         gradient_steps: int = -1,
#         action_noise: Optional[ActionNoise] = None,
#         replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
#         replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
#         optimize_memory_usage: bool = False,
#         policy_delay: int = 1,
#         target_policy_noise: float = 0.1,
#         target_noise_clip: float = 0.0,
#         n_actors: int = 1,
#         n_critics: int = 1,
#         temperature_initial: float = 0.9,
#         temperature_final: float = 0.5,
#         temperature_fraction: float = 0.5,
#         exploration_fraction: float = 0.1,
#         exploration_initial_eps: float = 0.5,
#         exploration_final_eps: float = 0.1,
#         actors_loss_fn: Optional[str] = None,
#         tensorboard_log: Optional[str] = None,
#         policy_kwargs: Optional[Dict[str, Any]] = None,
#         verbose: int = 0,
#         seed: Optional[int] = None,
#         device: Union[th.device, str] = "auto",
#         _init_setup_model: bool = False,
#     ):

#         # Coefficients for actor 2 loss components
#         self.n_actors = n_actors
#         self.actors_loss_fn = actors_loss_fn

#         self._n_calls = 0

#         self.temperature_initial = temperature_initial
#         self.temperature_final = temperature_final
#         self.temperature_fraction = temperature_fraction

#         self.temperature = 0.0
#         self.temperature_schedule = None

#         # Epsilon greedy selection of actors
#         self.exploration_initial_eps = exploration_initial_eps
#         self.exploration_final_eps = exploration_final_eps
#         self.exploration_fraction = exploration_fraction

#         self.exploration_rate = 0.0
#         self.exploration_schedule = None

#         self.cagrad_c = self.temperature_initial

#         super().__init__(
#             policy=policy,
#             env=env,
#             learning_rate=learning_rate,
#             buffer_size=buffer_size,
#             learning_starts=learning_starts,
#             batch_size=batch_size,
#             tau=tau,
#             gamma=gamma,
#             train_freq=train_freq,
#             gradient_steps=gradient_steps,
#             action_noise=action_noise,
#             replay_buffer_class=replay_buffer_class,
#             replay_buffer_kwargs=replay_buffer_kwargs,
#             optimize_memory_usage=optimize_memory_usage,
#             tensorboard_log=tensorboard_log,
#             policy_kwargs=policy_kwargs,
#             verbose=verbose,
#             seed=seed,
#             device=device,
#             _init_setup_model=False,
#         )

#         self.policy_delay = policy_delay
#         self.target_noise_clip = target_noise_clip
#         self.target_policy_noise = target_policy_noise

#         self.policy_kwargs["n_actors"] = n_actors
#         self.policy_kwargs["n_critics"] = n_critics

#         self.actor, self.actor_target = None, None
#         self.critic, self.critic_target = None, None

#         self._setup_model()

#     def _setup_model(self) -> None:
#         self._setup_lr_schedule()
#         self.set_random_seed(self.seed)

#         # Use DictReplayBuffer if needed
#         if self.replay_buffer_class is None:
#             if isinstance(self.observation_space, spaces.Dict):
#                 self.replay_buffer_class = DictReplayBuffer
#             else:
#                 self.replay_buffer_class = ReplayBuffer

#         elif self.replay_buffer_class == HerReplayBuffer:
#             assert (
#                 self.env is not None
#             ), "You must pass an environment when using `HerReplayBuffer`"

#             # If using offline sampling, we need a classic replay buffer too
#             if self.replay_buffer_kwargs.get("online_sampling", True):
#                 replay_buffer = None
#             else:
#                 replay_buffer = DictReplayBuffer(
#                     self.buffer_size,
#                     self.observation_space,
#                     self.action_space,
#                     device=self.device,
#                     optimize_memory_usage=self.optimize_memory_usage,
#                 )

#             self.replay_buffer = HerReplayBuffer(
#                 self.env,
#                 self.buffer_size,
#                 device=self.device,
#                 replay_buffer=replay_buffer,
#                 **self.replay_buffer_kwargs,
#             )

#         if self.replay_buffer is None:
#             self.replay_buffer = self.replay_buffer_class(
#                 self.buffer_size,
#                 self.observation_space,
#                 self.action_space,
#                 device=self.device,
#                 n_envs=self.n_envs,
#                 optimize_memory_usage=self.optimize_memory_usage,
#                 **self.replay_buffer_kwargs,
#             )

#         self.policy = self.policy_class(  # pytype:disable=not-instantiable
#             self.observation_space,
#             self.action_space,
#             self.lr_schedule,
#             **self.policy_kwargs,  # pytype:disable=not-instantiable
#         )
#         self.policy = self.policy.to(self.device)

#         # Convert train freq parameter to TrainFreq object
#         self._convert_train_freq()

#         self._create_aliases()
#         # Running mean and running var
#         self.actor_batch_norm_stats = get_parameters_by_name(self.actor, ["running_"])
#         self.critic_batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
#         self.actor_batch_norm_stats_target = get_parameters_by_name(
#             self.actor_target, ["running_"]
#         )
#         self.critic_batch_norm_stats_target = get_parameters_by_name(
#             self.critic_target, ["running_"]
#         )

#         self.temperature_schedule = get_linear_fn(
#             self.temperature_initial, self.temperature_final, self.temperature_fraction
#         )

#         self.exploration_schedule = get_linear_fn(
#             self.exploration_initial_eps,
#             self.exploration_final_eps,
#             self.exploration_fraction,
#         )

#         self.greedy_actor_count = th.zeros(self.n_actors).to(self.device)

#     def _create_aliases(self) -> None:
#         self.actor = self.policy.actor
#         self.actor_target = self.policy.actor_target
#         self.critic = self.policy.critic
#         self.critic_target = self.policy.critic_target

#     def _compute_gradient(
#         self, losses, retain_graph: bool = True, allow_unused: bool = False
#     ):
#         """Compute the gradient."""
#         grad = []
#         for loss in losses:
#             grad.append(
#                 tuple(
#                     _grad.contiguous()
#                     for _grad in th.autograd.grad(
#                         loss,
#                         self.actor.parameters(),
#                         retain_graph=retain_graph,
#                         allow_unused=allow_unused,
#                     )
#                 )
#             )
#         return grad

#     def _set_gradient(self, grads):
#         """Set the gradients of the policy."""
#         idx = 0
#         for param in self.actor.parameters():
#             if param.requires_grad:
#                 num_param_elements = th.numel(param.grad)
#                 modified_grad = grads[idx : idx + num_param_elements]
#                 modified_grad = modified_grad.view_as(param.grad)

#                 param.grad = modified_grad  # Set the modified gradient

#                 idx += num_param_elements

#     def cagrad(self, grad_vec):
#         """Conflict-Averse Gradient Descent (CAGrad)."""
#         grads = grad_vec
#         grad_0 = grad_vec[0] + grad_vec[1]

#         w = th.ones(2, 1, requires_grad=True)
#         w_opt = th.optim.SGD([w], lr=2, momentum=0.5)

#         c = (grad_0.norm() * self.cagrad_c).to(grad_vec.device)

#         w_best = None
#         obj_best = np.inf
#         for i in range(21):
#             w_opt.zero_grad()
#             ww = th.softmax(w, 0).to(grad_vec.device)

#             gw = ww.t().mm(grad_vec)
#             g0 = grad_0.view(1, -1).to(grad_vec.device)

#             obj = (gw.mm(g0.t()) + c * gw.norm()).sum()

#             if obj.item() < obj_best:
#                 obj_best = obj.item()
#                 w_best = w.clone()
#             if i < 20:
#                 obj.backward()
#                 w_opt.step()

#         # print(f"w_best: {w_best}")
#         ww = th.softmax(w_best, 0).to(grad_vec.device)
#         gw = (ww.t().mm(grad_vec)).to(grad_vec.device)

#         gw_norm = gw.norm()

#         lmbda = gw_norm / c
#         g = (grad_0 + gw / lmbda).view(-1, 1).to(grads.device)

#         return g, ww[1].item()

#     def _on_step(self) -> None:
#         """
#         Update the exploration rate and target network if needed.
#         This method is called in ``collect_rollouts()`` after each step in the environment.
#         """
#         self._n_calls += 1

#         self.temperature = self.temperature_schedule(self._current_progress_remaining)
#         self.exploration_rate = self.exploration_schedule(
#             self._current_progress_remaining
#         )

#         actions = self.policy._predict(
#             th.tensor(self._last_obs).to(self.device), deterministic=True
#         )
#         q_values = []
#         for action in actions:
#             q_values.append(
#                 self.policy.critic_target(
#                     th.tensor(self._last_obs).to(self.device),
#                     action,
#                 )
#             )
#         actions = th.stack(list(actions), dim=0)
#         # Convert list to tensor
#         q_values = th.FloatTensor(q_values).squeeze(dim=1)
#         # Min Q from among multiple critics
#         q_values, _ = th.min(q_values, dim=-1)

#         _, actor_idx = th.max(q_values, dim=-1)
#         actor_idx = actor_idx.to(self.device)

#         self.greedy_actor_count[actor_idx] += 1
#         actor_spread = Categorical(
#             probs=self.greedy_actor_count / self.greedy_actor_count.sum()
#         ).entropy()

#         self.logger.record("train/temperature", self.temperature)
#         self.logger.record("train/epsilon", self.exploration_rate)
#         self.logger.record("train/actor_spread", actor_spread.item())

#     def predict(
#         self,
#         observation: Union[np.ndarray, Dict[str, np.ndarray]],
#         state: Optional[Tuple[np.ndarray, ...]] = None,
#         episode_start: Optional[np.ndarray] = None,
#         deterministic: bool = False,
#         exploration_rate: float = 0,
#     ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
#         """
#         Get the policy action from an observation (and optional hidden state).
#         Includes sugar-coating to handle different observations (e.g. normalizing images).

#         :param observation: the input observation
#         :param state: The last hidden states (can be None, used in recurrent policies)
#         :param episode_start: The last masks (can be None, used in recurrent policies)
#             this correspond to beginning of episodes,
#             where the hidden states of the RNN must be reset.
#         :param deterministic: Whether or not to return deterministic actions.
#         :param actor_selection_probs: Probabilities of selecting actors.
#         :return: the model's action and the next hidden state
#             (used in recurrent policies)
#         """
#         return self.policy.predict(
#             observation, state, episode_start, deterministic, exploration_rate
#         )

#     def _sample_action(
#         self,
#         learning_starts: int,
#         action_noise=None,
#         n_envs: int = 1,
#     ):
#         """
#         Sample an action according to the exploration policy.
#         This is either done by sampling the probability distribution of the policy,
#         or sampling a random action (from a uniform distribution over the action space)
#         or by adding noise to the deterministic output.

#         :param action_noise: Action noise that will be used for exploration
#             Required for deterministic policy (e.g. TD3). This can also be used
#             in addition to the stochastic policy for SAC.
#         :param learning_starts: Number of steps before learning for the warm-up phase.
#         :param n_envs:
#         :return: action to take in the environment
#             and scaled action that will be stored in the replay buffer.
#             The two differs when the action space is not normalized (bounds are not [-1, 1]).
#         """
#         # Select action randomly or according to policy
#         if self.num_timesteps < learning_starts and not (
#             self.use_sde and self.use_sde_at_warmup
#         ):
#             # Warmup phase
#             unscaled_action = np.array(
#                 [self.action_space.sample() for _ in range(n_envs)]
#             )
#         else:
#             # Note: when using continuous actions,
#             # we assume that the policy uses tanh to scale the action
#             # We use non-deterministic action in the case of SAC, for TD3, it does not matter
#             unscaled_action, _ = self.predict(
#                 self._last_obs,
#                 deterministic=False,
#                 exploration_rate=self.exploration_rate,
#             )

#         # Rescale the action from [low, high] to [-1, 1]
#         if isinstance(self.action_space, spaces.Box):
#             scaled_action = self.policy.scale_action(unscaled_action)

#             # Add noise to the action (improve exploration)
#             if action_noise is not None:
#                 scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

#             # We store the scaled action in the buffer
#             buffer_action = scaled_action
#             # scaled_action = np.power(scaled_action, 2) - np.power(scaled_action, 3) + 1
#             action = self.policy.unscale_action(scaled_action)
#         else:
#             # Discrete case, no need to normalize or clip
#             buffer_action = unscaled_action
#             action = buffer_action

#         return action, buffer_action

#     def log_loss(self, action_1, action_2):
#         """
#         Loss: Log distance between actions.
#         """
#         # loss = th.exp(-th.norm(action_1 - action_2, dim=0) ** 2 / 0.01).mean()
#         loss = -th.log(th.norm(action_1 - action_2, dim=1) + 0.01).mean()
#         return loss

#     def mse_loss(self, action_1, action_2):
#         """
#         Loss: MSE between actions.
#         """
#         loss = -th.norm(action_1 - action_2, p=2, dim=1).mean()
#         return loss

#     def train(self, gradient_steps: int, batch_size: int = 100) -> None:
#         # Switch to train mode (this affects batch norm / dropout)
#         self.policy.set_training_mode(True)

#         # Update learning rate according to lr schedule
#         self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])

#         actor_losses, diversity_losses, critic_losses, distances_means = [], [], [], []
#         for _ in range(gradient_steps):

#             self._n_updates += 1
#             # Sample replay buffer
#             replay_data = self.replay_buffer.sample(
#                 batch_size, env=self._vec_normalize_env
#             )

#             with th.no_grad():
#                 # Select action according to policy
#                 next_actions_all = th.stack(
#                     self.actor_target(replay_data.next_observations), dim=0
#                 )

#                 # print(f"{next_actions_all}")
#                 # input()

#                 next_q_values_all = th.stack(
#                     [
#                         th.cat(
#                             self.critic_target(
#                                 replay_data.next_observations, next_actions
#                             ),
#                             dim=1,
#                         )
#                         for next_actions in next_actions_all
#                     ],
#                     dim=0,
#                 ).to(self.device)

#                 # print(f"{next_q_values_all}")
#                 # input()

#                 next_q_values_all, _ = th.min(next_q_values_all, dim=-1)
#                 # print(f"{next_q_values_all}")
#                 # input()

#                 next_actors = th.argmax(next_q_values_all, dim=0).unsqueeze(dim=1)
#                 # print(f"{next_actors}")
#                 # input()

#                 next_actors = next_actors.expand(
#                     -1, self.action_space.shape[0]
#                 ).unsqueeze(dim=1)
#                 # next_actors = th.cat((next_actors, next_actors), dim=1).unsqueeze(dim=1)
#                 # print(f"{next_actors}")
#                 # input()

#                 next_actions_all = th.stack(
#                     self.actor_target(replay_data.next_observations), dim=1
#                 )
#                 # print(f"{next_actions_all}")
#                 # input()

#                 next_actions = th.gather(
#                     next_actions_all, dim=1, index=next_actors.long()
#                 ).squeeze(1)
#                 # print(f"{next_actions}")
#                 # input()

#                 noise = replay_data.actions.clone().data.normal_(
#                     0, self.target_policy_noise
#                 )
#                 noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
#                 next_actions = (next_actions + noise).clamp(-1, 1)

#                 # Compute the next Q-values: min over all critics targets
#                 next_q_values = th.cat(
#                     self.critic_target(replay_data.next_observations, next_actions),
#                     dim=1,
#                 )
#                 # print(f"{next_q_values}")
#                 # input()

#                 next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
#                 # print(f"{next_q_values}")
#                 # input()

#                 target_q_values = (
#                     replay_data.rewards
#                     + (1 - replay_data.dones) * self.gamma * next_q_values
#                 )

#                 # For the actor losses
#                 # mu_all_target = self.actor_target(replay_data.observations)
#                 mu_all_target = self.actor(replay_data.observations)

#             # Get current Q-values estimates for each critic network
#             current_q_values = self.critic(
#                 replay_data.observations, replay_data.actions
#             )

#             # Compute critic loss
#             critic_loss = sum(
#                 F.mse_loss(current_q, target_q_values) for current_q in current_q_values
#             )
#             critic_losses.append(critic_loss.item())

#             # Optimize the critics
#             self.critic.optimizer.zero_grad()
#             critic_loss.backward()
#             self.critic.optimizer.step()

#             # Delayed policy updates
#             if self._n_updates % self.policy_delay == 0:
#                 # mu_all = self.actor(replay_data.observations)

#                 # dpg_loss, diversity_loss, distances_mean = 0, 0, 0
#                 # for targ_idx in range(self.n_actors):
#                 #     # Compute actor loss
#                 #     dpg_loss += -self.critic.q1_forward(
#                 #         replay_data.observations, mu_all[targ_idx]
#                 #     ).mean()
#                 #     for idx in range(self.n_actors):
#                 #         if targ_idx == idx:
#                 #             continue
#                 #         # # Compute diversity loss
#                 #         # diversity_loss += (1.0 / (self.n_actors - 1)) * self.mse_loss(
#                 #         #     mu_all_target[targ_idx], mu_all[idx]
#                 #         # )
#                 #         # diversity_loss += (1.0 / (self.n_actors - 1)) * th.exp(
#                 #         #     th.norm
#                 #         # (mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                 #         # ).mean()
#                 #         distances_mean += th.norm(
#                 #             mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1
#                 #         ).mean()
#                 #         # diversity_loss += th.exp(
#                 #         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                 #         # ).mean()
#                 #         # diversity_loss += -th.log(
#                 #         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                 #         #     ** 2
#                 #         # ).mean()
#                 #         diversity_loss += (
#                 #             th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                 #             ** 2
#                 #         ).mean()

#                 mu_all = self.actor(replay_data.observations)
#                 mu_dpg = self.actor(replay_data.observations)

#                 with th.no_grad():
#                     q_values = []
#                     for mu in mu_all:
#                         # Greedy actor
#                         q_values.append(
#                             self.critic.q1_forward(replay_data.observations, mu)
#                         )
#                     # Convert list to tensor
#                     q_values = th.tensor(th.stack(q_values)).squeeze()
#                     _, greedy_actor = th.max(q_values, dim=-2)
#                     greedy_actor = greedy_actor.to(self.device)

#                 mu_all = th.stack(list(mu_all), dim=0).to(self.device)

#                 with th.no_grad():
#                     mu_all_target = th.stack(list(mu_all_target), dim=0).to(self.device)

#                 # mu_all_target[greedy_actor, th.arange(self.batch_size)] = 2 * th.ones(
#                 #     self.action_space.shape[0]
#                 # ).to(self.device)
#                 # mu_all[greedy_actor, th.arange(self.batch_size)] = 2 * th.ones(
#                 #     self.action_space.shape[0]
#                 # ).to(self.device)

#                 dpg_loss, diversity_loss, distances_mean = 0, 0, 0
#                 for targ_idx in range(self.n_actors):
#                     # Compute actor loss
#                     dpg_loss += -self.critic.q1_forward(
#                         replay_data.observations, mu_dpg[targ_idx]
#                     ).mean()
#                     for idx in range(self.n_actors):
#                         if targ_idx == idx:
#                             continue
#                         # If idx or targ_idx == greedy_actor
#                         # # Compute diversity loss
#                         # diversity_loss += (1.0 / (self.n_actors - 1)) * self.mse_loss(
#                         #     mu_all_target[targ_idx], mu_all[idx]
#                         # )
#                         # diversity_loss += (1.0 / (self.n_actors - 1)) * th.exp(
#                         #     th.norm
#                         # (mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                         # ).mean()
#                         distances_mean += th.norm(
#                             mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1
#                         ).mean()

#                         # diversity_loss += th.exp(
#                         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                         # ).mean()
#                         # diversity_loss += -th.log(
#                         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                         #     ** 2
#                         # ).mean()
#                         # diversity_loss += (
#                         #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
#                         #     ** 2
#                         # ).mean()
#                         mask = (
#                             th.logical_not(
#                                 th.logical_or(
#                                     greedy_actor.eq(idx), greedy_actor.eq(targ_idx)
#                                 )
#                             )
#                             .float()
#                             .to(self.device)
#                         )
#                         if targ_idx == idx:
#                             continue

#                         diversity_loss += (
#                             mask
#                             * 100
#                             * th.exp(
#                                 -th.norm(
#                                     (mu_all_target[targ_idx] - mu_all[idx]),
#                                     p=2,
#                                     dim=1,
#                                 )
#                             )
#                         ).mean()

#                 distances_mean /= self.n_actors
#                 # diversity_loss = -th.log(1 / diversity_loss)

#                 actor_loss = (1 - self.cagrad_c) * dpg_loss
#                 actor_losses.append(actor_loss.item())

#                 if self.cagrad_c == 0:
#                     self.actor.optimizer.zero_grad()
#                     actor_loss.backward()
#                 else:
#                     self.actor.optimizer.zero_grad()
#                     actor_loss.backward(retain_graph=True)

#                     self.actor.optimizer.zero_grad()

#                     diversity_loss = self.cagrad_c * diversity_loss
#                     diversity_loss.backward(retain_graph=True)

#                     grad = self._compute_gradient([diversity_loss, dpg_loss])
#                     grad_vec = th.cat(
#                         list(
#                             map(
#                                 lambda x: th.nn.utils.parameters_to_vector(x).unsqueeze(
#                                     0
#                                 ),
#                                 grad,
#                             )
#                         ),
#                         dim=0,
#                     )
#                     regularized_cagrad, _ = self.cagrad(grad_vec)
#                     # regularized_cagrad = th.clip(
#                     #     regularized_cagrad, min=None, max=100.0
#                     # )
#                     regularized_cagrad = th.nan_to_num(regularized_cagrad, nan=0.0)

#                     self._set_gradient(regularized_cagrad)

#                 diversity_losses.append(diversity_loss.item())
#                 distances_means.append(distances_mean.item())

#                 # Optimize the actor
#                 self.actor.optimizer.step()

#                 polyak_update(
#                     self.critic.parameters(), self.critic_target.parameters(), self.tau
#                 )
#                 polyak_update(
#                     self.actor.parameters(), self.actor_target.parameters(), self.tau
#                 )
#                 # Copy running stats, see GH issue #996
#                 polyak_update(
#                     self.critic_batch_norm_stats,
#                     self.critic_batch_norm_stats_target,
#                     1.0,
#                 )
#                 polyak_update(
#                     self.actor_batch_norm_stats,
#                     self.actor_batch_norm_stats_target,
#                     1.0,
#                 )

#         self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
#         if len(actor_losses) > 0:
#             self.logger.record("train/actor_loss", np.mean(actor_losses))
#             self.logger.record("train/diversity_loss", np.mean(diversity_losses))
#             self.logger.record("train/distances_mean", np.mean(distances_means))
#         self.logger.record("train/critic_loss", np.mean(critic_losses))

#     def learn(
#         self: SelfGamid,
#         total_timesteps: int,
#         callback: MaybeCallback = None,
#         log_interval: int = 4,
#         tb_log_name: str = "Gamid",
#         reset_num_timesteps: bool = True,
#         progress_bar: bool = False,
#     ) -> SelfGamid:
#         return super().learn(
#             total_timesteps=total_timesteps,
#             callback=callback,
#             log_interval=log_interval,
#             tb_log_name=tb_log_name,
#             reset_num_timesteps=reset_num_timesteps,
#             progress_bar=progress_bar,
#         )

# ===== KL div ver =====
from typing import Any, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union

import io
import os
import pathlib
import numpy as np
import scipy
import torch as th
from gym import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3 import DDPG
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import (
    get_linear_fn,
    get_parameters_by_name,
    polyak_update,
)
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import recursive_getattr, save_to_zip_file
from rl_zoo3.gamid.policies import GamidPolicy, MlpPolicy
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer

from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.type_aliases import (
    GymEnv,
    MaybeCallback,
    RolloutReturn,
    Schedule,
    TrainFreq,
    TrainFrequencyUnit,
)
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps

SelfGamid = TypeVar("SelfGamid", bound="Gamid")


class Gamid(DDPG):

    policy_aliases: Dict[str, Type[BasePolicy]] = {
        "MlpPolicy": MlpPolicy,
    }

    def __init__(
        self,
        policy: Union[str, Type[GamidPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 0.001,
        buffer_size: int = 1000000,
        learning_starts: int = 100,
        batch_size: int = 100,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
        gradient_steps: int = -1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        policy_delay: int = 1,
        target_policy_noise: float = 0.1,
        target_noise_clip: float = 0.0,
        n_actors: int = 1,
        n_critics: int = 1,
        temperature_initial: float = 0.9,
        temperature_final: float = 0.5,
        temperature_fraction: float = 0.5,
        exploration_fraction: float = 0.1,
        exploration_initial_eps: float = 0.5,
        exploration_final_eps: float = 0.1,
        actors_loss_fn: Optional[str] = None,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        save_path: Optional[str] = "",
    ):

        # Coefficients for actor 2 loss components
        self.n_actors = n_actors
        self.actors_loss_fn = actors_loss_fn

        self._n_calls = 0

        self.temperature_initial = temperature_initial
        self.temperature_final = temperature_final
        self.temperature_fraction = temperature_fraction

        self.temperature = 0.0
        self.temperature_schedule = None

        # Epsilon greedy selection of actors
        self.exploration_initial_eps = exploration_initial_eps
        self.exploration_final_eps = exploration_final_eps
        self.exploration_fraction = exploration_fraction

        self.exploration_rate = 0.0
        self.exploration_schedule = None

        self.occupancy = None
        self.save_path = save_path

        self._n_trains = 0
        self.std_history = []
        self.std_alt_history = []
        self.q_val_history = []
        self.q_val_norm_history = []
        self.q_fn_history = []
        self.act_samp_history = []
        self.act_pdf_history = []
        self.actor_loc_history = []

        super().__init__(
            policy=policy,
            env=env,
            learning_rate=learning_rate,
            buffer_size=buffer_size,
            learning_starts=learning_starts,
            batch_size=batch_size,
            tau=tau,
            gamma=gamma,
            train_freq=train_freq,
            gradient_steps=gradient_steps,
            action_noise=action_noise,
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            optimize_memory_usage=optimize_memory_usage,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            seed=seed,
            device=device,
            _init_setup_model=False,
        )

        self.policy_delay = policy_delay
        self.target_noise_clip = target_noise_clip
        self.target_policy_noise = target_policy_noise

        self.policy_kwargs["n_actors"] = n_actors
        self.policy_kwargs["n_critics"] = n_critics

        self.actor, self.actor_target = None, None
        self.critic, self.critic_target = None, None

        self._setup_model()

    def _setup_model(self) -> None:
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        # Use DictReplayBuffer if needed
        if self.replay_buffer_class is None:
            if isinstance(self.observation_space, spaces.Dict):
                self.replay_buffer_class = DictReplayBuffer
            else:
                self.replay_buffer_class = ReplayBuffer

        elif self.replay_buffer_class == HerReplayBuffer:
            assert (
                self.env is not None
            ), "You must pass an environment when using `HerReplayBuffer`"

            # If using offline sampling, we need a classic replay buffer too
            if self.replay_buffer_kwargs.get("online_sampling", True):
                replay_buffer = None
            else:
                replay_buffer = DictReplayBuffer(
                    self.buffer_size,
                    self.observation_space,
                    self.action_space,
                    device=self.device,
                    optimize_memory_usage=self.optimize_memory_usage,
                )

            self.replay_buffer = HerReplayBuffer(
                self.env,
                self.buffer_size,
                device=self.device,
                replay_buffer=replay_buffer,
                **self.replay_buffer_kwargs,
            )

        if self.replay_buffer is None:
            self.replay_buffer = self.replay_buffer_class(
                self.buffer_size,
                self.observation_space,
                self.action_space,
                device=self.device,
                n_envs=self.n_envs,
                optimize_memory_usage=self.optimize_memory_usage,
                **self.replay_buffer_kwargs,
            )

        self.policy = self.policy_class(  # pytype:disable=not-instantiable
            self.observation_space,
            self.action_space,
            self.lr_schedule,
            **self.policy_kwargs,  # pytype:disable=not-instantiable
        )
        self.policy = self.policy.to(self.device)

        # Convert train freq parameter to TrainFreq object
        self._convert_train_freq()

        self._create_aliases()
        # Running mean and running var
        self.actor_batch_norm_stats = get_parameters_by_name(self.actor, ["running_"])
        self.critic_batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
        self.actor_batch_norm_stats_target = get_parameters_by_name(
            self.actor_target, ["running_"]
        )
        self.critic_batch_norm_stats_target = get_parameters_by_name(
            self.critic_target, ["running_"]
        )

        self.temperature_schedule = get_linear_fn(
            self.temperature_initial, self.temperature_final, self.temperature_fraction
        )

        self.exploration_schedule = get_linear_fn(
            self.exploration_initial_eps,
            self.exploration_final_eps,
            self.exploration_fraction,
        )

    def _create_aliases(self) -> None:
        self.actor = self.policy.actor
        self.actor_target = self.policy.actor_target
        self.critic = self.policy.critic
        self.critic_target = self.policy.critic_target

    def _moving_average(self, arr, n=10):
        if len(arr) > 0:
            return np.convolve(arr, np.ones(n) / n, mode="valid")[-1]
        else:
            return 0

    def _get_gamid_std(self):
        actions_sampled, q_values_sampled = [], []
        with th.no_grad():
            obs, _ = self.policy.obs_to_tensor(self._last_obs)
            for _ in range(100):
                act, _ = self._sample_action(
                    learning_starts=0, action_noise=self.action_noise
                )
                act = th.tensor(act).to(self.device)
                q_value = self.critic.q1_forward(
                    obs.to(th.float32), act.to(th.float32)
                ).item()
                actions_sampled.append(act)
                q_values_sampled.append(q_value)
            # Convert list to tensor
            q_values_sampled = th.FloatTensor(q_values_sampled)
            q_values_sampled = q_values_sampled / q_values_sampled.max()
        return (
            th.std(th.cat(actions_sampled, dim=0), dim=0).mean().item(),
            q_values_sampled.mean().item(),
        )

    # def _compute_mean_q_values(self, observations):
    #     actions_sampled, q_values_sampled = [], []
    #     observation = observations.cpu().detach().numpy()[0]
    #     with th.no_grad():
    #         for _ in range(50):
    #             unscaled_action, _ = self.predict(
    #                 observation, exploration_rate=self.exploration_rate
    #             )
    #             action = self.policy.scale_action(unscaled_action)

    #             # Add noise to the action (improve exploration)
    #             if self.action_noise is not None:
    #                 action = np.clip(action + self.action_noise(), -1, 1)

    #             actions_sampled.append(action)
    #             action = th.tensor(action).to(self.device)
    #             q_values = self.critic.q1_forward(
    #                 observations[0].unsqueeze(0).to(th.float32),
    #                 action.unsqueeze(0).to(th.float32),
    #             )
    #             q_values_sampled.append(q_values.item())
    #     q_values_sampled = q_values_sampled / np.max(q_values_sampled)

    #     return np.mean(np.std(actions_sampled)), np.mean(q_values_sampled)

    def _compute_mean_q_values(self, observations):
        (
            actions_sampled,
            actions_alt_sampled,
            q_values_sampled,
            q_values_norm_sampled,
        ) = ([], [], [], [])
        observations_numpy = observations.cpu().detach().numpy()
        with th.no_grad():
            for i, observation in enumerate(observations_numpy[:20]):
                actions_sampled_obs, q_values_sampled_obs = [], []
                for _ in range(100):
                    unscaled_action, _ = self.predict(
                        observation, exploration_rate=self.exploration_rate
                    )
                    action = self.policy.scale_action(unscaled_action)

                    # Add noise to the action (improve exploration)
                    if self.action_noise is not None:
                        action = np.clip(action + self.action_noise(), -1, 1)

                    actions_sampled_obs.append(action)
                    action = th.tensor(action).to(self.device)
                    # q_values = self.critic.q1_forward(
                    #     observations[i].unsqueeze(0).to(th.float32),
                    #     action.unsqueeze(0).to(th.float32),
                    # )
                    q_values = self.critic.q1_forward(
                        observations[i].unsqueeze(0).to(th.float32),
                        action.to(th.float32),
                    )
                    q_values_sampled_obs.append(q_values.item())

                q_values_norm_sampled_obs = (
                    q_values_sampled_obs - np.min(q_values_sampled_obs)
                ) / (np.max(q_values_sampled_obs) - np.min(q_values_sampled_obs))

                actions_alt_sampled.append(
                    np.mean(
                        (
                            np.array(actions_sampled_obs)
                            - np.mean(actions_sampled_obs, axis=0)
                        )
                        ** 2
                    )
                    ** 0.5
                )
                actions_sampled.append(np.mean(np.std(actions_sampled_obs, axis=0)))
                q_values_sampled.append(np.mean(q_values_sampled_obs))
                q_values_norm_sampled.append(np.mean(q_values_norm_sampled_obs))

        print(
            np.mean(actions_sampled),
            np.mean(actions_alt_sampled),
            np.mean(q_values_sampled),
            np.mean(q_values_norm_sampled),
        )
        return (
            np.mean(actions_sampled),
            np.mean(actions_alt_sampled),
            np.mean(q_values_sampled),
            np.mean(q_values_norm_sampled),
        )

    def _plot_q_values(self, observations):
        (
            actions_sampled,
            q_values_sampled,
        ) = ([], [])
        actions_plot = np.linspace(-1.0, 1.0, 100)
        actions_sampled = []
        q_values_sampled = []
        actor_loc = []
        with th.no_grad():
            for observation in observations[:1]:
                q_values_sampled_obs = []
                actor_loc.append(
                    [x.cpu().item() for x in self.actor(observation.unsqueeze(0))]
                )
                for action in actions_plot:
                    action = th.tensor(action).to(self.device)
                    q_values = self.critic.q1_forward(
                        observation.unsqueeze(0).to(th.float32),
                        action.unsqueeze(0).unsqueeze(0).to(th.float32),
                    )
                    q_values_sampled_obs.append(q_values.item())
                q_values_sampled.append(q_values_sampled_obs)
                for _ in range(100):
                    unscaled_action, _ = self.predict(
                        observation.detach().cpu().numpy(),
                        exploration_rate=self.exploration_rate,
                    )
                    action = self.policy.scale_action(unscaled_action)

                    # Add noise to the action (improve exploration)
                    if self.action_noise is not None:
                        action = np.clip(action + self.action_noise(), -1, 1)

                    actions_sampled.append(action)
                actions_pdf = scipy.stats.norm(
                    np.mean(actions_sampled, axis=0), 0.1
                ).pdf(actions_sampled)
        print(actor_loc)
        return q_values_sampled, actions_sampled, actions_pdf, actor_loc

    def _on_step(self) -> None:
        """
        Update the exploration rate and target network if needed.
        This method is called in ``collect_rollouts()`` after each step in the environment.
        """
        self._n_calls += 1

        self.temperature = self.temperature_schedule(self._current_progress_remaining)
        self.exploration_rate = self.exploration_schedule(
            self._current_progress_remaining
        )
        self.logger.record("train/temperature", self.temperature)
        self.logger.record("train/epsilon", self.exploration_rate)

    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        state: Optional[Tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
        exploration_rate: float = 0,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        """
        Get the policy action from an observation (and optional hidden state).
        Includes sugar-coating to handle different observations (e.g. normalizing images).

        :param observation: the input observation
        :param state: The last hidden states (can be None, used in recurrent policies)
        :param episode_start: The last masks (can be None, used in recurrent policies)
            this correspond to beginning of episodes,
            where the hidden states of the RNN must be reset.
        :param deterministic: Whether or not to return deterministic actions.
        :param actor_selection_probs: Probabilities of selecting actors.
        :return: the model's action and the next hidden state
            (used in recurrent policies)
        """
        return self.policy.predict(
            observation, state, episode_start, deterministic, exploration_rate
        )

    def _sample_action(
        self,
        learning_starts: int,
        action_noise=None,
        n_envs: int = 1,
    ):
        """
        Sample an action according to the exploration policy.
        This is either done by sampling the probability distribution of the policy,
        or sampling a random action (from a uniform distribution over the action space)
        or by adding noise to the deterministic output.

        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param n_envs:
        :return: action to take in the environment
            and scaled action that will be stored in the replay buffer.
            The two differs when the action space is not normalized (bounds are not [-1, 1]).
        """
        # Select action randomly or according to policy
        if self.num_timesteps < learning_starts and not (
            self.use_sde and self.use_sde_at_warmup
        ):
            # Warmup phase
            unscaled_action = np.array(
                [self.action_space.sample() for _ in range(n_envs)]
            )

        else:
            # Note: when using continuous actions,
            # we assume that the policy uses tanh to scale the action
            # We use non-deterministic action in the case of SAC, for TD3, it does not matter
            unscaled_action, _ = self.predict(
                self._last_obs,
                deterministic=False,
                exploration_rate=self.exploration_rate,
            )

        # Rescale the action from [low, high] to [-1, 1]
        if isinstance(self.action_space, spaces.Box):
            scaled_action = self.policy.scale_action(unscaled_action)

            # Add noise to the action (improve exploration)
            if action_noise is not None:
                scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

            # We store the scaled action in the buffer
            buffer_action = scaled_action
            action = self.policy.unscale_action(scaled_action)
        else:
            # Discrete case, no need to normalize or clip
            buffer_action = unscaled_action
            action = buffer_action
        return action, buffer_action

    def log_loss(self, action_1, action_2):
        """
        Loss: Log distance between actions.
        """
        # loss = th.exp(-th.norm(action_1 - action_2, dim=0) ** 2 / 0.01).mean()
        loss = -th.log(th.norm(action_1 - action_2, dim=1) + 0.01).mean()
        return loss

    def mse_loss(self, action_1, action_2):
        """
        Loss: MSE between actions.
        """
        loss = -th.norm(action_1 - action_2, p=2, dim=1).mean()
        return loss

    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)

        # Update learning rate according to lr schedule
        self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])

        """
        # if self._n_trains % 100 == 0:
        replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
        # std, std_alt, mean_q, mean_q_norm = self._compute_mean_q_values(
        #     replay_data.observations
        # )
        # q_fn, act_samp, act_pdf = self._plot_q_values(replay_data.observations)
        observations = th.tensor(
            np.load(
                "/home/sheelabhadra/Pi-Star/rl-baselines3-zoo/logs/td3/Pendulum-v1_3/Pendulum-v1_obs_samp.npy"
            )
        ).to(self.device)
        # print(observations)
        std, std_alt, mean_q, mean_q_norm = self._compute_mean_q_values(observations)
        q_fn, act_samp, act_pdf, actor_loc = self._plot_q_values(observations)

        # std, mean_q = self._get_gamid_std()
        self.std_history.append(std)
        self.std_alt_history.append(std_alt)
        self.q_val_history.append(mean_q)
        self.q_val_norm_history.append(mean_q_norm)
        self.q_fn_history.append(q_fn)
        self.act_samp_history.append(act_samp)
        self.act_pdf_history.append(act_pdf)
        self.actor_loc_history.append(actor_loc)
        """

        actor_losses, diversity_losses, critic_losses = [], [], []
        for _ in range(gradient_steps):

            self._n_updates += 1
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(
                batch_size, env=self._vec_normalize_env
            )

            with th.no_grad():
                # Select action according to policy
                next_actions_all = th.stack(
                    self.actor_target(replay_data.next_observations), dim=0
                )

                next_q_values_all = th.stack(
                    [
                        th.cat(
                            self.critic_target(
                                replay_data.next_observations, next_actions
                            ),
                            dim=1,
                        )
                        for next_actions in next_actions_all
                    ],
                    dim=0,
                ).to(self.device)

                next_q_values_all, _ = th.min(next_q_values_all, dim=-1)

                next_actors = th.argmax(next_q_values_all, dim=0).unsqueeze(dim=1)

                next_actors = next_actors.expand(
                    -1, self.action_space.shape[0]
                ).unsqueeze(dim=1)
                # next_actors = th.cat((next_actors, next_actors), dim=1).unsqueeze(dim=1)

                next_actions_all = th.stack(
                    self.actor_target(replay_data.next_observations), dim=1
                )

                next_actions = th.gather(
                    next_actions_all, dim=1, index=next_actors.long()
                ).squeeze(1)

                noise = replay_data.actions.clone().data.normal_(
                    0, self.target_policy_noise
                )
                noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
                next_actions = (next_actions + noise).clamp(-1, 1)

                # Compute the next Q-values: min over all critics targets
                next_q_values = th.cat(
                    self.critic_target(replay_data.next_observations, next_actions),
                    dim=1,
                )

                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)

                target_q_values = (
                    replay_data.rewards
                    + (1 - replay_data.dones) * self.gamma * next_q_values
                )

                # For the actor losses
                # mu_all_target = self.actor_target(replay_data.observations)
            mu_all_target = self.actor(replay_data.observations)

            # Get current Q-values estimates for each critic network
            current_q_values = self.critic(
                replay_data.observations, replay_data.actions
            )

            # Compute critic loss
            critic_loss = sum(
                F.mse_loss(current_q, target_q_values) for current_q in current_q_values
            )
            critic_losses.append(critic_loss.item())

            # Optimize the critics
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()

            # Delayed policy updates
            if self._n_updates % self.policy_delay == 0:
                mu_all = self.actor(replay_data.observations)

                dpg_loss, diversity_loss = 0, 0
                for targ_idx in range(self.n_actors):
                    # Compute actor loss
                    dpg_loss += -self.critic.q1_forward(
                        replay_data.observations, mu_all[targ_idx]
                    ).mean()
                    for idx in range(self.n_actors):
                        if targ_idx == idx:
                            continue
                        # # Compute diversity loss
                        # diversity_loss += (1.0 / (self.n_actors - 1)) * self.mse_loss(
                        #     mu_all_target[targ_idx], mu_all[idx]
                        # )
                        # diversity_loss += (1.0 / (self.n_actors - 1)) * th.exp(
                        #     th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
                        # ).mean()
                        diversity_loss += th.exp(
                            th.norm(mu_all_target[targ_idx] - mu_all[idx], p=2, dim=1)
                        ).mean()
                diversity_loss = th.log(1 / diversity_loss)

                actor_loss = th.add(
                    (1 - self.temperature) * dpg_loss, self.temperature * diversity_loss
                )

                actor_losses.append(actor_loss.item())
                diversity_losses.append(diversity_loss.item())

                # Optimize the actor
                self.actor.optimizer.zero_grad()
                actor_loss.backward()
                self.actor.optimizer.step()

                polyak_update(
                    self.critic.parameters(), self.critic_target.parameters(), self.tau
                )
                polyak_update(
                    self.actor.parameters(), self.actor_target.parameters(), self.tau
                )
                # Copy running stats, see GH issue #996
                polyak_update(
                    self.critic_batch_norm_stats,
                    self.critic_batch_norm_stats_target,
                    1.0,
                )
                polyak_update(
                    self.actor_batch_norm_stats,
                    self.actor_batch_norm_stats_target,
                    1.0,
                )
        self._n_trains += 1

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        if len(actor_losses) > 0:
            self.logger.record("train/actor_loss", np.mean(actor_losses))
            self.logger.record("train/diversity_loss", np.mean(diversity_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        # self.logger.record("train/std", self._moving_average(self.std_history))
        # self.logger.record("train/q_mean", self._moving_average(self.q_val_history))

    def learn(
        self: SelfGamid,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        tb_log_name: str = "Gamid",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfGamid:
        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )

    def save(
        self,
        path: Union[str, pathlib.Path, io.BufferedIOBase],
        exclude: Optional[Iterable[str]] = None,
        include: Optional[Iterable[str]] = None,
    ) -> None:
        """
        Save all the attributes of the object and the model parameters in a zip-file.

        :param path: path to the file where the rl agent should be saved
        :param exclude: name of parameters that should be excluded in addition to the default ones
        :param include: name of parameters that might be excluded but should be included anyway
        """
        # Copy parameter list so we don't mutate the original dict
        data = self.__dict__.copy()

        # Exclude is union of specified parameters (if any) and standard exclusions
        if exclude is None:
            exclude = []
        exclude = set(exclude).union(self._excluded_save_params())

        # Do not exclude params if they are specifically included
        if include is not None:
            exclude = exclude.difference(include)

        state_dicts_names, torch_variable_names = self._get_torch_save_params()
        all_pytorch_variables = state_dicts_names + torch_variable_names
        for torch_var in all_pytorch_variables:
            # We need to get only the name of the top most module as we'll remove that
            var_name = torch_var.split(".")[0]
            # Any params that are in the save vars must not be saved by data
            exclude.add(var_name)

        # Remove parameter entries of parameters which are to be excluded
        for param_name in exclude:
            data.pop(param_name, None)

        # Build dict of torch variables
        pytorch_variables = None
        if torch_variable_names is not None:
            pytorch_variables = {}
            for name in torch_variable_names:
                attr = recursive_getattr(self, name)
                pytorch_variables[name] = attr

        # Build dict of state_dicts
        params_to_save = self.get_parameters()

        # Save custom data
        np.save(f"{path}_std", self.std_history)
        np.save(f"{path}_std_alt", self.std_alt_history)
        np.save(f"{path}_q", self.q_val_history)
        np.save(f"{path}_q_norm", self.q_val_norm_history)
        np.save(f"{path}_q_fn", self.q_fn_history)
        np.save(f"{path}_act_samp", self.act_samp_history)
        np.save(f"{path}_act_pdf", self.act_pdf_history)
        np.save(f"{path}_act_loc", self.actor_loc_history)

        save_to_zip_file(
            path, data=data, params=params_to_save, pytorch_variables=pytorch_variables
        )

    def save_occupancy(self, episode):
        occupancy_folder = os.path.join(self.save_path, "occupancy")
        os.makedirs(occupancy_folder, exist_ok=True)
        # Save the file in the occupancy folder
        np.save(os.path.join(occupancy_folder, f"{episode}.npy"), self.occupancy)

    def collect_rollouts(
        self,
        env,
        callback,
        train_freq,
        replay_buffer,
        action_noise=None,
        learning_starts=0,
        log_interval=None,
    ):
        """
        Collect experiences and store them into a ``ReplayBuffer``.

        :param env: The training environment
        :param callback: Callback that will be called at each step
            (and at the beginning and end of the rollout)
        :param train_freq: How much experience to collect
            by doing rollouts of current policy.
            Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
            or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
            with ``<n>`` being an integer greater than 0.
        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param replay_buffer:
        :param log_interval: Log data every ``log_interval`` episodes
        :return:
        """
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        num_collected_steps, num_collected_episodes = 0, 0

        assert isinstance(env, VecEnv), "You must pass a VecEnv"
        assert train_freq.frequency > 0, "Should at least collect one step or episode."

        if env.num_envs > 1:
            assert (
                train_freq.unit == TrainFrequencyUnit.STEP
            ), "You must use only one env when doing episodic training."

        # Vectorize action noise if needed
        if (
            action_noise is not None
            and env.num_envs > 1
            and not isinstance(action_noise, VectorizedActionNoise)
        ):
            action_noise = VectorizedActionNoise(action_noise, env.num_envs)

        if self.use_sde:
            self.actor.reset_noise(env.num_envs)

        callback.on_rollout_start()
        continue_training = True

        while should_collect_more_steps(
            train_freq, num_collected_steps, num_collected_episodes
        ):
            if (
                self.use_sde
                and self.sde_sample_freq > 0
                and num_collected_steps % self.sde_sample_freq == 0
            ):
                # Sample a new noise matrix
                self.actor.reset_noise(env.num_envs)

            # Select action randomly or according to policy
            actions, buffer_actions = self._sample_action(
                learning_starts, action_noise, env.num_envs
            )

            # Rescale and perform action
            new_obs, rewards, dones, infos = env.step(actions)
            # print("step done")
            # img = env.render("rgb_array")

            # plt.imshow(img)
            # plt.show()
            if "maze" in self.save_path:
                cur_obs, _ = self.policy.obs_to_tensor(self._last_obs)
                if isinstance(self.occupancy, np.ndarray):
                    self.occupancy += counter(cur_obs, self.device)
                else:
                    self.occupancy = counter(cur_obs, self.device)

            self.num_timesteps += env.num_envs
            num_collected_steps += 1

            # Give access to local variables
            callback.update_locals(locals())
            # Only stop training if return value is False, not when it is None.
            if callback.on_step() is False:
                return RolloutReturn(
                    num_collected_steps * env.num_envs,
                    num_collected_episodes,
                    continue_training=False,
                )

            # Retrieve reward and episode length if using Monitor wrapper
            self._update_info_buffer(infos, dones)

            # Store data in replay buffer (normalized action and unnormalized observation)
            self._store_transition(
                replay_buffer, buffer_actions, new_obs, rewards, dones, infos
            )

            self._update_current_progress_remaining(
                self.num_timesteps, self._total_timesteps
            )

            # For DQN, check if the target network should be updated
            # and update the exploration schedule
            # For SAC/TD3, the update is dones as the same time as the gradient update
            # see https://github.com/hill-a/stable-baselines/issues/900
            self._on_step()

            for idx, done in enumerate(dones):
                if done:
                    # Update stats
                    num_collected_episodes += 1
                    self._episode_num += 1

                    if action_noise is not None:
                        kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
                        action_noise.reset(**kwargs)

                    # Log training infos
                    if (
                        log_interval is not None
                        and self._episode_num % log_interval == 0
                    ):
                        self._dump_logs()

                    if "maze" in self.save_path:
                        if self._episode_num % 10 == 0:
                            # Save occupancy
                            self.save_occupancy(self._episode_num)

        callback.on_rollout_end()

        return RolloutReturn(
            num_collected_steps * env.num_envs,
            num_collected_episodes,
            continue_training,
        )


def build_anchor(size, device):
    y = th.arange(-size, size, device=device) + 0.5
    x = th.arange(size, -size, -1.0, device=device) - 0.5  # x.clone()
    x, y = th.meshgrid(x, y, indexing="ij")
    return th.stack([y, x], dim=-1).cuda()


def counter(obs, device):
    # anchor = build_anchor(5, device)
    anchor = build_anchor(12, device)
    obs = obs.reshape(-1, obs.shape[-1])
    # print(obs.shape, anchor.shape)
    reached = th.abs(obs[None, None, :, :] - anchor[:, :, None, :])
    reached = th.logical_and(reached[..., 0] < 0.5, reached[..., 1] < 0.5)
    return reached.sum(axis=-1).float().detach().cpu().numpy()
