# BOSW: Learning Sliced Wasserstein Projections via Adaptive Bayesian Optimization

This repo contains only our implementation of BOSW and hybrids that plug into the public Quasi-SW codebase.

> **Upstream dependency:** all experiments were run by integrating our methods into the official Quasi-SW repo:
> `https://github.com/khainb/Quasi-SW.git`
> We do not re-host their code here. See “Setup” below for submodule/clone options.

---

## Repo structure

```
.
├─ bosw.py
├─ qsw.py
└─ utils.py
```

## File summaries

-   `utils.py`
    Helper functions for BOSW and hybrids.
-   `qsw.py`
    QSW family implementations.
-   `bosw.py`
    Core of our method. Two main entry points:
    -   `get_bosw_projections(...)` — Plain BOSW (acquisition-driven projection selection). Used directly for BOSW and for the RBOSW refresh variant.
    -   `get_bosw_adaptive(...)` — Hybrid/Adaptive BO that returns a projection sampler object with `.sample()`. Powers ABOSW (adaptive) and ARBOSW (adaptive + periodic restart).

---

## How we tested

For each experiment from the Quasi-SW repo (Approximation Error, Gradient Flow, Color Transfer, Autoencoder, etc.), we kept their pipeline intact and only swapped the direction-sampling step to call our functions. Concretely:

1.  Clone Quasi-SW.
2.  Import our functions where directions are sampled.
3.  Replace their random/QMC sampler with:
    -   `theta = get_bosw_projections(...)` for BOSW/RBOSW, or
    -   a persistent `projection_sampler = get_bosw_adaptive(...)` and `theta = projection_sampler.sample()` for ABOSW/ARBOSW.

We did not change their losses, training loops, or evaluation code—just the projection selection.

---

## Setup

You can depend on Quasi-SW as follows:

```bash
git clone https://github.com/khainb/Quasi-SW.git
cd Quasi-SW/
```

Then install Quasi-SW’s requirements (and your usual PyTorch/CUDA stack):

```bash
pip install -r requirements.txt
```

---

## API at a glance

### Plain BOSW (also used by RBOSW)

```python
from bosw import get_bosw_projections

theta = get_bosw_projections(
    L, device, pc1=X, pc2=Y, p=2,
    acq_kind='ei',      # or 'ucb', 'ts' depending on your ablations
    beta=0.7,           # exploration parameter (e.g., UCB/PI/EI variants)
    n_init=min(L//2, 64),
    batch_size=max(1, L//20),
    n_candidates=4096,
    seed=seed
)
```

### Hybrid / Adaptive BO (ABOSW & ARBOSW)

```python
from bosw import get_bosw_adaptive

projection_sampler = get_bosw_adaptive(
    L=L, device=device, pc1=X, pc2=Y, p=2,
    beta=0.7, seed=seed, ai="ucb",
    task_type='gradient',
    gradient_steps=25   # internally plans to learn projections over this horizon
)

# call per-iteration to get fresh projections
theta = projection_sampler.sample()
```

---

## Usage patterns in our experiments

Below are the exact integration snippets we used for RBOSW (refresh) and ARBOSW (adaptive with restarts). Insert these into the Quasi-SW training loop where directions are chosen.

### RBOSW (refresh every R steps)

```python
# Refresh BOSW directions every R steps using current X
if i % R == 0:
    with torch.no_grad():
        theta = get_bosw_projections(
            L, device, X.detach(), Y, p=2,
            acq_kind='ei', beta=0.7,
            n_init=min(L//2, 64),
            batch_size=max(1, L//20),
            n_candidates=4096,
            seed=seed*10_000 + i
        )
# ... use theta for the next step(s)
```

**What this does:** keeps a single BOSW selector but refreshes the learned projections periodically using the current iterate `X`. This adapts to nonstationary objectives while retaining stability between refreshes.

### ARBOSW (adaptive BO with periodic restarts)

```python
# Restart projection learning every 25 steps
if i == 0 or i % 25 == 0:
    projection_sampler = get_bosw_adaptive(
        L=L, device=device, pc1=X.detach(), pc2=Y,
        p=2, beta=2.0, seed=seed + i, ai="ucb",  # Different seed for restart
        task_type='gradient',
        gradient_steps=25  # Learn for next 25 steps
    )

# Sample projections (will be different each time due to randomization)
theta = projection_sampler.sample()
# ... use theta this step
```

**What this does:** maintains an adaptive sampler that is re-initialized on a schedule (every 25 steps here). Between restarts, `.sample()` produces projections guided by the current surrogate/acquisition state.
