# Classification Task

## Overview
[classification.py](https://github.com/ofsoundof/TimeFM/blob/add_documentation/tasks/classification_task.py) implements a fine-tuning task specifically for multi-class or binary classification problems. The `ClassificationTask` is built with PyTorch Lightning and leverages various on-the-fly data augmentations (white noise or spectrogram augmentations), diverse metrics (accuracy, AUROC, AUPR, etc.), and an optional layerwise learning rate decay strategy for the encoder.

---

## Key Components

### 1. **Initialization**

- **Encoder (`self.model`)**: A backbone (e.g., a Transformer or CNN) responsible for extracting features from the input.  
- **Classifier Head (`self.model_head`)**: A simple head for classification, producing logits over `self.num_classes`.
- **Freeze Backbone**: If `freeze_backbone=True`, the encoder weights are frozen during training.  
- **Augmentations**:  
  - White noise augmentation (`WhiteNoiseAugment`) can be applied to both raw waveforms and spectrograms.  
  - `SpecAugment` (masking frequencies/time bands) can be applied to spectrograms.  
- **Metrics**:  
  - Accuracy (macro-average)  
  - Balanced accuracy (macro-average of recall scores)  
  - AUROC  
  - AUPR (Average Precision)  
  - F1 Score  
  - Precision  

  These are instantiated for training, validation, and test sets, with different references (`train_acc`, `val_acc`, etc.).

### 2. **Training Step**

1. **Data Flow**: The batch contains:
   - `X = batch['input']`  
   - `y = batch['label']`
   - (Optionally, `batch['nr_padded_channels']` to track padded channels, but this script does not apply mask tokens.)
2. **Forward Pass**:
   - `self.model(X)`: The encoder processes the input (waveform or spectrogram).  
   - `self.model_head(...)`: The encoded features are passed to a classification head, producing logits (`y_preds_logits`).
3. **Loss**:
   - The script uses `self.criterion(y_preds_logits, batch)` to compute classification loss.  
4. **Metrics**:
   - Predictions are derived by `torch.argmax(y_preds_logits, dim=1)` for classification.  
   - Various metrics (accuracy, balanced accuracy, AUROC, AUPR) are computed using the raw logits or softmax probabilities, depending on the number of classes (binary vs. multiclass).
5. **Logging**:
   - Training loss is logged as `train_loss` (on both step and epoch). Metrics like accuracy, AUROC, AUPR, etc. are updated internally and then logged at the end of each epoch in `on_train_epoch_end`.

### 3. **Validation and Test Steps**

- **Validation** (`validation_step`) and **Test** (`test_step`) follow a similar flow to training:
  1. Extract inputs and labels from the batch.
  2. Forward pass through encoder + head.
  3. Compute metrics with the same approach (argmax for predicted class, softmax for probabilities).
  4. Log results (`val_loss` or `test_loss`) on the step and epoch level.

- **End-of-Epoch Logging**:  
  - `on_validation_epoch_end` logs metrics like `val_acc`, `val_balanced_acc`, `val_auroc`, `val_aupr`.
  - `on_test_epoch_end` logs test metrics (`test_acc`, `test_balanced_acc`, `test_auroc`, `test_aupr`, `test_precision`, `test_f1_score`), enabling quick reference for final performance.

### 4. **On-the-Fly Augmentations**

The `on_after_batch_transfer` hook applies data augmentations immediately after the batch is transferred to the GPU:
1. **White Noise**: Applied to the input if `self.training` is `True`.
2. **Spectrogram Conversion**: If `using_spectrogram=True`, the code applies STFT (or other transforms) on the fly before passing data to the encoder.
3. **SpecAugment**: If spectrogram is used, frequency/time masking is also applied for data augmentation.

### 5. **Layerwise Learning Rate Decay**

In `configure_optimizers`, the script configures a layerwise learning rate decay for the encoder if desired:
- `num_blocks = self.hparams.model.depth` determines how many layers or blocks exist in the encoder.  
- A base learning rate is multiplied by a decay factor (`layerwise_lr_decay`) for each subsequent layer.  
- The classification head parameters always receive the base learning rate.  

This approach can help stabilize training by allowing deeper layers to receive a smaller LR, while keeping earlier layers more broadly tuned.

### 6. **Optimizer and Scheduler**

`configure_optimizers` returns a dictionary containing:
- **Optimizer**: Depending on `self.hparams.optimizer.optim`, it can be `SGD`, `Adam`, `AdamW`, or `LAMB`.  
- **Scheduler**: The LR scheduler is instantiated via Hydra (`hydra.utils.instantiate(self.hparams.scheduler, ...)`) using the estimated number of training steps.  
- **LR Step**: Because some custom schedulers from `timm` require manual stepping (`scheduler.step_update`), `lr_scheduler_step` is overridden to handle incremental updates (`num_updates=self.global_step`).

### 7. **Checkpoint Loading**

`load_from_checkpoint` is overridden to:
- Strip away any classification head weights from the checkpoint if `head=False`.
- Potentially freeze backbone layers if `freeze_backbone=True`.  
- Print diagnostic messages about incompatible shapes and dropped parameters.

---

