# SON-GOKU: Graph Coloring for Multi-Task Learning
This repository contains the code implementation for the SON-GOKU method introduced in "Graph Coloring for Multi-Task Learning".

It provides a downloadable Python package to apply SON-GOKU to your own applications (and even on top of existing MTL methods). It also contains the code for downloading, loading, and preprocessing all the datasets that are used in the paper. We provide training scripts for SON-GOKU and its ablations that can be applied to all the datasets.

The core implementation is agnostic to what framework you are using. It can be used with JAX, Tensorflow, PyTorch, etc. We do provide a PyTorch integration and use PyTorch for training scripts, though.

## Installation
From a local version of this repository you can build and install the package with:

```bash
pip install .
```

Just make sure your command line is in the root directory of the repository.

## Usage Examples

### 1) Framework agnostic scheduling
```python
import numpy as np
from son_goku import SonGokuScheduler

K, d = 4, 512
sched = SonGokuScheduler(num_tasks=K, grad_dim=d, refresh_period=32)

for step in range(100):
    active = sched.next_active_set()
    grads = np.random.randn(len(active), d).astype("float32")
    sched.update_ema(active, grads)
    sched.step_finished()
    if sched.should_refresh():
        sched.refresh()
```

### 2) PyTorch multi task training with a shared backbone
```python
import torch
from son_goku import SonGokuScheduler
from son_goku.torch_integration import shared_gradient_vector

sched = SonGokuScheduler(num_tasks=3, grad_dim=100000, refresh_period=16)

for batch in loader:
    tasks = sched.next_active_set()
    shared.zero_grad()
    for k in tasks:
        loss_k = compute_loss(k, shared, heads[k], batch)
        loss_k.backward(retain_graph=True)
        gk = shared_gradient_vector(shared.parameters())
        sched.update_ema([k], gk[None, :])
    optimizer.step()
    sched.step_finished()
    if sched.should_refresh():
        sched.refresh()
```

### 3) Running a pre-made experiment script
```bash
# CIFAR-10 with auxiliary tasks
python experiments/train_cifar10.py --download --epochs 10

# AV-MNIST
python experiments/train_avmnist.py --download --epochs 5
```
