[![Paper](https://img.shields.io/badge/Paper-arXiv-red)](https://arxiv.org/abs/2505.19619)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)

# SESaMo: Symmetry-Enforcing Stochastic Modulation for Normalizing Flows

## Quick start

Install the package with pip:
```bash
pip install sesamo
```

Here is a quick example of how to use SESaMo to build a normalizing flow with stochastic modulation:
```python
import torch
from sesamo import Sesamo
from sesamo.models import GaussianPrior, RealNVP, Z2Modulation, Z2Regularization
from sesamo.loss import StochmodLoss

# Initialize SESaMo
sesamo = Sesamo(
    prior=GaussianPrior(
        var=1,
        lat_shape=[1,2]
    ),
    flow=RealNVP(
        lat_shape=[1,2],
        num_coupling_layers=10,
        num_hidden_layers=2,
        num_hidden_features=40
    ),
    stochastic_modulation=Z2Modulation(),
    regularization=Z2Regularization(),
).to("cuda")

action = # define action for the target distribution p(x) = exp(-action(x)) / Z
loss_fn = StochmodLoss()
optimizer = torch.optim.Adam(sesamo.parameters(), lr=5e-4)

# Training loop
for _ in range(10_000):
    # reset gradients
    optimizer.zero_grad()

    # sample from sesamo
    samples, log_prob, log_prob_stochmod, penalty = sesamo.sample_for_training(8_000)
    
    # compute action and loss
    action_samples = action(samples)
    loss = loss_fn(action_samples, log_prob, log_prob_stochmod, penalty).mean()
    
    # backpropagate and update flow parameters
    loss.backward()
    optimizer.step()
```

### Examples

For more examples see the ```SESaMo/examples``` folder, which contains Jupyter notebooks for the Hubbard model and the Gaussian mixture model.


## Run experiments

To run the experiments from the paper, follow the instructions below.

Clone the repository and move into the directory:
```
git clone https://github.com/fifi-research/SESaMo.git
cd SESaMo
```

Create a python virtual environment and install the package:
```
python -m venv .venv
source .venv/bin/activate
pip install -e .
```

Run experiments with
```
cd experiments
python train.py -cp configs/<experiment> -cn <model>
```

Available ```<experiment>```s are:
```
hubbard2x1
hubbard18x100
gaussian-mixture
broken-gaussian-mixture
complex-phi4
broken-complex-phi4
broken-scalar-phi4
```

Available ```<model>```s are:
```
realnvp
vmonf
canonicalization
sesamo
```

The checkpoint, tensorboard, config and stats files are stored in the ```SESaMo/scripts/runs``` folder.
After training is completed or interupted the distribution is plotted and saved as ```SESaMo/scripts/runs/.../samples.png```



## Citation
If you use SESaMo in your research, please consider citing our paper:
```
@article{kreit2025sesamo,
    title={SESaMo: Symmetry-Enforcing Stochastic Modulation for Normalizing Flows}, 
    author={Janik Kreit and Dominic Schuh and Kim A. Nicoli and Lena Funcke},
    year={2025},
    eprint={2505.19619},
    archivePrefix={arXiv},
    primaryClass={cs.LG},
    url={https://arxiv.org/abs/2505.19619}, 
}
