<h2 align="center">ReinMax</h2>
<h4 align="center"> Beyond Straight-Through</h4>

<p align="center">
  <a href="#st">Straight-Through</a> •
  <a href="#reinmax">ReinMax</a> •
  <a href="#how-to-use">How To Use</a> •
  <a href="#examples">Examples</a>
</p>


ReinMax achieves **second-order** accuracy and is **as fast as** the original Straight-Through, which has first-order accuracy.

<h3 align="center" id="st"><i>What is Straight-Through</i></h4>

Straight-Through (as below) bridges discrete variables (`y_hard`) and back-propagation. 
```python
y_soft = theta.softmax()

# one_hot_multinomial is a non-differentiable function
y_hard = one_hot_multinomial(y_soft) 

# with straight-through, the derivative of s_hard will
# act as if you had `p_soft` in the forward
y_hard = y_soft - y_soft.detach() + y_hard 
```
It is a long-standing mystery on how straight-through works, lefting doubts on many problems like whether we should use:
- `y_soft - y_soft.detach()`,
- ` (theta/tau).softmax() - (theta/tau).softmax().detach()`,
- or what?



<h3 align="center" id="reinmax"><i>Understand Straight-Through and Go Beyond</i></h3>

We reveal that Straight-Through works as a special case of the forward Euler method, a numerical methods with first-order accuracy. 
Inspired by Heun's Method, a numerical method achieving second-order accuracy without requiring Hession or other second-order derivatives, we propose ReinMax, which *approximates gradient with second-order accuracy with negligible computation overheads.*

### **How to use?**

`reinmax` can be installed via `pip`

```
pip install --editable .
```

To replace Straight-Through Gumbel-Softmax with ReinMax: 

```diff
from reinmax import reinmax
...
- y_hard = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True)
+ y_hard, _ = reinmax(logits, tau) # note that reinmax prefers to set tau >= 1, while gumbel-softmax prefers to set tau < 1
...
```

To replace Straight-Through with ReinMax:
```diff
from reinmax import reinmax
...
- y_hard = one_hot_multinomial(logits.softmax()) 
- y_soft_tau = (logits/tau).softmax()
- y_hard = y_soft_tau - y_soft_tau.detach() + y_hard 
+ y_hard, y_soft = reinmax(logits, tau) 
...
```
### **Examples**

#### **Polynomial Programming**
Following the previous study (Tucker et al., 2017; Grathwohl et al., 2018; Pervez et al., 2020;
Paulus et al., 2021), let us start with a simple and classic problem, polynomial programming. 

The implementation for this problem is available at the `poly` folder. 

#### **MNIST-VAE**
We also benchmarked the performance by training variational auto-encoders (VAE) with
categorical latent variables on MNIST. 

The implementation for MNIST-VAE is available at the `mnist_vae` folder. 

#### **ListOps**
For unsupervised parsing on ListOps, the implementation is available at the `listops` folder. 
