# Neural Data Transformer 3: A Foundation Model for Motor BCI

## Getting started

### Setup
We recommend setting up with setup.py (the env.ymls lists a dump of an active environment, but setup.py lists the core dependencies)
```
conda/mamba create --name torch2 python=3.11
pip install -e . --extra-index-url https://download.pytorch.org/whl/cu118
```

NDT3 is currently backed with FlashAttn, which requires Ampere NVIDIA GPU architecture. You will need to additionally run:
`pip install flash-attn --no-build-isolation` after the above pip install.
If this fails due to a packaging error, try:
python -m pip install setuptools==69.5.1
(https://github.com/pytorch/serve/issues/3176)
If flagging nvcc not available, see https://github.com/Dao-AILab/flash-attention/issues/509
### < Ampere-level GPUs
NDT3 was prepared with Ampere arch, and utilizes flash-attention 2, which also does not support older architectures. Limited codebase functionality may be possible by downgrading to flash-attention 1.0.9, though full pretrained models are not available (since they were trained in bfloat16).
- Will require rotary_emb package manual install: https://github.com/Dao-AILab/flash-attention/issues/160

### Data Setup
Datasets and checkpoints are expected to go under `./data`.
Install NDT2-public datasets with the following command; for troubleshooting,individual dataloaders in `tasks` have specific instructions.
```
. data_scripts/install_datasets.sh
```
Several datasets needs specific data processing libs, which can be done with `pip install -r data_scripts/data_requirements.txt`.


## Running an experiment (or fine-tuning)
Models are trained or tuned according to experimental configs. See, for example, any yaml under `./context_general_bci/config/exp/`. The config system is based on hydra, composing by default with `context_general_bci/config/config_base.py`. The only difference between pretraining and fine-tuning is whether the model is specified to inherit a checkpoint in the configuration.

Experiment tracking is done Weights and Biases, which should be set up before runs are launched (please follow wandb setup guidelines and configure your user in `config_base`.) If you are running this on shared workstations, the configured account is likely REDACT's.

Provided all paths are setup, start a given run with:
`python run.py +exp/<EXP_SET>=<exp>`.
e.g. to run the experiment configured in `context_general_bci/config/exp/arch/base/f32.yaml`: `python run.py +exp/arch/base=f32`.

You can launch in SLURM-controlled clusters via `sbatch ./launch.sh +exp/<EXPSET>=<exp>`, or any of the `launch` scripts. The directives should be updated accordingly. Please note there are several config level mechanisms (`inherit_exp`, `inherit_tag`) in place to support experiment loading inheritance, that is tightly coupled to the wandb checkpoint logging system.
A whole folder can be launched through slurm with `python launch_exp.py -e ./context_general_bci/config/exp/arch/base`.
Note for slurm jobs, I trigger the necessary env loads with a `load_env.sh` script located _outside_ this repo, which then point back into the samples provided (`load_env, load_env_crc.sh`), feel free to edit these to match your local environment needs.

### Hyperparameter Sweeps
Configurations for hyperparameter sweeping can be configured, see e.g. `exp/arch/tune_hp`. Only manual grid or random searches are currently implemented.

## Checkpoints
NDT3 is a foundation model meant for off the shelf tuning.
Currently, code is in development, so checkpoints are unstable -- please ask REDACT for a checkpoint appropriate for your use case if you intend to use NDT3.

## Codebase Overview
This codebase provides tooling for the entire lifecycle of NDT3 creation and deployment. The main package decomposes asfollows:
- `context_general_bci/{subjects, tasks, contexts}`: Infrastructure for metadata management for model use and analysis.
- `context_general_bci/config`: Pretraining and fine-tuning experiment management. Uses [Hydra](https://hydra.cc/) for configuration management.
- `context_general_bci/rtndt`:  Deployment of model to a realtime system. Opinionated and a work in progress; mostly prepared to satisfy abstractions in closed-source REDACT BCI infrastructure.
- `context_general_bci/rl`: V0 of RL for closed loop fine tuning.
Users may also be interested in referencing `scripts` to see example model evaluation. Many of the import statements under `scripts` assumes you have installed this package in editable mode, i.e. `pip install -e .`. This codepath is not fully reproducible/documented.

## Troubleshooting

### Model does not torch.compile
Specific error: `stdlib.h: No such file or directory`.
- Make sure you have build dependencies: `sudo apt-get install build-essential`
<!-- This is a GCC issue, I think, but GCC is available. -->
