import numpy as np
from tianshou.data import ReplayBuffer
from tianshou.data.types import RolloutBatchProtocol, BatchWithReturnsProtocol
from tianshou.policy import DQNPolicy

from policy.multi_objective.multi_objective_policy import MultiObjectivePolicy


class MODQNPolicy(MultiObjectivePolicy, DQNPolicy):

    def process_fn(
        self,
        batch: RolloutBatchProtocol,
        buffer: ReplayBuffer,
        indices: np.ndarray,
    ) -> BatchWithReturnsProtocol:
        """Compute the n-step return for Q-learning targets.

        More details can be found at
        :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
        """
        return self.compute_nstep_return(
            batch=batch,
            buffer=buffer,
            indices=indices,
            target_q_fn=self._target_q,
            gamma=self.gamma,
            n_step=self.n_step,
            rew_norm=self.rew_norm,
            reward_weights=self.reward_weights,
        )
