
# Welcome to Scaling L2O!

# Installation 

Run the following code:
```
apt-get update
apt install tmux vim rsync htop -y
tmux

mkdir l2o_install
cd l2o_install

wget https://repo.anaconda.com/miniconda/Miniconda3-py311_24.7.1-0-Linux-x86_64.sh
bash Miniconda3-py311_24.7.1-0-Linux-x86_64.sh -b -p $PWD/miniconda3
source $PWD/miniconda3/bin/activate

cd learned_optimization
pip install -e .

cd vision_transformer
pip install -e . 

cd ..
pip install mmengine seqio wandb
pip install orbax-checkpoint==0.3.2
pip install torch==2.2.0
conda install -y -c conda-forge openmpi=4.1.2
conda install -y -c conda-forge mpi4py openmpi
conda install -y nccl

pip install git+https://github.com/haydn-jones/SOAP_JAX
pip install aiofiles
pip install gin-config optax_shampoo
pip install numpy==1.24.3

# if the above does not work, try this
# conda install -c conda-forge mpi4py openmpi

pip install gin-config optax_shampoo
pip install "jax[cuda12]==0.6.0"
pip install -U dm-haiku chex flax
pip install git+https://github.com/google-deepmind/optax.git
pip install numpy==1.26

# change the following as is appropriate for your environment
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/include
export WANDB_API_KEY=

export MASTER_NODE=$HOSTNAME
export MASTER_PORT=12345

git config --global credential.helper 'cache --timeout=172800'
```

# Setup (env variables that you need to set) 

```
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/include
export WANDB_API_KEY=

export MASTER_NODE=$HOSTNAME
export MASTER_PORT=12345

export TFDS_DATA_DIR=$PWD/data
export WANDB_DIR=$PWD/wandb
```


# Distributed PES meta-training of MuP small_fc_mlp:
```
mpirun -np 8 --allow-run-as-root --oversubscribe  --bind-to none bash -c 'OMP_NUM_THREADS=64 CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_RANK python src/main.py \
--config config/meta_train/meta_train_base.py,\
config/learned_optimizer/muadafacmlplopt.py,\
config/inner_length_schedule/constant_100_1000.py,\
config/schedule/warmup_cosine_decay.py,\
config/optimizer/adamw.py,\
config/gradient_transform/after/none.py,\
config/gradient_transform/before/clip.py \
--cfg_options \
gradient_transform_before_optim.0.kwargs.max_delta=3.0 \
es_std=0.01 \
schedule.kwargs.peak_value=0.003 \
schedule.kwargs.end_value=0.001 \
schedule.kwargs.decay_steps=9900 \
schedule.kwargs.warmup_steps=100 \
--num_tasks 1 \
--master_node $MASTER_NODE \
--master_port $MASTER_PORT \
--local_batch_size 4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
--train_project mup-meta-training \
--optimizer mup_small_fc_mlp \
--needs_state \
--steps_per_jit 1 \
--name_suffix _mulo_d=3_orig \
--prefetch_batches 20 \
--truncation_length 50 \
--num_outer_steps 5000 \
--pmap_pes_across_devices \
--task "mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d16_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d16_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3" \
--auto_resume'
```


# Distributed PES meta-training of MuP VeLO:
```

mpirun -np 8 \
  --allow-run-as-root \
  --map-by ppr:8:node \
  --bind-to none \
  --oversubscribe \
  -x CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' \
  bash -c 'CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_RANK OMP_NUM_THREADS=64 python src/main.py \
--config config/meta_train/meta_train_base.py,\
config/learned_optimizer/muvelo_h=4_lstm=16_inits=8_long.py,\
config/inner_length_schedule/constant_100_1000.py,\
config/schedule/warmup_cosine_decay.py,\
config/optimizer/adamw.py,\
config/gradient_transform/after/clip_by_global_norm.py,\
config/gradient_transform/before/none.py \
--cfg_options \
gradient_transform_after_optim.0.kwargs.max_norm=1.0 \
es_std=0.01 \
schedule.kwargs.peak_value=0.003 \
schedule.kwargs.end_value=0.001 \
schedule.kwargs.decay_steps=99900 \
schedule.kwargs.warmup_steps=100 \
optimizer_args.kwargs.weight_decay=0.001 \
optimizer_args.kwargs.b1=0.9 \
optimizer_args.kwargs.b2=0.999 \
optimizer_args.kwargs.eps=1e-11 \
--num_tasks 1 \
--master_node $MASTER_NODE \
--master_port $MASTER_PORT \
--local_batch_size 4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
4096 4096 4096 \
--train_project mup-meta-training \
--optimizer muHyper \
--needs_state \
--steps_per_jit 1 \
--name_suffix _muvelo_d=3_eps_1000_t1_orig \
--prefetch_batches 20 \
--truncation_length 20 \
--num_outer_steps 100000 \
--pmap_pes_across_devices \
--task "mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3,\
mumlp-w128-d3_imagenet-32x32x3,\
mumlp-w512-d3_imagenet-32x32x3,\
mumlp-w1024-d3_imagenet-32x32x3"'

```


# Optimizing Meta-Training Programs
Tweaking the hyperparameters of a meta-training program can have a significant impact on the iteration speed and memory consumption of meta-training. Here are some tips for optimizing meta-training programs:

## Optimizing meta-training memory usage when using `custom_preload_tfds_image_classification_datasets`
- `prefetch_batches`: when using this preloading function, each task keeps a buffer of samples on the GPU to avoid waiting for CPU-GPU transfers during meta-training. The prefect_batches variable therefore controls how much GPU memory will be used by the buffer.


## Opitmizing meta-training efficiency and memory usage
- `num_tasks`: this variable controls the number of perturbations to the learned optimizer's weights sampled in the gradient estimator. During meta-training, 1 unroll is performed for each perturbation (2 if antethetic sampling is used). 1 Batch of data is required per optimization step per task.
- `steps_per_jit`: This variable controls the number of unrolling steps performed within the jitted unroll_step function. Since data cannot be sampled within jitted functions, this has the effect of also specifying the amount of data we are required to load before each call to `unroll_step`. The total amount of data is `steps_per_jit * num_tasks * batch_size`. If antithetic sampling is used (as is the case for PES), this quantity should be multiplied by 2. 


# Config file structure

Using MMengine's config file parser, we can write config files directly in Python and use an inheritance config structure to avoid redundant configurations. This can be achieved by specifying config files to inherit from using the 
```_base_=['my_config.py']``` 
special variable at the top of config files. More information is available at [mmengine config docs](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html).

In learned_aggragation, configuration files are logically separated into different directories based on the task to be executed: ```config/meta_test```,```config/meta_train```, and ```config/sweeps```. 

# Setting up a sweep
To sweep over the hyperparameters of a model during meta-testing, one can simply specify a sweep configuration using the ```sweep_config``` variable.


# Checkpointing during meta training
The ```checkpoints_to_keep``` and ```save_iter``` config variables control the number of checkpoints that should be kept and the checkpointing multiple, respectively. Default values of ```checkpoints_to_keep=10``` and ```save_iter=1000``` ensure that at most 10 previous checkpoints will be kept and that a checkpoint will be saved every 1000 iterations.

# Loading from a checkpoint during meta-training

When a checkpoint is logged, it is saved under ```checkpoints/<meta-train-dir>``` where ```<meta-train-dir>``` is the dynamically assigned meta-train-name. Whenever a new checkpoint is logged, a file called ```latest``` is updated with the name of the most recent checkpoint. When resuming from a checkpoint the user simply has to set the ```--from_checkpoint``` flag and meta training will automatically resume to the checkpoint specified in the ```latest``` file.


# Fine-tuning a pre-treained optimizer



# Usefull information about VeLO, which is printed on startup
```
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/Adam.learning_rate = 0.0003
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/build_gradient_estimators.gradient_estimator_fn = @FullESOrPMAP
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/build_gradient_estimators.sample_task_family_fn = @april28_distribution_bigger
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/FullES.loss_type = 'last_recompute'                                                                                                                                                      opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/FullES.recompute_samples = 100
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/FullES.sign_delta_loss_scalar = 1.0
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/FullES.truncation_schedule = @LogUniformLengthSchedule()
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/gradient_worker_compute.extra_metrics = False
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/GradientAccumulator.num_average = 20
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/GradientAccumulator.opt = @Adam()
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/GradientClipOptimizer.opt = @GradientAccumulator()
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/GradientLearner.init_theta_from_path =     'jul18_continue_on_bigger_2xbs_morestale_9264/params'
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/GradientLearner.meta_init = @HyperV2()
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/GradientLearner.reset_outer_iteration = True
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/GradientLearner.theta_opt = @GradientClipOptimizer()
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/HyperV2.lstm_hidden_size = 512
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/HyperV2.param_inits = 256
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/HyperV2.use_bugged_loss_features = False
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/LogUniformLengthSchedule.max_length = 200000
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/LogUniformLengthSchedule.min_length = 200
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/periodically_save_checkpoint.time_interval = 60
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/PMAPFullES.truncation_schedule = @LogUniformLengthSchedule()
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/run_train.lopt = @HyperV2()
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/run_train.num_estimators = 8
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/run_train.num_steps = 100000
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/run_train.outer_learner_fn = @GradientLearner                                                                                                                                            opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/run_train.run_num_estimators_per_gradient = 1
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/run_train.staleness = 500
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/run_train.stochastic_resample_frequency = 200
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/run_train.summary_every_n = 25
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/run_train.trainer_batch_size = 512
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/VectorizedLOptTruncatedStep.num_tasks = 8
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/VectorizedLOptTruncatedStep.random_initial_iteration_offset = 0
opt_from_checkpoint__6cf1d6ba_d295_4f96_88f3_ca14cdaf0da9/VectorizedLOptTruncatedStep.trunc_sched = @NeverEndingTruncationSchedule()
```
