### Jax to PyTorch Conversion

This directory is used to convert Jax weights to PyTorch. 

First `cd jax_weights` and download the models. Then return here and run `convert.py`. The weights will be saved to `weights/`.

