# config_ablation.py

"""
Configuration File for the Induction Heads Experiment with Ablation Controls.

This is the central "Command Center" for the ablation study.
To run a specific experiment, modify the ablation flags at the bottom of the
'neuromamba_config' section and then execute 'train_ablation.py'.

- To run the baseline (no ablation): set both ablate_gc and ablate_y2 to False.
- To run ablation on the 'gc' branch: set ablate_gc = True.
- To run ablation on the 'y2' branch: set ablate_y2 = True.
"""

from neuromamba.models.config_neuromamba import NeuroMambaConfig

# -----------------------------------------------------------------------------
# 1. Configuration for Training Loop
# -----------------------------------------------------------------------------
training_config = {
    "batch_size": 8,
    "learning_rate": 2e-4,
    "num_steps": 204800,
    "eval_interval": 8192,
    "weight_decay": 0.0,
}

# -----------------------------------------------------------------------------
# 2. Configuration for Induction Heads Dataset
# -----------------------------------------------------------------------------
dataset_config = {
    "difficulty_level": 2,
    "level_4_noise_type": 'none',
    "max_noise_len": 4,
    "vocab_size": 16,
    "num_induction_pairs": 3,
    "train_seq_len": 256,
    "eval_seq_lens": [2**i for i in range(6, 21)],
}

# -----------------------------------------------------------------------------
# 3. Configuration for NeuMa Model (Architecture Definition)
# -----------------------------------------------------------------------------

neuromamba_config = NeuroMambaConfig(
    d_model=128,
    n_layer=4,
    expand_gc=2,
    vocab_size=dataset_config['vocab_size'],
    ssm_cfg={},
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    pad_vocab_size_multiple=8,
    tie_embeddings=True,
)


# ==========================================================
# --- ABLATION FLAGS: EXPERIMENT CONTROL SWITCHES ---
# ==========================================================
neuromamba_config.ablate_gc = False  # <-- True to remove gc 
neuromamba_config.ablate_y2 = True  # <-- True to remove y2 


