# Domain Transfer becomes Identifiable with a Single Alignment


## Requirements

```bash
pip install torch torchvision numpy matplotlib pandas tqdm pyyaml wandb pillow
```

## Quick Start: MNIST Demo

### Step 1: Prepare the Dataset

The MNIST demo translates between original MNIST digits and rotated MNIST digits (90 degrees).

```bash
cd image_translation

# Prepare the rotated MNIST dataset
python utils/prepare_mnist_dataset.py --output_dir ./data/rotatedmnist 
```

This creates the following structure:
```
data/rotatedmnist/
├── trainA/          # Original MNIST 
├── trainB/          # Rotated MNIST 
├── testA/           # Original MNIST test 
├── testB/           # Rotated MNIST test 
├── trainA_attr.csv  # Labels for trainA
├── trainB_attr.csv  # Labels for trainB
├── testA_attr.csv   # Labels for testA
└── testB_attr.csv   # Labels for testB
```

### Step 2: Configure the Training

Edit `configs/mnist.yaml` and update the `data_path`:

```yaml
data_path: ./data/rotatedmnist  # Path to your prepared dataset
```

Key configuration options:
- `network_type: fcn` - Uses a fully-connected network (good for small images)
- `one_sided: True` - Only trains translation from domain A to domain B
- `jacobian_loss_w: 0.01` - Weight for Jacobian regularization (set to 0 to disable)
- `paired_loss_w: 1.0` - Weight for the anchor sample loss
- `use_wandb: True` - Enable Weights & Biases logging (set to False to disable)

### Step 3: Train the Model

```bash
python train.py --config configs/mnist.yaml
```

Training options:
- `--seed 42` - Set random seed (default: 42)
- `--debug` - Debug mode (disables W&B logging and uses minimal data)
- `--resume` - Resume training from checkpoint

Training outputs:
- `results/models/` - Model checkpoints
- `results/samples/` - Sample translations during training

### Step 4: Evaluate the Model

```bash
python train.py --config configs/mnist.yaml --eval --checkpoint_path results/models/mnist_demo/checkpoint-current.pt
```

This generates sample translations from the test set.

## Project Structure

```
image_translation/
├── configs/
│   ├── mnist.yaml                    # MNIST demo config (with Jacobian reg)
│   └── mnist_wo_jac_reg.yaml         # MNIST config without Jacobian reg
├── src/
│   ├── trainer.py                    # Main training logic
│   └── trainer_utils.py              # Model factory functions
├── utils/
│   ├── data_loader.py                # Dataset and data loading
│   ├── data.py                       # Image folder utilities
│   ├── losses.py                     # Loss functions (GAN, Jacobian, etc.)
│   └── prepare_mnist_dataset.py      # MNIST dataset preparation
├── model.py                          # Generator and discriminator architectures
├── train.py                          # Main training script
└── README.md
```

## Configuration Options

### Network Architecture

| Option | Values | Description |
|--------|--------|-------------|
| `network_type` | `fcn`, `cnn` | FCN for small images, CNN for larger images |
| `fcn_hidden_dim` | int | Hidden dimension for FCN (default: 1024) |
| `gen.use_adain` | bool | Use AdaIN in CNN generator |
| `gen.num_downsample` | int | Number of downsampling layers in CNN |

### Training

| Option | Description |
|--------|-------------|
| `train_iters` | Total training iterations |
| `batch_size` | Batch size for unpaired data |
| `lr` | Learning rate |
| `one_sided` | If True, only train A→B translation |

### Loss Weights

| Option | Description |
|--------|-------------|
| `gen_w` | Generator adversarial loss weight |
| `dis_w` | Discriminator loss weight |
| `recons_w` | Cycle reconstruction loss weight |
| `paired_loss_w` | Paired supervision loss weight |
| `jacobian_loss_w` | Jacobian regularization weight |
| `r1_reg_w` | R1 gradient penalty weight |

### Jacobian Regularization

| Option | Description |
|--------|-------------|
| `jacobian_loss_w` | Jacobian regularization weight (0 to disable) |
| `jacobian_norm_type` | `fd_l1` (sparse finite diff) or `l1` |
| `jacobian_num_samples` | Number of samples for stochastic estimation |
| `jacobian_probe_sparsity` | Sparsity of perturbation mask (for fd_l1) |
| `jacobian_sigma` | Perturbation magnitude std dev (for fd_l1) |

## Comparing With and Without Jacobian Regularization

To compare the effect of Jacobian regularization:

```bash
# Train with Jacobian regularization
python train.py --config configs/mnist.yaml

# Train without Jacobian regularization  
python train.py --config configs/mnist_wo_jac_reg.yaml
```
