# GraphFM

Official Implementation of GraphFM

> [!WARNING]  
> This repository is under construction. We provide code for training our largest GraphFM 
model on an 8-GPU cluster. We will update the repository to include extensive 
documentation, support for finetuning, more datasets, as well as weights for pretrained 
models.

### Setup

To set up a Python virtual environment with the required dependencies, run:
```
python3 -m venv graphfm_env
source graphfm_env/bin/activate
pip install --upgrade pip
```

Follow instructions to install 
[PyTorch 1.9.1](https://pytorch.org/get-started/locally/) and 
[PyG](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html):
```
pip install uv
uv pip install -U "ray[data,train,tune,serve]==2.40.0"
uv pip install torch==2.5.1 xformers==0.0.29.post1 --index-url https://download.pytorch.org/whl/cu121
uv pip install torch_geometric
uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.1+cu121.html
uv pip install ogb einops torchtyping torch-optimizer tabulate yacs pydantic torchmetrics wandb black prettytable
uv pip install optuna
uv pip install hydra-core
```

The code uses PyG (PyTorch Geometric).
All datasets are available through this package.


# Download Network Repo Datasets
```
cd network_repo_download
python regex.py
sh unzip.sh
```


# Preprocess Data
```
python preprocess_datasets.py --cfg configs/pretrain_model.yaml
```

# Pretrain Model
```
python main.py --cfg configs/pretrain_model.yaml
```


