"""PPO with auxiliary loss"""

from dowel import tabular

from garage.torch.algos import PPO


class PPOAux(PPO):
    def __init__(
        self,
        env_spec,
        policy,
        value_function,
        sampler,
        auxiliary_obj_coeff=1.0,
        **kwargs
    ):
        super().__init__(env_spec, policy, value_function, sampler, kwargs)
        self._auxiliary_obj = None
        self._auxiliary_obj_coeff = auxiliary_obj_coeff

    def _set_auxiliary_obj(self, aux):
        self._auxiliary_obj = aux

    def _compute_loss_with_adv(self, obs, actions, rewards, advantages):
        ppo_obj = super()._compute_loss_with_adv(obs, actions, rewards, advantages)
        return ppo_obj + self._auxiliary_obj * self._auxiliary_obj_coeff
