# Trajectory General Mellowmax (TGM)
A `jax` implementation of the TGM algorithm. Contains:
- Algorithm implementations of
    - `TGM` (encompasses `GFN`)
    - `SAC`
    - `PPO`
- Code and data for the following synthetic and biological sequence design tasks:
    - `BitSequence`
    - Untranslated region (`UTR`)
    - Antimicrobial peptides (`AMP`)
    - Green fluorescent protein (`GFP`)
- Proxy training

![Demo of TGM training](media/run_example_cropped.gif)
<center><i>Example run of TGM.</i></center>

## Setup
To get started (currently tested for `Python 3.10`). Copy files, then:
```bash
# Or your favorite virtual env
python -m venv env && source env/bin/activate && pip install --upgrade pip
pip install -e tgm
python example.py
```

## Training
The easiest way to train a sampler, is to use the `run` function from `medium_rl.run` and pass it a `Config` object from `src/medium_rl/config.py`.

### Config
The `Config` objects contains general training options (e.g. sampling policy, minibatch size, evaluation, etc.). It also consists of the three following sub configurations:
- `EnvConfig`: Configuration for the environment
- `AlgConfig`: Configuration for the training algorithm
- `NetworkConfig`: Configuration for the neural network

To see available options, check `src/medium_rl/config.py`. Also already contains base configurations for each. They can then be composed as follows:
```python
from medium_rl.config import (
    AMPConfig,
    TGMConfig,
    BaseTransformerConfig,
    Config,
)
    env_cfg = AMPConfig() # Taking default options
    alg_cfg = TGMConfig(alpha=1, omega=1, q=0.75) # Changing some values
    network_cfg = BaseTransformerConfig(dropout=0.05)

    cfg = Config(
        env=env_cfg,
        alg=alg_cfg,
        network=network_cfg,
        reward_exp=64, # Change beta
        lr=1e-4, # Change lr
    )
```

### Run
Once the `Config` object is created, running training simply requires:
```python
from medium_rl.run import run

run(cfg)
```

## Training speed
On an L40s GPU, training for 100k samples should be quick:
- `BitSequence`: <3 minutes.
- `UTR`: <5 minutes.
- `AMP`: <5 minutes.
- `GFP`: <30 minutes.

## Environments
All environments are subclasses of `SequenceEnv` that describes a generic sequence generation DCG. Similarly to [PGX](https://github.com/sotetsuk/pgx), the core object is a `State` that contains information about the current sequence. Then, `SequenceEnv` defines `init`, `step` and `get_rewards` functions to initialize the state, step the state and get proxy rewards for a sequence.

For each of the biological sequence design tasks, the checkpoint for the proxy reward function is provided in `src/medium_rl/envs/proxies/<env_name>/proxy.pkl` and the validation mean/std in `src/medium_rl/envs/proxies/<env_name>/val_stats.pkl`

### BitSequence
Synthetic task described in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259).

### UTR
Sequence design task for the 5' UTR mRNA region that regulates transcription of the main coding sequence. Data to train the proxy comes from https://github.com/brandontrabucco/design-bench/issues/11#issuecomment-1912494126
 and consists of 250 000 sequences and their associated ribosome loads.

### AMP
Antimicrobial peptide design task. Proxy was trained as a binary classifier to predict whether a sequence is antimicrobial on a dataset of 9222 non-AMP sequences and 6438 AMP sequences from https://github.com/MJ10/clamp-gen-data/tree/master/data/dataset. The logit of the classifier is used as proxy reward.

### GFP
Green fluorescent protein design task. Data to train the proxy was sourced from https://github.com/brandontrabucco/design-bench/issues/11#issuecomment-1912494126 and consists of 56086 variations of the original GFP protein and their associated fluorescence.

## Custom environments
To create a custom environment, one can extend the `SequenceEnv` environment as follows:
```python
from medium_rl.envs.sequence_env import SequenceEnv

class NEW_ENV_NAMESequence(SequenceEnv):
    # Everything below needs to be specified
    name = "NEW_ENV_NAME"

    num_tokens = len(NEW_ENV_ALPHABET)
    alphabet = NEW_ENV_ALPHABET
    dict = {NEW_ENV_ALPHABET[i]: i for i in range(len(NEW_ENV_ALPHABET))}

    CLS = 0 # CLS or BOS token index
    PAD = 1 # PAD token index
    EOS = 2 # EOS token index

    def __init__(self, min_len: int, max_len: int, **kwargs):
        super().__init__(min_len, max_len)
        self.proxy = NEW_ENV_PROXY() # Initialize proxy if necessary

    def get_rewards(
        self,
        token_seq: Array,  # [B, T], batch of sequence of tokens
    ):
        # Need to specify/write proxy reward function that takes in a [B, T] array of token indexes
        # - B: Batch size
        # - T: Sequence length
        # and returns the proxy reward
        rewards = self.proxy.evaluate(token_seq)
        return rewards
```
Then, a `EnvConfig` can be specified as follows
```python
class NEW_ENVConfig(EnvConfig):
    name: str = "NEW_ENV"
    min_len: int = 5
    max_len: int = 10
    ...
```


## Proxy training
`src/medium_rl/envs/proxies/train_proxy.py` contains code for training proxy reward functions from data. `train_model` expects a `x` (a [N, T] array of token indexes) and `y` (a [N,] array of either floats to regress to or binary classes) as well as `model_cfg` specifying the hyperparameters of the network. See `train_proxies.py` for example uses.

## Acknowledgments
The biological environments are jax implementations with moderate modifications of the environments of [Biological Sequence Design with GFlowNets
](https://github.com/MJ10/BioSeq-GFN-AL) as well as the benchmarks of [Design-Bench: Benchmarks for Data-Driven Offline Model-Based Optimization](https://github.com/brandontrabucco/design-bench). The training process for the proxy reward functions comes from the former and the data used from the latter. The BitSequence environment comes from [Trajectory balance: Improved credit assignment in GFlowNets
](https://arxiv.org/abs/2201.13259). The design of the `SequenceEnv` environment is inspired by the [PGX library](https://github.com/sotetsuk/pgx).

<!-- ## Citation -->
