# Welcome to Scaling L2O!

# Installation 

Run the following code:
```
mkdir l2o_install
cd l2o_install

wget https://repo.anaconda.com/miniconda/Miniconda3-py39_24.5.0-0-Linux-x86_64.sh
bash Miniconda3-py39_24.5.0-0-Linux-x86_64.sh -b -p $PWD/miniconda3
source $PWD/miniconda3/bin/activate

cd scaling_l2o
pip install -r requirements.txt

cd ..
git clone https://github.com/google/learned_optimization
cd learned_optimization
pip install -e .

cd ..
git clone https://github.com/google-research/vision_transformer
cd vision_transformer
git checkout ac6e056
pip install -e . 

cd ../scaling_l2o
pip install mmengine seqio wandb
pip install -U dm-haiku chex flax
pip install optax==0.1.7
pip install "jax[cuda12]==0.4.26"
conda install -c conda-forge openmpi=4.1.2


# change the following as is appropriate for your environment
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/include
export TFDS_DATA_DIR=/scr/data/tensorflow_datasets
export WANDB_DIR=$PWD/wandb
```

# Single GPU meta-training 

Here are some example commands for running single-GPU meta-training.

```
#meta-training velo on a toy fashion MNIST task

OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --bind-to none python src/main.py \
--config config/meta_train/learned_optimizers/velo_h=32_lstm=4_inits=1.py \
--num_tasks 8 \
--local_batch_size 128 \
--train_project mup-meta-training \
--optimizer MuHyperV2 \
--needs_state \
--steps_per_jit 5 \
--name_suffix _mup_Hyper_DEBUG \
--prefetch_batches 20 \
--adafac_step_mult 0.01 \
--truncation_length 50 \
--task "mumlp-w128-d2_fashionmnist-8x8x1"

# Meta-training settings of Mu-LO 

OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --bind-to none python src/main.py \
--config config/meta_train/schedules/mxlr=3e-3_mnlr=1e-3_it=5000_clip.py \
--num_tasks 8 \
--local_batch_size 4096 4096 4096 \
--train_project mup-meta-training \
--optimizer mup_small_fc_mlp \
--needs_state \
--steps_per_jit 2 \
--name_suffix _mulos-muvit_it=5000_mxlr=3e-3_stepm=01_tasks=24 \
--task "mumlp-w128-d3_imagenet-32x32x3,mumlp-w512-d3_imagenet-32x32x3,mumlp-w1024-d3_imagenet-32x32x3" \
--prefetch_batches 20 \
--adafac_step_mult 0.01
```



# Multi-GPU Meta-Training 

Here are some example commands for running multi-GPU meta-training.

```
# MuP RNN
OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=5,6,7 mpirun -np 3 --bind-to none python src/main.py \
--config config/meta_train/schedules/mxlr=3e-3_mnlr=1e-3_it=5000_clip.py \
--num_tasks 8 \
--local_batch_size 4096 4096 4096 \
--train_project mup-meta-training \
--optimizer MuRNNMLPLOpt \
--needs_state \
--steps_per_jit 2 \
--name_suffix _mup_RNN_distributed \
--prefetch_batches 10 \
--adafac_step_mult 0.01 \
--truncation_length 50 \
--task "mumlp-w1024-d3_imagenet-32x32x3,mumlp-w512-d3_imagenet-32x32x3,mumlp-w128-d3_imagenet-32x32x3"

#SP RNN
OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=5,6,7  mpirun -np 3 --bind-to none python src/main.py \
--config config/meta_train/schedules/mxlr=3e-3_mnlr=1e-3_it=5000_clip.py \
--num_tasks 8 \
--local_batch_size 4096 4096 4096 \
--train_project mup-meta-training \
--optimizer RNNMLPLOpt \
--needs_state \
--steps_per_jit 2 \
--name_suffix _sp_RNN_distributed \
--prefetch_batches 10 \
--adafac_step_mult 0.01 \
--truncation_length 50 \
--task "mlp-w1024-d3_imagenet-32x32x3,mlp-w512-d3_imagenet-32x32x3,mlp-w128-d3_imagenet-32x32x3"
```

# Evaluate Learned Optimizers that have been saved to WANDB

```
#MLP ImageNet
OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --bind-to none python src/main.py \
    --config config/meta_test/meta_test_base.py \
    --name_suffix _single-task-lopt_no-clip \
    --local_batch_size 1024 \
    --test_project mup-meta-testing \
    --task "vit-w1024-d3_imagenet-32x32x3" \
    --optimizer small_fc_mlp \
    --wandb_checkpoint_id WANDB_ORG/mup-meta-training/woz3g9l0 \
    --num_runs 10 \
    --num_inner_steps 2000 \
    --gradient_accumulation_steps 1 \
    --needs_state \
    --test_interval 50

#ViT ImageNet
OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --bind-to none python src/main.py \
    --config config/meta_test/meta_test_base.py \
    --name_suffix _single-task-lopt_no-clip \
    --local_batch_size 1024 \
    --test_project mup-meta-testing \
    --task "vit-w1024-d3_imagenet-32x32x3" \
    --optimizer small_fc_mlp \
    --wandb_checkpoint_id WANDB_ORG/mup-meta-training/woz3g9l0 \
    --num_runs 10 \
    --num_inner_steps 2000 \
    --gradient_accumulation_steps 1 \
    --needs_state \
    --test_interval 50

# Transformer LM
OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --bind-to none python src/main.py \
    --config config/meta_test/meta_test_base.py \
    --name_suffix _mu-lo-m_only_lm \
    --local_batch_size 128 \
    --test_project mup-meta-testing \
    --task "mutransformer-w2048-d3_lm1b-s64-v32k" \
    --optimizer mup_small_fc_mlp \
    --wandb_checkpoint_id WANDB_ORG/mup-meta-training/woz3g9l0 \
    --num_runs 5 \
    --num_inner_steps 5000 \
    --gradient_accumulation_steps 1 \
    --needs_state \
    --test_interval 50 \
    --adafac_step_mult 0.01 \
    --use_bf16
```

## Launching a Sweep (This example sweeps Hyperparameters for AdamW on a 1024 MLP Task)

```
OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --bind-to none python src/main.py \
    --config config/sweeps/mulo_sweep_adamw.py \
    --name_suffix _adamw_sweep \
    --local_batch_size 4096 \
    --test_project mup-meta-testing \
    --task "mlp-w1024-d3_imagenet-32x32x3" \
    --optimizer adamw \
    --num_runs 1 \
    --learning_rate 3e-4 \
    --num_inner_steps 1000 \
    --gradient_accumulation_steps 1 \
    --needs_state \
    --mup_input_mult 1 \
    --mup_output_mult 1 \
    --mup_hidden_lr_mult 1 \
    --test_interval 50
```
# Optimizing Meta-Training Programs
Tweaking the hyperparameters of a meta-training program can have a significant impact on the iteration speed and memory consumption of meta-training. Here are some tips for optimizing meta-training programs:

## Optimizing meta-training memory usage when using `custom_preload_tfds_image_classification_datasets`
- `prefetch_batches`: when using this preloading function, each task keeps a buffer of samples on the GPU to avoid waiting for CPU-GPU transfers during meta-training. The prefect_batches variable therefore controls how much GPU memory will be used by the buffer.


## Opitmizing meta-training efficiency and memory usage
- `num_tasks`: this variable controls the number of perturbations to the learned optimizer's weights sampled in the gradient estimator. During meta-training, 1 unroll is performed for each perturbation (2 if antethetic sampling is used). 1 Batch of data is required per optimization step per task.
- `steps_per_jit`: This variable controls the number of unrolling steps performed within the jitted unroll_step function. Since data cannot be sampled within jitted functions, this has the effect of also specifying the amount of data we are required to load before each call to `unroll_step`. The total amount of data is `steps_per_jit * num_tasks * batch_size`. If antithetic sampling is used (as is the case for PES), this quantity should be multiplied by 2. 


# Config file structure

Using MMengine's config file parser, we can write config files directly in Python and use an inheritance config structure to avoid redundant configurations. This can be achieved by specifying config files to inherit from using the 
```_base_=['my_config.py']``` 
special variable at the top of config files. More information is available at [mmengine config docs](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html).

In learned_aggragation, configuration files are logically separated into different directories based on the task to be executed: ```config/meta_test```,```config/meta_train```, and ```config/sweeps```. 

# Setting up a sweep
To sweep over the hyperparameters of a model during meta-testing, one can simply specify a sweep configuration using the ```sweep_config``` variable.


# Checkpointing during meta training
The ```checkpoints_to_keep``` and ```save_iter``` config variables control the number of checkpoints that should be kept and the checkpointing multiple, respectively. Default values of ```checkpoints_to_keep=10``` and ```save_iter=1000``` ensure that at most 10 previous checkpoints will be kept and that a checkpoint will be saved every 1000 iterations.

# Loading from a checkpoint during meta-training

When a checkpoint is logged, it is saved under ```checkpoints/<meta-train-dir>``` where ```<meta-train-dir>``` is the dynamically assigned meta-train-name. Whenever a new checkpoint is logged, a file called ```latest``` is updated with the name of the most recent checkpoint. When resuming from a checkpoint the user simply has to set the ```--from_checkpoint``` flag and meta training will automatically resume to the checkpoint specified in the ```latest``` file.


# Fine-tuning a pre-treained optimizer
