# Optimization on Stiefel manifold (optimization under orthonormal constraints)

This is the code for the paper [Momentum Stiefel Optimizer, with Applications to Suitably-Orthogonal Attention, and Optimal Transport](https://arxiv.org/abs/2205.14173) (ICLR 2023)

## What is Stiefel optimization
A Stiefel manifold $\mathsf{St}(n,m)$ is the set of all matrices satisfying $X^\top X=I$. For optimization on Stiefel manifold, if you can find the minimum value of a function $f(X)$ such that $X$ is a matrix with orthonormal columns, you are in the right place, i.e.,
$$\min_{X \in \mathsf{St}(n,m)} f(X)=\min_{X \in \mathbb{R}^{n \times m}, s.t. X^\top X=I} f(X),\qquad n\ge m.$$


Here is an example of the usage of Stiefel optimization.

### Example: Subspace pursuing 
When function evaluation in high-dimensional spaces is too costly, why not try it in a lower-dimensional subspace? 

Suppose we have a dataset $\lbrace x_i \rbrace_{i=1}^k$ with $x_i$ in $\mathbb{R}^n$. Then instead of evaluate our function $f(\lbrace x_i\rbrace_{i=1}^k)$ directly, we try to consider the optimization problem $$\max_{U\in St(n,m)} f(\lbrace U^\top x_i\rbrace_{i=1}^k).$$ We take the maximum of $U$ in the sense that the information is preserve as much as we can with the column of $U$ being a set of the orthonormal basis of the subspace. If we choose $m\ll n$, then this can save many computational resources as well as reduce noise. Please refer to Sec. 3.1 in our paper.



## Improve your model in a few lines
![Demo](./demo.gif)

### Example: Transformer model
Our experiment shows that simply applying orthonormal constraint to vanilla transformer will let it outperform later, fancier model(Table 1 in the paper). To apply this to your own model, all you need to do is to simply change the optimizer for each of the $W_i^Q$'s and $W_i^K$'s to our corresponding Stiefel SGD/Adam and use the same hyperparameters. Please refer to Sec. 3.2 in our paper for details. 
```python
# put the Euclidean and Stiefel parameters into 2 different list
for name, param in net.named_parameters():
    if 'q.weight' in name or 'k.weight' in name:
        torch.nn.init.orthogonal_(param) # optional
        stiefel_parameters.append(param)
    else:
        euclidean_parameters.append(param)
optimizer_euclidean=torch.optim.Adam(model.parameters)
optimizer_stiefel=StiefelAdam(stiefel_parameters)
optimizer=CombinedOptimizer(optimizer_euclidean, optimizer_stiefel)
```
By modifying just the above few lines, your model will be improved **WITHOUT tuning any hyperparameters**!

*Note: different implementations of the Attention layer may need different modification. See class Attention in [ViT.py](ViT.py) for details.*

## Details of the optimizers
[StiefelOptimizers.py](StiefelOptimizers.py) is the implementation of our proposed Momentum (S)GD and Adam on St(n,m) (Algorithm 1 and 2 in our paper. They can also be used on $\mathsf{SO}(n)$. [utils_StiefelOptimizers.py](utils_StiefelOptimizers.py) contains some auxiliary code. Please put both files in your path when using.

### Momentum (S)GD on Stiefel manifold
This corresponds to Algorithm 1 and can also be used for special case of $\mathsf{SO}(n)$ in Algorithm 4. See the following details:
```python
class StiefelSGD(params, lr=required, momentum=0.9, dampening=0, expm_method='ForwardEuler', inner_prod='Canonical', max_inner_iter=100)
```
Parameters:
- **params**: parameters to be optimized
- **lr**: learning rate
- **momentum** (float, optional): momentum factor (default: 0.9)
- **dampening** (float, optional): dampening for momentum (default: 0)
- **expm_method** (str in `['MatrixExp', 'Cayley', 'ForwardEuler']`, optional): method to compute matrix exponential. (default: `'ForwardEuler'`)
- **inner_prod**: (float <1 or string in `['Canonical', 'Euclidean']`, optional): Canonical-type metric (please refer to Definition 1 in the paper for details)
- **max_inner_iter** (int, optional): maximum number of iterations when computing matrix root inversion. (default: 100)

### Adam on Stiefel manifold
This corresponds to Algorithm 2 and can also be used for special case of $\mathsf{SO}(n)$ in Algorithm 5. See the following details:
```python
class StiefelAdam(params, lr=0.001, betas=(0.9,0.99), epsilon=1e-5, expm_method='ForwardEuler', inner_prod='Canonical', max_inner_iter=100)
```
Parameters:
- **params**: parameters to be optimized
- **lr** (float, optional): learning rate (default: 0.001)
- **betas** (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
- **expm_method**, **inner_prod** and **max_inner_iter** are same as StiefelSGD

Note:
- The only package needed for using this optimizer is *Pytorch*
- Both of the 2 optimizers inherit from `torch.optim.Optimizers` and have almost the same usage.
- There is no significant difference when further tuning **expm_method**, **inner_prod** and **max_inner_iter**. Default is good enough to use.
- We recommend to use the same hyperparameters when the model contains both Euclidean parameters and Stiefel parameters. See Remark 1 in the paper for details.
- The matrices being optimized should have number of rows $\ge$ number of columns . Otherwise, the matrix will be transposed without warning. For tensors with more than 2 dimensions, all the dimensions will be flattened excepted the first dimension in order to make it a matrix. 
- No special orthonormal initialization for Stiefel matrices is required. Commonly used element-wise random Gaussian matrices will work and our optimizer will automatically project it onto the Stiefel manifold. However, explicit initialization using `torch.nn.init.orthogonal_` is still recommended.

## Reproduce the experiments in the paper
First install packages using the following code: 
```
pip install -r requirements.txt
```
### Projection Robust Wasserstein Distance
Please check the folder ProjectionRobustWasserstein. Run [test_mnist.py](test_mnist.py) and [test_shakespeare.py](test_shakespeare.py) to reproduce the results and use [plot.ipynb](plot.ipynb) to visualize. 
(Modified from [official implementation of Projection Robust Wasserstein Distance](https://github.com/fanchenyou/PRW))
### Vision Transformer (ViT)
Please check the folder ViT. Run [ViT_main.py](ViT_main.py) and use arguments `--label-smoothing` and `--autoaugment` for every optimizer, constraint and dataset. For example: 
```
python ViT_main.py --optim-method StiefelSGD --dataset c10 --constraint OnlyWithin --label-smoothing --autoaugment
```

- `optim-method` should be chosen from `['SGD','Adam','RegularizerStiefelSGD', 'RegularizerStiefelAdam', 'ProjectedStiefelSGD', 'ProjectedStiefelAdam', 'StiefelSGD', 'StiefelAdam', 'MomentumlessStiefelSGD']`

- `constraint` should be chosen from `['Across', 'OnlyWithin', None]`

- `dataset` should be chosen from `['c10', 'c100']`

(Modified form the following repositary: [Training process](https://github.com/omihub777/ViT-CIFAR); [model implementation](https://github.com/lucidrains/vit-pytorch))
### Leading eigenvalue problem
Please run [LEV/LEV.ipynb](LEV/LEV.ipynb).


## Citation
Feel free to cite if you want to use these optimizers in your research!

	@inproceedings{kong2022momentum,
        title={Momentum Stiefel Optimizer, with Applications to Suitably-Orthogonal Attention, and Optimal Transport},
        author={Kong, Lingkai and Wang, Yuqing and Tao, Molei},
        booktitle={International Conference on Learning Representations},
        year={2023}
    }
