# ViTMAETask

## Overview
[vitmae_pretrainin.py](https://github.com/ofsoundof/TimeFM/blob/add_documentation/tasks/vitmae_pretraining.py) implements a [ViTMAE](https://arxiv.org/abs/2111.06377) pre-training paradigm. 
In this task, a Vision Transformer (ViT) encoder processes only *visible tokens* after random masking. A lightweight decoder then reconstructs the original input from the encoder outputs and the masked tokens. The overall objective is to learn high-quality representations via reconstructing missing patches.

---

## Key Components

### 1. **Initialization**

- **Encoder (`self.model`)**: A ViT-style Transformer that receives partially masked tokens.  
- **Decoder (`self.model_head`)**: A lightweight Transformer that receives:
  1. The visible token embeddings (from the encoder).
  2. Mask token embeddings inserted at masked positions.
  3. Index restoration mapping to reconstruct the original order.
- **Masking & Pad Tokens**:
  - `self.mask_token`: Special embedding used to fill the positions of masked patches.
  - `self.pad_token`: Used for padding, where needed.
- **Spectrogram Option**: If `using_spectrogram=True`, an STFT transform is applied before patch embedding.

### 2. **Training Step**

1. **Patch Embedding**:
   - `x = self.model.patch_embed(X)`: Converts the input `X` into tokens (patch embeddings).
2. **Masking**:
   - `self.prepare_tokens(...)` randomly masks a subset of patches.
   - Produces masked embeddings (`x`), a mask indicator (`token_mask`), and an attention mask (`attn_mask`) if padding is present.
3. **Encoder**:
   - `latent = self.model(x, directly_input_tokens=True, attn_mask=attn_mask)`: Processes only unmasked tokens.
4. **Decoder**:
   - `pred = self.model_head(latent, ids_restore)`: Uses the encoder output plus mask tokens at the masked positions to predict the original patches.
5. **Loss & Logging**:
   - The `patchify` function transforms the original input into patches for ground truth.  
   - `self.criterion(pred, batch)` computes the reconstruction loss.  
   - Training loss is logged as `train_loss`.

### 3. **Validation Step**

Similar to training:
1. `x = self.model.patch_embed(X)`
2. Randomly mask tokens via `prepare_tokens(...)`
3. Encoder processes visible tokens -> `latent`
4. Decoder reconstructs -> `pred`
5. Compute & log `val_loss`

### 4. **On-the-Fly Spectrogram Transformation**

If `self.using_spectrogram` is set:
- In `on_after_batch_transfer`, the waveform is transformed into a spectrogram (`self.transform(batch['input'])`).

### 5. **Optimizer & Scheduler**

`configure_optimizers`:
- Combines encoder and decoder parameters into a single list (`params_to_pass`).
- Chooses an optimizer (`SGD`, `Adam`, `AdamW`, `LAMB`) based on `hparams`.
- Instantiates a scheduler from Hydra configuration.  
- `lr_scheduler_step` calls `scheduler.step_update(num_updates=self.global_step)` if the scheduler requires manual per-step updates.

### 6. **Checkpoint Loading**

`load_from_checkpoint`:
1. Uses `get_params_from_checkpoint` to strip away any decoder-specific parameters if `head=False`.  
2. Matches encoder parameters in the current model with those from the checkpoint; logs and skips mismatches.  
3. Optionally freezes the encoder (`self.freeze_backbone`) to keep pretrained weights fixed.

### 7. **Masking Mechanism**

- **`prepare_tokens`**:  
  1. Optionally replaces padded positions with `self.pad_token`.  
  2. Calls `_mask_tokens` to perform random token masking.
- **`_mask_tokens`**:
  - Delegates masking logic to `random_masking`, which:
    - Generates random noise per sample and sorts by ascending noise values.
    - Keeps only a fraction of tokens (`1 - mask_ratio`) for the encoder.
    - Returns masked tokens, a binary `mask` (0=kept, 1=masked), and `ids_restore` for unshuffling.
- **Decoder Behavior**:
  - The decoder re-inserts `mask_token` in place of masked patches to reconstruct the full original input.

This approach follows the MAE principle: train a ViT encoder on visible (unmasked) tokens only, then let a lightweight decoder learn to reconstruct from the partial latent representation plus mask tokens. 
