# PAVI: Plate Amortized Variational Inference

## Install
To install this package run from this directory:
```bash
pip install .
``` 

## Reference

This repository relates to our preprint (under review):
> PAVI: Plate Amortized Variational Inference

## Directory organization
* subdirectory `pavi` contains our package
* subdirectory `examples` contains scripts to reproduce experiments (see `README` inside the directory)
* subdirectory `data` is a placeholder directory for the data produced by the scripts

### pavi.set_transformer

Provides a fully-parametrized Keras implementation of Set Transformers:
> Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks (Lee et al. 2019)

[link to the paper](http://proceedings.mlr.press/v97/lee19d.html)


### pavi.normalizing_flow

Contains tensorflow-probability utilities to construct chains of normalizing flow bijectors

### pavi.dual

Contains our main methodological contribution: the utilities to derive from a `generative_hbm` an architecture to perform inference.

Notably provides a TFP - Keras implementation of Cascading Flows:
> Automatic variational inference with cascading flows (Ambrogioni et al. 2021)

[link to the paper](https://arxiv.org/abs/2102.04801)

Also provides a TFP - Keras implementation of ADAVI:
> Automatic Dual Variational Inference (Rouillard et Wassermann 2021)

[link to the paper](https://arxiv.org/abs/2106.12248)

Also provides a TFP - Keras implementation of UIVI:
> Unbiased Implicit Variational Inference (Titsias and Ruiz 2018)

[link to the paper](https://arxiv.org/abs/1808.02078)

## Example usage
```python
import tensorflow_probability as tfp
from pavi.dual.models import PAVFFamily

tfd = tfp.distributions
tfb = tfp.bijectors

def get_hbm(G: int) -> tfd.Distribution:
    generative_hbm = tfd.JointDistributionNamed(
        model=dict(
            theta=tfd.Normal(
                loc=tf.zeros((1,)),
                scale=tf.ones((1,))
            ),
            X=lambda theta: tfd.TransformedDistribution(
                tfd.Sample(
                    distribution=tfd.Normal(loc=theta, scale=0.1),
                    sample_shape=(G,)
                ),
                tfb.Transpose([1, 0, 2])
            )
        )
    )

    return generative_hbm

full_hbm = get_hbm(G=10)
reduced_hbm = get_hbm(G=2)

hbm_kwargs = dict(
    full_hbm=full_hbm,
    reduced_hbm=reduced_hbm,
    plates_per_rv={
        "theta": ['P'],
        "X": ['P', 'G']
    },
    link_functions={...}
)

pav_family = PAVFFamily(
    posterior_schemes_kwargs={
        "theta": ("flow", {...}),
        "X": ("observed", {})
    },
    encodings_sizes={
        ('P',): 8,
        ('P', 'G'): 8
    },
    **hbm_kwargs
)

train_data = full_hbm.sample((100,))
val_datum = full_hbm.sample((1,))

pav_family.compile(
    train_method="reverse_KL",
    n_theta_draws=8,
    optimizer="adam"
)
pav_family.fit(train_data)
posterior_sample = (
    pav_family
    .sample(
        sample_shape=(32,),
        observed_values=val_datum
    )
)
```