Reinforcement Learning: Proximal Policy Optimization (PPO)
===============================================================
Author: `lucifer2859 <https://github.com/lucifer2859>`_

Translator: `LiutaoYu <https://github.com/LiutaoYu>`_

This tutorial applies a spiking neural network to reproduce `ppo.py <https://github.com/lucifer2859/Policy-Gradients/blob/master/ppo.py>`_.
Please make sure that you have read the original tutorial and corresponding codes before proceeding.

Here, we apply the same method as the previous DQN tutorial to make SNN output floating numbers.
We set the firing threshold of a neuron to be infinity, which won't fire at all, and we adopt the final membrane potential to represent Q function.
It is convenient to implement such neurons in the ``SpikingJelly`` framework: just inherit everything from LIF neuron ``neuron.LIFNode`` and rewrite the ``forward`` function.

.. code-block:: python

    class NonSpikingLIFNode(neuron.LIFNode):
        def forward(self, dv: torch.Tensor):
            self.neuronal_charge(dv)
            # self.neuronal_fire()
            # self.neuronal_reset()
            return self.v

The basic structure of the Spiking Actor-Critic Network is very simple: input layer, IF neuron layer, and NonSpikingLIF neuron layer,
between which are fully linear connections.
The IF neuron layer is an encoder to convert the CartPole's state variables to spikes,
and the NonSpikingLIF neuron layer can be regraded as the decision making unit.

.. code-block:: python

    class ActorCritic(nn.Module):
        def __init__(self, num_inputs, num_outputs, hidden_size, T=16, std=0.0):
            super(ActorCritic, self).__init__()
            
            self.critic = nn.Sequential(
                nn.Linear(num_inputs, hidden_size),
                neuron.IFNode(),
                nn.Linear(hidden_size, 1),
                NonSpikingLIFNode(tau=2.0)
            )
            
            self.actor = nn.Sequential(
                nn.Linear(num_inputs, hidden_size),
                neuron.IFNode(),
                nn.Linear(hidden_size, num_outputs),
                NonSpikingLIFNode(tau=2.0)
            )

            self.log_std = nn.Parameter(torch.ones(1, num_outputs) * std)

            self.T = T
            
        def forward(self, x):
            for t in range(self.T):
                self.critic(x)
                self.actor(x)
            value = self.critic[-1].v
            mu    = self.actor[-1].v
            std   = self.log_std.exp().expand_as(mu)
            dist  = Normal(mu, std)
            return dist, value


Training the network
---------------------------
The code of this part is almost the same with the ANN version.
But note that the SNN version here adopts ``Observation`` returned by ``env`` as the network input.

Following is the training code of the SNN version.
During the training process, we will save the model parameters responsible for the largest reward.

.. code-block:: python

    # GAE
    def compute_gae(next_value, rewards, masks, values, gamma=0.99, tau=0.95):
        values = values + [next_value]
        gae = 0
        returns = []
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
            gae = delta + gamma * tau * masks[step] * gae
            returns.insert(0, gae + values[step])
        return returns

    # Proximal Policy Optimization Algorithm
    # Arxiv: "https://arxiv.org/abs/1707.06347"
    def ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantage):
        batch_size = states.size(0)
        ids = np.random.permutation(batch_size)
        ids = np.split(ids[:batch_size // mini_batch_size * mini_batch_size], batch_size // mini_batch_size)
        for i in range(len(ids)):
            yield states[ids[i], :], actions[ids[i], :], log_probs[ids[i], :], returns[ids[i], :], advantage[ids[i], :]

    def ppo_update(ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantages, clip_param=0.2):
        for _ in range(ppo_epochs):
            for state, action, old_log_probs, return_, advantage in ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantages):
                dist, value = model(state)
                functional.reset_net(model)
                entropy = dist.entropy().mean()
                new_log_probs = dist.log_prob(action)

                ratio = (new_log_probs - old_log_probs).exp()
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage

                actor_loss  = - torch.min(surr1, surr2).mean()
                critic_loss = (return_ - value).pow(2).mean()

                loss = 0.5 * critic_loss + actor_loss - 0.001 * entropy

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    while step_idx < max_steps:

        log_probs = []
        values    = []
        states    = []
        actions   = []
        rewards   = []
        masks     = []
        entropy = 0

        for _ in range(num_steps):
            state = torch.FloatTensor(state).to(device)
            dist, value = model(state)
            functional.reset_net(model)

            action = dist.sample()
            next_state, reward, done, _ = envs.step(torch.max(action, 1)[1].cpu().numpy())

            log_prob = dist.log_prob(action)
            entropy += dist.entropy().mean()
            
            log_probs.append(log_prob)
            values.append(value)
            rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))
            masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))
            
            states.append(state)
            actions.append(action)
            
            state = next_state
            step_idx += 1
            
            if step_idx % 100 == 0:
                test_reward = test_env()
                print('Step: %d, Reward: %.2f' % (step_idx, test_reward))
                writer.add_scalar('Spiking-PPO-' + env_name + '/Reward', test_reward, step_idx)

        next_state = torch.FloatTensor(next_state).to(device)
        _, next_value = model(next_state)
        functional.reset_net(model)
        returns = compute_gae(next_value, rewards, masks, values)

        returns   = torch.cat(returns).detach()
        log_probs = torch.cat(log_probs).detach()
        values    = torch.cat(values).detach()
        states    = torch.cat(states)
        actions   = torch.cat(actions)
        advantage = returns - values
        
        ppo_update(ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantage)


It should be emphasized here that, we need to ``reset`` the network after each forward process,
because SNN is retentive while each trial should be started with a clean network state.

The integrated script can be found here `activation_based/examples/Spiking_PPO.py <https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/examples/Spiking_PPO.py>`_.
And we can start the training process in a Python Console as follows.

.. code-block:: python

    >>> python Spiking_PPO.py

Performance comparison between ANN and SNN
------------------------------------------------------
Here is the reward curve during the training process of 1e5 episodes:

.. image:: ../_static/tutorials/activation_based/\8_ppo_cart_pole/Spiking-PPO-CartPole-v0.*
    :width: 100%

And here is the result of the ANN version with the same settings.
The integrated code can be found here `activation_based/examples/PPO.py <https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/examples/PPO.py>`_.

.. image:: ../_static/tutorials/activation_based/\8_ppo_cart_pole/PPO-CartPole-v0.*
    :width: 100%