# Masked Autoencoder (MAE) Pretraining Task

## Overview
[simmim_mae_pretraining.py](https://github.com/ofsoundof/TimeFM/blob/add_documentation/tasks/simmim_mae_pretraining.py) implements [SimMIM](https://arxiv.org/abs/2111.09886)-style Masked Autoencoding, where the encoder consumes masked and unmasked tokens from a given input and produces a latent representation which the decoder decodes to obtain a reconstructed of the original input.

---

## Key Components

### Training and Validation Steps

1. **Patch Embedding**:
   - Given an input tensor representing a batch of waveforms/spectrograms which has been patchified, this function produces a patch embedding for these patches.
3. **Prepare Tokens**:
      The `prepare_tokens` function prepares tokens for processing by the encoder. It calls the `_mask_tokens`, which calls the  `random_masking`function which performs the masking and returns a binary mask and indices to restore masked input to its original unmasked form. See the functions for details.
   
  
5. **Encoder (model)**:
   - Consumes the masked input (containing both masked and unmasked tokens) and produces a latent representation of these objects.
     
7. **Decoder (model_head)**:
   - Consumes the latent representation and restoration ids in order to procide a reconstrution.

9. **Loss computation**:
    - Using the decoder reconstruction, the `criterion` is called to compute a loss between reconstruction and original input.
      
11. **Logging**:
    - We log the training and validation loss as well as logging a few (original, reconstruction) pairs to be visualized in TensorBoard.
13. **Extras**:
    - GPU-accelerated augmentations can be included on-the-fly in the `on_after_batch_transfer`hook. For our pre-training we do not use any augmentations, but this can potentially help.

### Configuring Optimizer
`configure_optimizer` includes management of the scheduler and optimizer (see [PyTorch Lightning documentation](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers) for details).

Since we use a custom [Timm](https://timm.fast.ai/Optimizers) optimizer, we need an additional `lr_scheduler_step` function that tells the PyTorch Trainer object when to update the learning rate.

