from omnisafe.algorithms.on_policy.naive_lagrange import PPOLag

from typing import Dict, Generator, Optional, Tuple, Union

import numpy as np
import torch as th
from gymnasium import spaces

from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize

from ucrl.common.type_aliases import RolloutBufferHSamples, DictRolloutBufferHSamples, RolloutBufferCSamples, DictRolloutBufferCSamples

class RolloutBufferH(RolloutBuffer):
    """
    Rollout buffer used in on-policy algorithms with augmented hidden state like PPO-H.
    It corresponds to ``buffer_size`` transitions collected
    using the current policy.
    This experience will be discarded after the policy update.
    In order to use PPO objective, we also store the current value of each state
    and the log probability of each taken action.

    The term rollout here refers to the model-free notion and should not
    be used with the concept of rollout used in model-based RL or planning.
    Hence, it is only involved in policy and value function training but not action selection.

    :param buffer_size: Max number of element in the buffer
    :param observation_space: Observation space
    :param action_space: Action space
    :param device: PyTorch device
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
        Equivalent to classic advantage when set to 1.
    :param gamma: Discount factor
    :param n_envs: Number of parallel environments
    """

    hidden_obs: np.ndarray
    full_hidden_obs: np.ndarray
    log_scores: np.ndarray
    log_score_advantages: np.ndarray
    log_score_returns: np.ndarray
    log_score_values: np.ndarray

    def __init__(
        self,
        buffer_size: int,
        observation_space: spaces.Space,
        full_hidden_obs_shape: Tuple,
        action_space: spaces.Space,
        device: Union[th.device, str] = "auto",
        gae_lambda: float = 1,
        gamma: float = 0.99,
        n_envs: int = 1,
    ):
        self.full_hidden_obs_shape = full_hidden_obs_shape
        super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)

    def reset(self) -> None:
        self.hidden_obs = np.zeros((self.buffer_size, self.n_envs, self.full_hidden_obs_shape[-1]), dtype=np.float32)
        self.full_hidden_obs = np.zeros((self.buffer_size, self.n_envs, *self.full_hidden_obs_shape), dtype=np.float32)
        self.log_scores = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_score_returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        # self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_score_values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_score_advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        super().reset()

    def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray, last_log_score_values: th.Tensor = None) -> None:
        """
        Post-processing step: compute the lambda-return (TD(lambda) estimate)
        and GAE(lambda) advantage.

        Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
        to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
        where R is the sum of discounted reward with value bootstrap
        (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.

        The TD(lambda) estimator has also two special cases:
        - TD(1) is Monte-Carlo estimate (sum of discounted rewards)
        - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))

        For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.

        :param last_values: state value estimation for the last step (one for each env)
        :param dones: if the last step was a terminal step (one bool for each env).
        :param last_log_score_values: state log_score value estimation for the last step (one for each env)
        """
        if last_log_score_values is None:
            super().compute_returns_and_advantage(last_values, dones)
        else:
            # Convert to numpy
            last_values = last_values.clone().cpu().numpy().flatten()
            last_log_score_values = last_log_score_values.clone().cpu().numpy().flatten()

            last_gae_lam, last_log_score_gae_lam = 0, 0
            for step in reversed(range(self.buffer_size)):
                if step == self.buffer_size - 1:
                    next_non_terminal = 1.0 - dones
                    next_values, next_log_score_values = last_values, last_log_score_values
                else:
                    next_non_terminal = 1.0 - self.episode_starts[step + 1]
                    next_values, next_log_score_values = self.values[step + 1], self.log_score_values[step + 1]
                delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
                delta_log_score = self.log_scores[step] + self.gamma * next_log_score_values * next_non_terminal - self.log_score_values[step]
                last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
                last_log_score_gae_lam = delta_log_score + self.gamma * self.gae_lambda * next_non_terminal * last_log_score_gae_lam
                self.advantages[step], self.log_score_advantages[step] = last_gae_lam, last_log_score_gae_lam
            # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
            # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
            self.returns = self.advantages + self.values
            self.log_score_returns = self.log_score_advantages + self.log_score_values

    def add(
        self,
        obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        episode_start: np.ndarray,
        value: th.Tensor,
        log_prob: th.Tensor,
        hidden_obs: np.ndarray = None,
        full_hidden_obs: np.ndarray = None,
        log_scores: np.ndarray = None,
        log_score_value: th.Tensor = None,
    ) -> None:
        """
        :param obs: Observation
        :param action: Action
        :param reward:
        :param episode_start: Start of episode signal.
        :param value: estimated value of the current state
            following the current policy.
        :param log_prob: log probability of the action
            following the current policy.
        :param hidden_obs: Hidden observation
        :param full_hidden_obs: Full set of hidden obs for stacked GRU
        :param log_scores: feasibility log scores
        :param log_score_value: estimated log score value of the current augmented state following the current policy
        """
        assert ((hidden_obs is None and full_hidden_obs is None and log_scores is None and log_score_value is None) or
                (hidden_obs is not None and full_hidden_obs is not None and log_scores is not None and log_score_value is not None))

        if hidden_obs is None and full_hidden_obs is None and log_scores is None and log_score_value is None:
            super().add(obs, action, reward, episode_start, value, log_prob)
        else:

            if len(log_prob.shape) == 0:
                # Reshape 0-d tensor to avoid error
                log_prob = log_prob.reshape(-1, 1)

            # Reshape needed when using multiple envs with discrete observations
            # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
            if isinstance(self.observation_space, spaces.Discrete):
                obs = obs.reshape((self.n_envs, *self.obs_shape))

            # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
            action = action.reshape((self.n_envs, self.action_dim))

            self.observations[self.pos] = np.array(obs)
            self.actions[self.pos] = np.array(action)
            self.rewards[self.pos] = np.array(reward)
            self.episode_starts[self.pos] = np.array(episode_start)
            self.values[self.pos] = value.clone().cpu().numpy().flatten()
            self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
            self.hidden_obs[self.pos] = np.array(hidden_obs)
            full_hidden_obs = full_hidden_obs.swapaxes(0, 1)
            self.full_hidden_obs[self.pos] = np.array(full_hidden_obs)
            self.log_scores[self.pos] = np.array(log_scores)
            self.log_score_values[self.pos] = log_score_value.clone().cpu().numpy().flatten()
            self.pos += 1
            if self.pos == self.buffer_size:
                self.full = True

    def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferHSamples, None, None]:
        assert self.full, ""
        indices = np.random.permutation(self.buffer_size * self.n_envs)
        # Prepare the data
        if not self.generator_ready:
            _tensor_names = [
                "observations",
                "hidden_obs",
                "full_hidden_obs",
                "actions",
                "log_scores",
                "values",
                "log_score_values",
                "log_probs",
                "advantages",
                "log_score_advantages",
                "returns",
                "log_score_returns",
            ]

            for tensor in _tensor_names:
                self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
            self.generator_ready = True

        # Return everything, don't create minibatches
        if batch_size is None:
            batch_size = self.buffer_size * self.n_envs

        start_idx = 0
        while start_idx < self.buffer_size * self.n_envs:
            yield self._get_samples(indices[start_idx : start_idx + batch_size])
            start_idx += batch_size

    def _get_samples(
        self,
        batch_inds: np.ndarray,
        env: Optional[VecNormalize] = None,
    ) -> RolloutBufferHSamples:
        data = (
            self.observations[batch_inds],
            self.hidden_obs[batch_inds],
            self.full_hidden_obs[batch_inds].swapaxes(0, 1),  # TODO Check
            self.actions[batch_inds],
            self.log_scores[batch_inds].flatten(),
            self.values[batch_inds].flatten(),
            self.log_score_values[batch_inds].flatten(),
            self.log_probs[batch_inds].flatten(),
            self.advantages[batch_inds].flatten(),
            self.log_score_advantages[batch_inds].flatten(),
            self.returns[batch_inds].flatten(),
            self.log_score_returns[batch_inds].flatten(),
        )
        return RolloutBufferHSamples(*tuple(map(self.to_torch, data)))

class DictRolloutBufferH(RolloutBufferH):
    """
    Dict Rollout buffer used in on-policy algorithms like PPO-H.
    Extends the RolloutBuffer to use dictionary observations

    It corresponds to ``buffer_size`` transitions collected
    using the current policy.
    This experience will be discarded after the policy update.
    In order to use PPO objective, we also store the current value of each state
    and the log probability of each taken action.

    The term rollout here refers to the model-free notion and should not
    be used with the concept of rollout used in model-based RL or planning.
    Hence, it is only involved in policy and value function training but not action selection.

    :param buffer_size: Max number of element in the buffer
    :param observation_space: Observation space
    :param action_space: Action space
    :param device: PyTorch device
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
        Equivalent to Monte-Carlo advantage estimate when set to 1.
    :param gamma: Discount factor
    :param n_envs: Number of parallel environments
    """

    observation_space: spaces.Dict
    obs_shape: Dict[str, Tuple[int, ...]]  # type: ignore[assignment]
    observations: Dict[str, np.ndarray]  # type: ignore[assignment]

    def __init__(
        self,
        buffer_size: int,
        observation_space: spaces.Dict,
        full_hidden_obs_shape: Tuple,
        action_space: spaces.Space,
        device: Union[th.device, str] = "auto",
        gae_lambda: float = 1,
        gamma: float = 0.99,
        n_envs: int = 1,
    ):

        self.full_hidden_obs_shape = full_hidden_obs_shape

        super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)

        assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"

        self.gae_lambda = gae_lambda
        self.gamma = gamma

        self.generator_ready = False
        self.reset()

    def reset(self) -> None:
        self.observations = {}
        for key, obs_input_shape in self.obs_shape.items():
            self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
        self.hidden_obs = np.zeros((self.buffer_size, self.n_envs, self.full_hidden_obs_shape[-1]), dtype=np.float32)
        self.full_hidden_obs = np.zeros((self.buffer_size, self.n_envs, *self.full_hidden_obs_shape), dtype=np.float32)
        self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)

        self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_scores = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_score_returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

        self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_score_values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_score_advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.generator_ready = False
        super(RolloutBuffer, self).reset()

    def add(  # type: ignore[override]
        self,
        obs: Dict[str, np.ndarray],
        action: np.ndarray,
        reward: np.ndarray,
        episode_start: np.ndarray,
        value: th.Tensor,
        log_prob: th.Tensor,
        hidden_obs: np.ndarray = None,
        full_hidden_obs: np.ndarray = None,
        log_scores: np.ndarray = None,
        log_score_value: th.Tensor = None,
    ) -> None:
        """
        :param obs: Observation
        :param action: Action
        :param reward:
        :param episode_start: Start of episode signal.
        :param value: estimated value of the current state
            following the current policy.
        :param log_prob: log probability of the action
            following the current policy.
        :param hidden_obs: Hidden observation
        :param full_hidden_obs: Full set of hidden obs for stacked GRU
        :param log_scores: feasibility log scores
        :param log_score_value: estimated log score value of the current augmented state following the current policy
        """
        assert ((hidden_obs is None and full_hidden_obs is None and log_scores is None and log_score_value is None) or
                (hidden_obs is not None and full_hidden_obs is not None and log_scores is not None and log_score_value is not None))

        if hidden_obs is None and full_hidden_obs is None and log_scores is None and log_score_value is None:
            hidden_obs, full_hidden_obs, log_scores, log_score_value = np.array([]), np.array([]), np.array([]), np.array([])
            # super().add(obs, action, reward, episode_start, value, log_prob)

        if len(log_prob.shape) == 0:
            # Reshape 0-d tensor to avoid error
            log_prob = log_prob.reshape(-1, 1)

        for key in self.observations.keys():
            obs_ = np.array(obs[key])
            # Reshape needed when using multiple envs with discrete observations
            # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
            if isinstance(self.observation_space.spaces[key], spaces.Discrete):
                obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
            self.observations[key][self.pos] = obs_

        # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
        action = action.reshape((self.n_envs, self.action_dim))

        self.actions[self.pos] = np.array(action)
        self.rewards[self.pos] = np.array(reward)
        self.episode_starts[self.pos] = np.array(episode_start)
        self.values[self.pos] = value.clone().cpu().numpy().flatten()
        self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
        self.hidden_obs[self.pos] = np.array(hidden_obs)
        full_hidden_obs = full_hidden_obs.swapaxes(0, 1)
        self.full_hidden_obs[self.pos] = np.array(full_hidden_obs)
        self.log_scores[self.pos] = np.array(log_scores)
        self.log_score_values[self.pos] = log_score_value.clone().cpu().numpy().flatten()
        self.pos += 1
        if self.pos == self.buffer_size:
            self.full = True

    def get(  # type: ignore[override]
        self,
        batch_size: Optional[int] = None,
    ) -> Generator[DictRolloutBufferHSamples, None, None]:
        assert self.full, ""
        indices = np.random.permutation(self.buffer_size * self.n_envs)
        # Prepare the data
        if not self.generator_ready:
            for key, obs in self.observations.items():
                self.observations[key] = self.swap_and_flatten(obs)

            _tensor_names = ["hidden_obs", "full_hidden_obs", "actions", "log_scores", "values", "log_score_values", "log_probs", "advantages", "log_score_advantages", "returns", "log_score_returns"]

            for tensor in _tensor_names:
                self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
            self.generator_ready = True

        # Return everything, don't create minibatches
        if batch_size is None:
            batch_size = self.buffer_size * self.n_envs

        start_idx = 0
        while start_idx < self.buffer_size * self.n_envs:
            yield self._get_samples(indices[start_idx : start_idx + batch_size])
            start_idx += batch_size

    def _get_samples(  # type: ignore[override]
        self,
        batch_inds: np.ndarray,
        env: Optional[VecNormalize] = None,
    ) -> DictRolloutBufferHSamples:
        return DictRolloutBufferHSamples(
            observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
            hidden_obs=self.to_torch(self.hidden_obs[batch_inds]),
            full_hidden_obs=self.to_torch(self.full_hidden_obs[batch_inds].swapaxes(0, 1)),
            actions=self.to_torch(self.actions[batch_inds]),
            log_scores=self.to_torch(self.log_scores[batch_inds].flatten()),
            old_values=self.to_torch(self.values[batch_inds].flatten()),
            old_log_score_values=self.to_torch(self.log_score_values[batch_inds].flatten()),
            old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
            advantages=self.to_torch(self.advantages[batch_inds].flatten()),
            log_score_advantages=self.to_torch(self.log_score_advantages[batch_inds].flatten()),
            returns=self.to_torch(self.returns[batch_inds].flatten()),
            log_score_returns=self.to_torch(self.log_score_returns[batch_inds].flatten()),
        )


class RolloutBufferC(RolloutBuffer):
    """
    Rollout buffer used in on-policy algorithms like PPO-Lag.
    It corresponds to ``buffer_size`` transitions collected
    using the current policy.
    This experience will be discarded after the policy update.
    In order to use PPO objective, we also store the current value of each state
    and the log probability of each taken action.

    The term rollout here refers to the model-free notion and should not
    be used with the concept of rollout used in model-based RL or planning.
    Hence, it is only involved in policy and value function training but not action selection.

    :param buffer_size: Max number of element in the buffer
    :param observation_space: Observation space
    :param action_space: Action space
    :param device: PyTorch device
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
        Equivalent to classic advantage when set to 1.
    :param gamma: Discount factor
    :param n_envs: Number of parallel environments
    """

    neg_costs: np.ndarray
    neg_cost_advantages: np.ndarray
    neg_cost_returns: np.ndarray
    neg_cost_values: np.ndarray

    def reset(self) -> None:
        self.neg_costs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.neg_cost_returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        # self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.neg_cost_values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.neg_cost_advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        super().reset()

    def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray, last_neg_cost_values: th.Tensor = None) -> None:
        """
        Post-processing step: compute the lambda-return (TD(lambda) estimate)
        and GAE(lambda) advantage.

        Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
        to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
        where R is the sum of discounted reward with value bootstrap
        (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.

        The TD(lambda) estimator has also two special cases:
        - TD(1) is Monte-Carlo estimate (sum of discounted rewards)
        - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))

        For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.

        :param last_values: state value estimation for the last step (one for each env)
        :param dones: if the last step was a terminal step (one bool for each env).
        :param last_neg_cost_values: state neg_cost value estimation for the last step (one for each env)
        """
        if last_neg_cost_values is None:
            super().compute_returns_and_advantage(last_values, dones)
        else:
            # Convert to numpy
            last_values = last_values.clone().cpu().numpy().flatten()
            last_neg_cost_values = last_neg_cost_values.clone().cpu().numpy().flatten()

            last_gae_lam, last_neg_cost_gae_lam = 0, 0
            for step in reversed(range(self.buffer_size)):
                if step == self.buffer_size - 1:
                    next_non_terminal = 1.0 - dones
                    next_values, next_neg_cost_values = last_values, last_neg_cost_values
                else:
                    next_non_terminal = 1.0 - self.episode_starts[step + 1]
                    next_values, next_neg_cost_values = self.values[step + 1], self.neg_cost_values[step + 1]
                delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
                delta_neg_cost = self.neg_costs[step] + self.gamma * next_neg_cost_values * next_non_terminal - self.neg_cost_values[step]
                last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
                last_neg_cost_gae_lam = delta_neg_cost + self.gamma * self.gae_lambda * next_non_terminal * last_neg_cost_gae_lam
                self.advantages[step], self.neg_cost_advantages[step] = last_gae_lam, last_neg_cost_gae_lam
            # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
            # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
            self.returns = self.advantages + self.values
            self.neg_cost_returns = self.neg_cost_advantages + self.neg_cost_values

    def add(
        self,
        obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        episode_start: np.ndarray,
        value: th.Tensor,
        log_prob: th.Tensor,
        neg_costs: np.ndarray = None,
        neg_cost_value: th.Tensor = None,
    ) -> None:
        """
        :param obs: Observation
        :param action: Action
        :param reward:
        :param episode_start: Start of episode signal.
        :param value: estimated value of the current state
            following the current policy.
        :param log_prob: log probability of the action
            following the current policy.
        :param neg_costs: negative costs
        :param neg_cost_value: estimated negative cost value of the current augmented state following the current policy
        """
        assert ((neg_costs is None and neg_cost_value is None) or (neg_costs is not None and neg_cost_value is not None))

        if neg_costs is None and neg_cost_value is None:
            super().add(obs, action, reward, episode_start, value, log_prob)
        else:

            if len(log_prob.shape) == 0:
                # Reshape 0-d tensor to avoid error
                log_prob = log_prob.reshape(-1, 1)

            # Reshape needed when using multiple envs with discrete observations
            # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
            if isinstance(self.observation_space, spaces.Discrete):
                obs = obs.reshape((self.n_envs, *self.obs_shape))

            # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
            action = action.reshape((self.n_envs, self.action_dim))

            self.observations[self.pos] = np.array(obs)
            self.actions[self.pos] = np.array(action)
            self.rewards[self.pos] = np.array(reward)
            self.episode_starts[self.pos] = np.array(episode_start)
            self.values[self.pos] = value.clone().cpu().numpy().flatten()
            self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
            self.neg_costs[self.pos] = np.array(neg_costs)
            self.neg_cost_values[self.pos] = neg_cost_value.clone().cpu().numpy().flatten()
            self.pos += 1
            if self.pos == self.buffer_size:
                self.full = True

    def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferCSamples, None, None]:
        assert self.full, ""
        indices = np.random.permutation(self.buffer_size * self.n_envs)
        # Prepare the data
        if not self.generator_ready:
            _tensor_names = [
                "observations",
                "actions",
                "neg_costs",
                "values",
                "neg_cost_values",
                "log_probs",
                "advantages",
                "neg_cost_advantages",
                "returns",
                "neg_cost_returns",
            ]

            for tensor in _tensor_names:
                self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
            self.generator_ready = True

        # Return everything, don't create minibatches
        if batch_size is None:
            batch_size = self.buffer_size * self.n_envs

        start_idx = 0
        while start_idx < self.buffer_size * self.n_envs:
            yield self._get_samples(indices[start_idx : start_idx + batch_size])
            start_idx += batch_size

    def _get_samples(
        self,
        batch_inds: np.ndarray,
        env: Optional[VecNormalize] = None,
    ) -> RolloutBufferCSamples:
        data = (
            self.observations[batch_inds],
            self.actions[batch_inds],
            self.neg_costs[batch_inds].flatten(),
            self.values[batch_inds].flatten(),
            self.neg_cost_values[batch_inds].flatten(),
            self.log_probs[batch_inds].flatten(),
            self.advantages[batch_inds].flatten(),
            self.neg_cost_advantages[batch_inds].flatten(),
            self.returns[batch_inds].flatten(),
            self.neg_cost_returns[batch_inds].flatten(),
        )
        return RolloutBufferCSamples(*tuple(map(self.to_torch, data)))


class DictRolloutBufferC(RolloutBufferC):
    """
    Dict Rollout buffer used in on-policy algorithms like PPO-Lag.
    Extends the RolloutBuffer to use dictionary observations

    It corresponds to ``buffer_size`` transitions collected
    using the current policy.
    This experience will be discarded after the policy update.
    In order to use PPO objective, we also store the current value of each state
    and the log probability of each taken action.

    The term rollout here refers to the model-free notion and should not
    be used with the concept of rollout used in model-based RL or planning.
    Hence, it is only involved in policy and value function training but not action selection.

    :param buffer_size: Max number of element in the buffer
    :param observation_space: Observation space
    :param action_space: Action space
    :param device: PyTorch device
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
        Equivalent to Monte-Carlo advantage estimate when set to 1.
    :param gamma: Discount factor
    :param n_envs: Number of parallel environments
    """

    observation_space: spaces.Dict
    obs_shape: Dict[str, Tuple[int, ...]]  # type: ignore[assignment]
    observations: Dict[str, np.ndarray]  # type: ignore[assignment]

    def __init__(
        self,
        buffer_size: int,
        observation_space: spaces.Dict,
        action_space: spaces.Space,
        device: Union[th.device, str] = "auto",
        gae_lambda: float = 1,
        gamma: float = 0.99,
        n_envs: int = 1,
    ):
        super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)

        assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"

        self.gae_lambda = gae_lambda
        self.gamma = gamma

        self.generator_ready = False
        self.reset()

    def reset(self) -> None:
        self.observations = {}
        for key, obs_input_shape in self.obs_shape.items():
            self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
        self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
        self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

        self.neg_costs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.neg_cost_returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

        self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.neg_cost_values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.neg_cost_advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.generator_ready = False
        super(RolloutBuffer, self).reset()

    def add(  # type: ignore[override]
        self,
        obs: Dict[str, np.ndarray],
        action: np.ndarray,
        reward: np.ndarray,
        episode_start: np.ndarray,
        value: th.Tensor,
        log_prob: th.Tensor,
        neg_costs: np.ndarray = None,
        neg_cost_value: th.Tensor = None,
    ) -> None:
        """
        :param obs: Observation
        :param action: Action
        :param reward:
        :param episode_start: Start of episode signal.
        :param value: estimated value of the current state
            following the current policy.
        :param log_prob: log probability of the action
            following the current policy.
        :param neg_costs: negative costs
        :param neg_cost_value: estimated negative costs value of the current augmented state following the current policy
        """
        assert ((neg_costs is None and neg_cost_value is None) or (neg_costs is not None and neg_cost_value is not None))

        if neg_costs is None and neg_cost_value is None:
            neg_costs, neg_cost_value = np.array([]), np.array([])
            # super().add(obs, action, reward, episode_start, value, log_prob)

        if len(log_prob.shape) == 0:
            # Reshape 0-d tensor to avoid error
            log_prob = log_prob.reshape(-1, 1)

        for key in self.observations.keys():
            obs_ = np.array(obs[key])
            # Reshape needed when using multiple envs with discrete observations
            # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
            if isinstance(self.observation_space.spaces[key], spaces.Discrete):
                obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
            self.observations[key][self.pos] = obs_

        # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
        action = action.reshape((self.n_envs, self.action_dim))

        self.actions[self.pos] = np.array(action)
        self.rewards[self.pos] = np.array(reward)
        self.episode_starts[self.pos] = np.array(episode_start)
        self.values[self.pos] = value.clone().cpu().numpy().flatten()
        self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
        self.neg_costs[self.pos] = np.array(neg_costs)
        self.neg_cost_values[self.pos] = neg_cost_value.clone().cpu().numpy().flatten()
        self.pos += 1
        if self.pos == self.buffer_size:
            self.full = True

    def get(  # type: ignore[override]
        self,
        batch_size: Optional[int] = None,
    ) -> Generator[DictRolloutBufferCSamples, None, None]:
        assert self.full, ""
        indices = np.random.permutation(self.buffer_size * self.n_envs)
        # Prepare the data
        if not self.generator_ready:
            for key, obs in self.observations.items():
                self.observations[key] = self.swap_and_flatten(obs)

            _tensor_names = ["actions", "neg_costs", "values", "neg_cost_values", "log_probs", "advantages", "neg_cost_advantages", "returns", "neg_cost_returns"]

            for tensor in _tensor_names:
                self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
            self.generator_ready = True

        # Return everything, don't create minibatches
        if batch_size is None:
            batch_size = self.buffer_size * self.n_envs

        start_idx = 0
        while start_idx < self.buffer_size * self.n_envs:
            yield self._get_samples(indices[start_idx : start_idx + batch_size])
            start_idx += batch_size

    def _get_samples(  # type: ignore[override]
        self,
        batch_inds: np.ndarray,
        env: Optional[VecNormalize] = None,
    ) -> DictRolloutBufferCSamples:
        return DictRolloutBufferCSamples(
            observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
            actions=self.to_torch(self.actions[batch_inds]),
            neg_costs=self.to_torch(self.neg_costs[batch_inds].flatten()),
            old_values=self.to_torch(self.values[batch_inds].flatten()),
            old_neg_cost_values=self.to_torch(self.neg_cost_values[batch_inds].flatten()),
            old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
            advantages=self.to_torch(self.advantages[batch_inds].flatten()),
            neg_cost_advantages=self.to_torch(self.neg_cost_advantages[batch_inds].flatten()),
            returns=self.to_torch(self.returns[batch_inds].flatten()),
            neg_cost_returns=self.to_torch(self.neg_cost_returns[batch_inds].flatten()),
        )
