# NeuroMamba Pretraining

This repository contains a robust and scalable framework for pretraining the **NeuroMamba** language model. It is designed for large-scale training on multi-source, web-scale datasets, with a strong emphasis on stability and fault tolerance.

## Framework Overview

This pretraining pipeline is built upon the Hugging Face `Trainer` API and incorporates several key features to handle the challenges of long-duration training runs on massive datasets:

1.  **Large-Scale Data Streaming**: The framework is designed to stream data directly from multiple sources (e.g., SlimPajama, The Stack, SkyPile) without needing to download the entire datasets locally.
2.  **Multi-Source Data Interleaving**: It seamlessly mixes datasets from different domains (prose, code, multilingual text) with configurable probabilities, allowing for the creation of a diverse and balanced pretraining corpus.
3.  **Robust Training Pipeline**: Two custom utility classes are included to prevent common failures during long training runs:
    -   `SafeIterableDataset`: A wrapper that prevents the entire training process from crashing due to a single corrupt or unreadable data sample.
    -   `HealthCheckCallback`: A `Trainer` callback that actively monitors model weights for numerical instability (NaNs or Infs) and automatically halts training if an issue is detected, saving valuable compute resources.

## Files

-   `pretrain.py`: The main **executable script** for launching the pretraining process. It handles dataset loading and interleaving, model and tokenizer initialization, and configures the Hugging Face `Trainer`.
-   `safe_dataset.py`: Contains `SafeIterableDataset`, a wrapper class that makes the data loading pipeline resilient to errors in individual data points. It logs warnings for corrupt samples and skips them instead of crashing.
-   `HealthCheckCallback.py`: Contains `HealthCheckCallback`, a custom callback that monitors the model's parameters for numerical stability (NaN/Inf values) at each step, stopping the training run early if instability is detected.

## How to Run Pretraining

### 1. Prerequisites

Ensure you have the necessary libraries installed:
```bash
pip install torch transformers datasets accelerate
```
This framework is designed for execution on powerful hardware. A GPU with support for `bfloat16` (e.g., NVIDIA Ampere series or newer) is highly recommended for optimal performance.

### 2. Configuration

Open `pretrain.py` to configure your training run. Key areas to adjust include:

-   **Datasets**: Modify the `interleave_datasets` call to change the data sources or their mixing probabilities.
-   **Model Configuration**: Adjust the `NeuroMambaConfig` to define your model's architecture (e.g., `hidden_size`, `num_hidden_layers`).
-   **Training Arguments**: Modify the `TrainingArguments` object to set crucial hyperparameters like `output_dir`, `per_device_train_batch_size`, `learning_rate`, and `max_steps`.

### 3. Execution

Launch the pretraining script from your terminal:
```bash
python pretrain.py
```
The script will initialize all components and begin the training process, managed by the Hugging Face `Trainer`.

### 4. Resuming from a Checkpoint

For long pretraining runs, the ability to resume is critical. The script is configured to save checkpoints periodically. To resume from a specific checkpoint:

1.  Locate the desired checkpoint directory (e.g., `.../NeuMa_140M/checkpoint-25000`).
2.  Modify the final line in `pretrain.py` to point to this directory:
    ```python
    # Before (for a new run, this line would be trainer.train())
    trainer.train(resume_from_checkpoint="/path/to/your/checkpoint-25000")
    ```
3.  Re-run the script: `python pretrain.py`. The `Trainer` will automatically load the model weights, optimizer state, and scheduler state from the checkpoint and continue training.

## Monitoring and Results

-   **Console Output**: The `Trainer` will provide live progress updates in your terminal, including the current step, loss, and learning rate.
-   **Checkpoints**: Model checkpoints will be saved to the specified `output_dir` at intervals defined by `save_steps`.
-   **Robustness Alerts**:
    -   If a corrupt data sample is encountered, a warning from `SafeIterableDataset` will be logged, and training will continue.
    -   If numerical instability (NaN/Inf) is detected, `HealthCheckCallback` will log a critical error detailing the problematic parameter and step, and then gracefully terminate the training process.