
from tianshou.policy import C51Policy

from .dqn import CustomDQNPolicy


class CustomC51Policy(C51Policy, CustomDQNPolicy):
    """Implementation of Categorical Deep Q-Network. arXiv:1707.06887.

    :param model: a model following the rules (s_B -> action_values_BA)
    :param optim: a torch.optim for optimizing the model.
    :param discount_factor: in [0, 1].
    :param num_atoms: the number of atoms in the support set of the
        value distribution. Default to 51.
    :param v_min: the value of the smallest atom in the support set.
        Default to -10.0.
    :param v_max: the value of the largest atom in the support set.
        Default to 10.0.
    :param estimation_step: the number of steps to look ahead.
    :param target_update_freq: the target network update frequency (0 if
        you do not use the target network).
    :param reward_normalization: normalize the **returns** to Normal(0, 1).
        TODO: rename to return_normalization?
    :param is_double: use double dqn.
    :param clip_loss_grad: clip the gradient of the loss in accordance
        with nature14236; this amounts to using the Huber loss instead of
        the MSE loss.
    :param observation_space: Env's observation space.
    :param lr_scheduler: if not None, will be called in `policy.update()`.

    .. seealso::

        Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed
        explanation.
    """
