# Learning to Permute with Discrete Diffusion

We provide the code to reproduce results for Sorting 4-Digit MNIST Numbers (`"sort-MNIST"`) with $n\leq 100$ and Jigsaw Puzzle on CIFAR-10 (`"unscramble-CIFAR10"`). Complete code will be provided upon request or upon publication.

## How to Run

We use Python 3.10, PyTorch 2.1, and Cuda 11.8.

First install the required packages:
```bash
pip install -r requirements.txt
```

We use `json` files for setting configurations and hyperparameters. The config files that can reproduce our results in the `./reproduction` folder.

Here we also provide descriptions for the keys:
```json
{
    "CNN": { // CNN hyperparameters. For Jigsaw Puzzle, we only use the keys ending with 1.
        "hidden_channels1": 32,
        "hidden_channels2": 64,
        "in_channels": 3,
        "kernel_size1": 3,
        "kernel_size2": 5,
        "padding1": 1,
        "padding2": 1,
        "stride1": 1,
        "stride2": 1
    },
    "beam_search": true, // true: beam search; false: greedy search
    "beam_size": {
        "PL": 200, // beam size for GPL
        "time": 20 // beam size along the diffusion timesteps
    },
    "dataset": "sort-MNIST", // "sort-MNIST" | "unscramble-CIFAR10"
    "eval_only": false, // true if doing evaluation only
    "eval_batch_size": 64,
    "image_size": 28,
    "num_digits": 4, // 4-digit MNIST numbers
    "num_pieces": 52, // n
    "save_wrong_images": false,
    "train": {
        "resume": false,
        "batch_size": 64,
        "diffusion": {
            "latent": false,
            "num_timesteps": 13,
            "reverse": "generalized_PL",
            "reverse_steps": [0, 5, 6, 7, 10, 13],
            "transition": "riffle"
        },
        "entropy_reg_rate": 0.05, // used for REINFORCE only
        "epochs": 120,
        "learning_rate": 1e-5,
        "loss": "log_likelihood",
        "record_wandb": false, // whether or not to use wandb, default false
        "reinforce_N": 10, // number of MC in REINFORCE
        "reinforce_ema_rate": 0.995,
        "run_name": "sort-MNIST_n=52_[0,5,6,7,10,13]_42",
        "sample_N": 3, // number of MC trajectories when computing the loss
        "scheduler": "transformer", // "transformer" | "cosine-decay"; "transformer" is the one from the "Attention is All You Need" paper
        "warmup_steps": 51600 // warmup steps for "transformer"
    },
    "transformer": { // Transformer hyperparameters
        "d_hid": 512, // hidden dimension in feed-forward
        "dropout": 0.1,
        "embd_dim": 128, // d_model
        "n_layers": 7,
        "nhead": 8
    },
    "seed": 42
}
```

To run the code, use the following command
```bash
python3 main.py --config_json $PATH_TO_CONFIG_FILE --ckpt_dir $PATH_TO_CKPT_DIR
```
