## Marginal Flow
Marginal flow is a **universal density approximator**, 
meaning that it can approximate any well-behaved distribution arbitrarily well.
The variational family $q_\theta(x)$ is defined as a Mixture of $N_c$ distributions $q(x|w_{\theta,i})$ 
parameterized by $w_{\theta,i}$ (e.g. for Gaussians mixtures one choice could be $w_{\theta,i}=\mu(\theta_i)$, hence $q(x|w_{\theta,i})=\mathcal{N}(x|\mu(\theta),\sigma)$ ):

$$q_\theta(x) := \frac{1}{N_c} \sum_{i=1}^{N_c} q(x|w_{\theta,i}) \quad \text{where} \quad w_{\theta,i} := f_\theta(z_i) \quad \text{with} \quad z_i \sim p_{\text{base}}(z) $$

where $f_\theta$ is any neural network and $p_{\text{base}}$ any base distribution of choice.
Note that $\dim(z)$ doesn't have to be the same as $\dim(x)$.
Relevantly, the parameters $w_{\theta,i}$ that define the distribution $q(x|w_{\theta,i})$ are re-sampled each time we evaluate or sample from $q_\theta(x)$, 
effectively rendering a different mixture each time. 

Unlike most density estimation models, Marginal flow is efficient both at training and at inference.
In both cases we first need to sample the parameters $w_{\theta,i}$:
- $z_i \sim p_{\text{base}}(z) \rightarrow w_{\theta,i} = f_\theta(z_i)$

Then, we can either evaluate $q_\theta(x)$ or sample from it:
- **evaluate** $q_\theta(x)$: simply evaluate the pdf $q_\theta(x) := \frac{1}{N_c} \sum_{i=1}^{N_c} q(x|w_{\theta,i})$
- **sampling** from $q_\theta(x)$: sample from the mixture model, 
i.e. sample index $j$ (with replacement!) $j\sim\{1,...,N_c\}$ and then $x \sim q(x|w_{\theta,j})$

## Installation

We recommend installation with [uv](https://docs.astral.sh/uv/#installation).
Once installed simply run:
```console
uv sync
```

## Implementation details
We hereby go through a basic use-case of margflow.
A more detailed example can be found in [this jupyter notebook](experiments/toy_example.ipynb).

We first define the dataset we want to train our model on:

```python
import argparse
from margflow.datasets.datasets import create_dataset

# dataset parameters
seed = 1234
device = "cuda" # or "cpu"
dataset_name = "mog" # "mog", "two_moons", "swiss_roll", "checkerboard" "two_circles", "pinwheel"
n_dim = 2 # for all datasets above
n_samples_dataset = 2_000
fixed_datapoints = False # if False, training has infinite samples (re-sample at each iteration)
epsilon = 0.05 # Gaussian noise added to the data

# if dataset_name == "mog" specify number of mixtures and standard deviation of each mixture
n_mog = 10
mog_sigma = 0.1

# collect all parameters as args
args = argparse.Namespace(
    dataset=dataset_name,
    n_samples_dataset=n_samples_dataset,
    n_samples_dataset_test=n_samples_dataset,
    epsilon=epsilon,
    script_path="",
    device=device,
    x_dim=n_dim,
    n_mog=n_mog,
    mog_sigma=mog_sigma,
    seed=seed
)

dataset = create_dataset(args=args)
```
We now define the proposed model Marginal flow. 
In almost all our setting we define the variational family $q_\theta(x)$ as a mixture of isotropic Gaussians,
i.e. $q(x|w_{\theta,i})=q(x|\mu(\theta_i),\sigma)$, 
but other choices are also possible.
Furthermore, in order to easily learn multi-modal distributions, we define the base distribution as a Mixture of Gaussians with `n_mixture_base` components.

```python
from margflow.marginal_flow import MarginalFlow
from margflow.trainer import train_marginal_flow

# problem parameters
x_dim = n_dim # space where the distribution is defined
n_mixture_base = 10 # by default the base distribution is a mixture of Gaussians with n_mixture_base components
base_dim = x_dim  # can be any int <= x_dim
marginal_flow = MarginalFlow(x_dim=x_dim,
                             z_dim=base_dim,
                             n_base_means=n_mixture_base,
                             device=device)
```

We can now train the model on the dataset by maximising the log-likelihood 
(other objective functions are also possible but not yet implemented)

```python
# training parameters
n_epochs = 1000
batch_size = 2000

#model parameters
n_samples = 1024 # how many samples to take from marginal flow (at each evaluation/sampling step)
n_mixtures = 1024 # how many mixture to use in the marginal flow definition (at each evaluation/sampling step)
lr_network = 5e-4
lr_sigma = 5e-2

train_marginal_flow(
    model=marginal_flow,
    n_mixtures=n_mixtures,
    n_samples=n_samples,
    n_epochs=n_epochs,
    batch_size=batch_size,
    dataset=dataset,
    lr_network=lr_network,
    lr_sigma=lr_sigma,
    fixed_datapoints=fixed_datapoints,
    save_best_val=False)
```
We can now sample from the trained model and evaluate the learnt (log-) pdf:
```python
n_samples = 10_000
n_mixtures = 2048 # can be any integer number, also different from the one used during training
mflow_samples = marginal_flow.sample(n_samples=n_samples, n_mixtures=n_mixtures)
logp_samples = marginal_flow.log_prob(x=mflow_samples, n_mixtures=n_mixtures)
```

## Other baseline models
In order to compare Marginal flows with other baseline density estimation models,
in `margflow/other_models/` we provide implementation of the following models:
- [flow matching](https://arxiv.org/abs/2210.02747), implemented as in the [official repo](https://github.com/facebookresearch/flow_matching)
- [normalizing flows](https://arxiv.org/abs/1908.09257) implemented as in [FlowConductor](https://github.com/FabricioArendTorres/FlowConductor)
- [free-form flows](https://arxiv.org/abs/2310.16624) implemented as in the [official repo](https://github.com/vislearn/FFF)

All implementation share the abstract class structure found in `margflow/abstract_model/`.
Specifically, all models have the following methods: `model.sample()` and `model.log_prob()`

