C51_Learner
======================

.. raw:: html

    <br><hr>

PyTorch
------------------------------------------

.. py:class::
  xuance.torch.learners.qlearning_family.c51_learner.C51_Learner(policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency)

  :param policy: The policy that provides actions and values.
  :type policy: nn.Module
  :param optimizer: The optimizer that update the parameters of the model.
  :type optimizer: Optimizer
  :param scheduler: The tool for learning rate decay.
  :type scheduler: lr_scheduler
  :param device: The calculating device.
  :type device: str
  :param model_dir: The directory for saving or loading the model parameters.
  :type model_dir: str
  :param gamma: The discount factor.
  :type gamma: float
  :param sync_frequency: The frequency to synchronize the target networks.
  :type sync_frequency: int

.. py:function::
  xuance.torch.learners.qlearning_family.c51_learner.C51_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch)

  :param obs_batch: A batch of observations sampled from experience replay buffer.
  :type obs_batch: np.ndarray
  :param act_batch: A batch of actions sampled from experience replay buffer.
  :type act_batch: np.ndarray
  :param rew_batch: A batch of rewards sampled from experience replay buffer.
  :type rew_batch: np.ndarray
  :param next_batch: A batch of next observations sampled from experience replay buffer.
  :type next_batch: np.ndarray
  :param terminal_batch: A batch of terminal data sampled from experience replay buffer.
  :type terminal_batch: np.ndarray
  :return: The information of the training.
  :rtype: dict

.. raw:: html

    <br><hr>

TensorFlow
------------------------------------------

.. py:class::
  xuance.tensorflow.learners.qlearning_family.c51_learner.C51_Learner(policy, optimizer, device, model_dir, gamma, sync_frequency)

  :param policy: The policy that provides actions and values.
  :type policy: nn.Module
  :param optimizer: The optimizer that update the parameters of the model.
  :type optimizer: Optimizer
  :param device: The calculating device.
  :type device: str
  :param model_dir: The directory for saving or loading the model parameters.
  :type model_dir: str
  :param gamma: The discount factor.
  :type gamma: float
  :param sync_frequency: The frequency to synchronize the target networks.
  :type sync_frequency: int

.. py:function::
  xuance.tensorflow.learners.qlearning_family.c51_learner.C51_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch)

  :param obs_batch: A batch of observations sampled from experience replay buffer.
  :type obs_batch: np.ndarray
  :param act_batch: A batch of actions sampled from experience replay buffer.
  :type act_batch: np.ndarray
  :param rew_batch: A batch of rewards sampled from experience replay buffer.
  :type rew_batch: np.ndarray
  :param next_batch: A batch of next observations sampled from experience replay buffer.
  :type next_batch: np.ndarray
  :param terminal_batch: A batch of terminal data sampled from experience replay buffer.
  :type terminal_batch: np.ndarray
  :return: The information of the training.
  :rtype: dict

.. raw:: html

    <br><hr>

MindSpore
------------------------------------------

.. py:class::
  xuance.mindspore.learners.qlearning_family.c51_learner.C51_Learner(policy, optimizer, scheduler, model_dir, gamma, sync_frequency)

  :param policy: The policy that provides actions and values.
  :type policy: nn.Module
  :param optimizer: The optimizer that update the parameters of the model.
  :type optimizer: Optimizer
  :param scheduler: The tool for learning rate decay.
  :type scheduler: lr_scheduler
  :param model_dir: The directory for saving or loading the model parameters.
  :type model_dir: str
  :param gamma: The discount factor.
  :type gamma: float
  :param sync_frequency: The frequency to synchronize the target networks.
  :type sync_frequency: int

.. py:function::
  xuance.mindspore.learners.qlearning_family.c51_learner.C51_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch)

  :param obs_batch: A batch of observations sampled from experience replay buffer.
  :type obs_batch: np.ndarray
  :param act_batch: A batch of actions sampled from experience replay buffer.
  :type act_batch: np.ndarray
  :param rew_batch: A batch of rewards sampled from experience replay buffer.
  :type rew_batch: np.ndarray
  :param next_batch: A batch of next observations sampled from experience replay buffer.
  :type next_batch: np.ndarray
  :param terminal_batch: A batch of terminal data sampled from experience replay buffer.
  :type terminal_batch: np.ndarray
  :return: The information of the training.
  :rtype: dict

.. raw:: html

    <br><hr>

Source Code
-----------------

.. tabs::

  .. group-tab:: PyTorch

    .. code-block:: python

        from xuance.torch.learners import *


        class C51_Learner(Learner):
            def __init__(self,
                         policy: nn.Module,
                         optimizer: torch.optim.Optimizer,
                         scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
                         device: Optional[Union[int, str, torch.device]] = None,
                         model_dir: str = "./",
                         gamma: float = 0.99,
                         sync_frequency: int = 100):
                self.gamma = gamma
                self.sync_frequency = sync_frequency
                super(C51_Learner, self).__init__(policy, optimizer, scheduler, device, model_dir)

            def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
                self.iterations += 1
                act_batch = torch.as_tensor(act_batch, device=self.device).long()
                rew_batch = torch.as_tensor(rew_batch, device=self.device)
                ter_batch = torch.as_tensor(terminal_batch, device=self.device)
                _, _, evalZ = self.policy(obs_batch)
                _, targetA, targetZ = self.policy.target(next_batch)

                current_dist = (evalZ * F.one_hot(act_batch, evalZ.shape[1]).unsqueeze(-1)).sum(1)
                target_dist = (targetZ * F.one_hot(targetA.detach(), evalZ.shape[1]).unsqueeze(-1)).sum(1).detach()

                current_supports = self.policy.supports
                next_supports = rew_batch.unsqueeze(1) + self.gamma * self.policy.supports * (1 - ter_batch.unsqueeze(1))
                next_supports = next_supports.clamp(self.policy.vmin, self.policy.vmax)

                projection = 1 - (next_supports.unsqueeze(-1) - current_supports.unsqueeze(0)).abs() / self.policy.deltaz
                target_dist = torch.bmm(target_dist.unsqueeze(1), projection.clamp(0, 1)).squeeze(1)
                loss = -(target_dist * torch.log(current_dist + 1e-8)).sum(1).mean()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                if self.scheduler is not None:
                    self.scheduler.step()
                # hard update for target network
                if self.iterations % self.sync_frequency == 0:
                    self.policy.copy_target()
                lr = self.optimizer.state_dict()['param_groups'][0]['lr']

                info = {
                    "Qloss": loss.item(),
                    "learning_rate": lr
                }

                return info






  .. group-tab:: TensorFlow

    .. code-block:: python

        from xuance.tensorflow.learners import *


        class C51_Learner(Learner):
            def __init__(self,
                         policy: tk.Model,
                         optimizer: tk.optimizers.Optimizer,
                         device: str = "cpu:0",
                         model_dir: str = "./",
                         gamma: float = 0.99,
                         sync_frequency: int = 100):
                self.gamma = gamma
                self.sync_frequency = sync_frequency
                super(C51_Learner, self).__init__(policy, optimizer, device, model_dir)

            def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
                self.iterations += 1
                with tf.device(self.device):
                    act_batch = tf.cast(tf.convert_to_tensor(act_batch), dtype=tf.int64)
                    rew_batch = tf.convert_to_tensor(rew_batch)
                    ter_batch = tf.convert_to_tensor(terminal_batch)

                    with tf.GradientTape() as tape:
                        _, _, evalZ = self.policy(obs_batch)
                        _, targetA, targetZ = self.policy.target(next_batch)

                        current_dist = tf.reduce_sum(evalZ * tf.expand_dims(tf.one_hot(act_batch, evalZ.shape[1]), axis=-1), axis=1)
                        target_dist = tf.stop_gradient(tf.reduce_sum(targetZ * tf.expand_dims(tf.one_hot(targetA, evalZ.shape[1]), axis=-1), axis=1))

                        current_supports = self.policy.supports
                        next_supports = tf.expand_dims(rew_batch, 1) + self.gamma * self.policy.supports * (1 - tf.expand_dims(ter_batch, 1))
                        next_supports = tf.clip_by_value(next_supports, self.policy.vmin, self.policy.vmax)

                        projection = 1 - tf.math.abs((tf.expand_dims(next_supports, -1) - tf.expand_dims(current_supports, 0))) / self.policy.deltaz
                        target_dist = tf.squeeze(tf.linalg.matmul(tf.expand_dims(target_dist, 1), tf.clip_by_value(projection, 0, 1)), 1)

                        loss = -tf.reduce_mean(tf.reduce_sum((target_dist * tf.math.log(current_dist + 1e-8)), axis=1))

                    gradients = tape.gradient(loss, self.policy.trainable_variables)
                    self.optimizer.apply_gradients([
                        (grad, var)
                        for (grad, var) in zip(gradients, self.policy.trainable_variables)
                        if grad is not None
                    ])

                    if self.iterations % self.sync_frequency == 0:
                        self.policy.copy_target()

                    lr = self.optimizer._decayed_lr(tf.float32)

                    info = {
                        "Qloss": loss.numpy(),
                        "lr": lr.numpy(),
                    }

                    return info



  .. group-tab:: MindSpore

    .. code-block:: python

        from xuance.mindspore.learners import *
        from mindspore.ops import OneHot,Log,BatchMatMul,ExpandDims,Squeeze,ReduceSum,Abs,ReduceMean,clip_by_value


        class C51_Learner(Learner):
            class PolicyNetWithLossCell(nn.Cell):
                def __init__(self, backbone):
                    super(C51_Learner.PolicyNetWithLossCell, self).__init__(auto_prefix=False)
                    self._backbone = backbone
                    self._onehot = OneHot()
                    self._log = Log()
                    self._bmm = BatchMatMul()
                    self._unsqueeze = ExpandDims()
                    self._squeeze = Squeeze(1)
                    self._sum = ReduceSum()
                    self._mean = ReduceMean()
                    self.on_value = Tensor(1.0, ms.float32)
                    self.off_value = Tensor(0.0, ms.float32)
                    self.clamp_min_value = Tensor(0.0, ms.float32)
                    self.clamp_max_value = Tensor(1.0, ms.float32)

                def construct(self, x, a, projection, target_a, target_z):
                    _, _, evalZ = self._backbone(x)

                    current_dist = self._sum(evalZ * self._unsqueeze(self._onehot(a, evalZ.shape[1], self.on_value, self.off_value), -1), 1)
                    target_dist = self._sum(target_z * self._unsqueeze(self._onehot(target_a, evalZ.shape[1], self.on_value, self.off_value), -1), 1)

                    target_dist = self._squeeze(self._bmm(self._unsqueeze(target_dist, 1),clip_by_value(projection,self.clamp_min_value,self.clamp_max_value)))
                    loss = -self._mean(self._sum((target_dist *  self._log(current_dist + 1e-8)), 1))

                    return loss

            def __init__(self,
                         policy: nn.Cell,
                         optimizer: nn.Optimizer,
                         scheduler: Optional[nn.exponential_decay_lr] = None,
                         model_dir: str = "./",
                         gamma: float = 0.99,
                         sync_frequency: int = 100):
                self.gamma = gamma
                self.sync_frequency = sync_frequency
                super(C51_Learner, self).__init__(policy, optimizer, scheduler, model_dir)
                # connect the feed forward network with loss function.
                self.loss_net = self.PolicyNetWithLossCell(policy)
                # define the training network
                self.policy_train = nn.TrainOneStepCell(self.loss_net, optimizer)
                # set the training network as train mode.
                self.policy_train.set_train()

                self._abs = Abs()
                self._unsqueeze = ExpandDims()

            def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
                self.iterations += 1
                obs_batch = Tensor(obs_batch)
                act_batch = Tensor(act_batch, ms.int32)
                rew_batch = Tensor(rew_batch)
                next_batch = Tensor(next_batch)
                ter_batch = Tensor(terminal_batch)

                _, targetA, targetZ = self.policy(next_batch)

                current_supports = self.policy.supports
                next_supports = self._unsqueeze(rew_batch, 1) + self.gamma * self.policy.supports * (1-self._unsqueeze(ter_batch, -1))
                next_supports = clip_by_value(next_supports, Tensor(self.policy.vmin, ms.float32), Tensor(self.policy.vmax, ms.float32))
                projection = 1 - self._abs((self._unsqueeze(next_supports, -1) - self._unsqueeze(current_supports, 0)))/self.policy.deltaz

                loss = self.policy_train(obs_batch, act_batch, projection, targetA, targetZ)

                # hard update for target network
                if self.iterations % self.sync_frequency == 0:
                    self.policy.copy_target()

                lr = self.scheduler(self.iterations).asnumpy()

                info = {
                    "Qloss": loss.asnumpy(),
                    "learning_rate": lr
                }

                return info
