# Code for EEG foundation model

This code repository is based on [BrainBERT](https://github.com/czlwang/BrainBERT), 
which is a modeling approach for learning self-supervised representations of intracranial electrode data and 
and [GRL code repository](https://github.com/ofsoundof/GRL-Image-Restoration/tree/main). 
See [BrainBERT paper](https://arxiv.org/abs/2302.14367) for details.

## Installation
```bash
conda create -n TimeFM python=3.11.5
conda activate TimeFM
pip install -r requirements.txt
```

## How to run pre-training
The following command runs pre-training for the large version (37 M params) of our model.

```bash
python -u run_train.py +experiment=waveECG_large.yaml

```

## How to run fine-tuning
The following command runs fine-tuning for the small version (5 M params) of our model.
```bash
python -u run_train.py +experiment=test_waveECGsmall.yaml

```

# General Introduction

## About [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/)
We use PyTorch Lightning to build this code base. 
The benefit of Lightning is that we do not need to write boilerplate codes. 
It also facilitates [distributed data parallel](https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_intermediate.html) with line of code.
PyTorch Lightning is also used to train some of the state-of-the-art models such as [stable diffusion](https://github.com/Stability-AI/stablediffusion/blob/main/scripts/txt2img.py).

## About [hydra](https://hydra.cc/docs/intro/)
In this repository, we use [hydra](https://hydra.cc/docs/intro/) to manage the [config](./config) files for running experiments. **The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line.**

In hydra, the configuration files for different purposes (e.g. data, model, loss, optimizer) are seperated and stored in different files. The basic structure of the config folder is as follows:

```bash
    config/
    ├── __init__.py                      # This is needed by hydra.
    ├── defaults.yaml                    # The default configuration file.
    ├── task                             # The task.
    ├── scheduler                        # The learning rate scheduler.
    ├── model                            # Model (encoder) configuration directory. 
    ├── model_head                       # Model head configuration directory. 
    ├── data_module                      # Data module configuration directory.
    ├── criterion                        # Loss configuraton directory.
    ├── experiment                       # Experiment configuraton directory.
    ├── experiments_ft_waveforms         # Finetuning experiment configuraton directory for our final model.
    ├── experiments_pt_waveforms         # Pre-training experiment configuraton directory for our final model.
    ├── experiments_scratch_waveforms    # Training from scratch experiment configuraton directory for our final model.
    └── ...
```
Configuration parameter override order:
CLI input -> experiment config -> other (data, model, loss) modular config -> defaults.yaml

Please check [hydra](https://hydra.cc/docs/intro/) for more information.

## Explanation of the important folders
```bash
    TimeFM/
    ├── config                   # All the hydra config files.
    ├── criterion                # Loss function. Default: baseline_criterion.py
    ├── data_module              # Data module directory. Default: eeg_data_module.py
    ├── datasets                 # Dataset directory. Default: custom_eeg_dataset.py
    ├── models                   # Model directory.
    ├── make_datasets                    # Directory for download and creation of EEG datasets. 
    ├── schedulers               # The learning rate scheduler. Default: multi_step_lr.py
    ├── tasks                    # The task. Default: base_task.py
    ├── scripts                  # Track command
    ├── utils                    # Some utilities (not used now).
    ├── preprocessors            # Directory for additional signal pre-processing (e.g. spectrogram computation).
    └── ...
```

PyTorch-Lightning removes many boilerplate code such as the training loop and validation loop. 
Instead, training step and validation step are defined in the `pl.LightningModule`. 
Details are given [`./tasks/base_task.py`](./tasks/base_task.py).


## How to add a new model?
1. Add the code of the model to [`./models`](./models`).
2. Add the configuration file of the model to [`./config/model`](./config/model`).

## How to start a new experiment with the added model?
1. Add experiment configuration file to [`./config/experiment`](./config/experiment`). 
    If you are interested, you may check the [hydra document about it](https://hydra.cc/docs/patterns/configuring_experiments/).
2. Override the default configurations in the added experiment configuration file.
3. Optionally, you can add a .md file under [`./scripts`](./script`) and store all your experiment commands there.
4. Run the command: `python run_train.py +experiment=eeg_demo`

## How to use distributed data parallel?
In your experiment configuration file, add the following arguments
```yaml
trainer:
  accelerator: gpu  # Using GPU
  num_nodes: ${num_nodes}  # The number of computing nodes
  devices: -1  # Automatically uses all GPUs available
  strategy: ddp  # distributed data parallel
```

## How to save GPU memory?
1. Try fairscale checkpointing first. Check [here](https://fairscale.readthedocs.io/en/stable/api/nn/checkpoint/checkpoint_activations.html) and [here](https://github.com/ofsoundof/GRL-Image-Restoration/blob/main/models/networks/grl.py#L134)
2. Use sharded training. Check [here](https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html).
