This repo is mostly based on https://github.com/google-research/meliad. 

Install required packages into the python virtual environment. If you want to use GPUs, then Jax must be upgraded to use CUDA. Installing t5 after upgrading jax may be necessary to avoid link errors (we don't know why).

pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install t5
On Unix systems, you may need to ensure that PYTHONPATH includes the current directory. All module names are given relative to this root.

```
export PYTHONPATH=.:$PYTHONPATH
```

To train a lexinvariant model over the Wiki-40B with character level tokenization: 
```
python transformer/ht_main.py --gin_file=base_htrans.gin --gin_file=size/medium_150M.gin --gin_file=tasks/wiki40b_char.gin --gin_file=options/seq_512_nocache.gin --gin_file=options/window_512.gin --gin_file=random_all_ga.gin --workdir=<workdir> --run_name <run_name>
```

To train a lexinvariant model over the the Pile with character level tokenization (need to modify DATAPATH to point to a source of the Pile in tasks.py): 
```
python transformer/ht_main.py --gin_file=base_htrans.gin --gin_file=size/medium_150M.gin --gin_file=tasks/the_pile_char.gin --gin_file=options/seq_512_nocache.gin --gin_file=options/window_512.gin --gin_file=random_all_ga.gin --workdir=<workdir> --run_name <run_name>
```