"""
Configuration File for the Induction Heads Experiment 

This file acts as the central "Command Center" for running a comprehensive suite
of diagnostic tests on sequence models. By modifying the parameters in the
`dataset_config` section, you can systematically control the difficulty and
nature of the task to probe different aspects of model intelligence.
"""

from neuromamba.models.config_neuromamba import NeuroMambaConfig

# -----------------------------------------------------------------------------
# 1. Configuration for Training Loop (Generally Kept Constant Across Experiments)
# -----------------------------------------------------------------------------
training_config = {
    # From Mamba Paper Appendix B.2: "batch size of 8"
    "batch_size": 8,

    # From Appendix B.2: "Mamba learned... better at the larger LR of 1e-3"
    "learning_rate": 2e-4,

    # From Appendix B.2: "25th epoch (8192 × 25 = 204800 steps)"
    "num_steps": 204800,
    
    # An "epoch" is defined as 8192 steps. This is used for evaluation frequency.
    "eval_interval": 8192,

    # From Appendix B.2: "Adam optimizer with no weight decay"
    "weight_decay": 0.0,
}

# -----------------------------------------------------------------------------
# 2. Configuration for Induction Heads Dataset (THE MAIN CONTROL PANEL)
# -----------------------------------------------------------------------------
dataset_config = {
    # ----------------------------------------------------------------------
    #                     <<< EXPERIMENT CONTROL PANEL >>>
    # ----------------------------------------------------------------------
    
    # A) Choose the main difficulty level (0-4)
    #    0: Baseline - [P,A,B] clean triplets. Tests basic recall.
    #    1: Memory Robustness - Noise BETWEEN triplets [P,A,B,N,N,P,A,B]. Tests long-term memory.
    #    2: Abstract Pattern Recognition - Noise WITHIN triplets [P,N,A,N,B]. Tests non-local association.
    #    3: Combined Stress Test - Noise both WITHIN and BETWEEN.
    #    4: Autonomous Learning Suite - No prefix 'P'.
    "difficulty_level": 3,

    # B) If difficulty_level is 4, choose the sub-type. This is ignored otherwise.
    #    'none':    (Level 4.0) Sanity Check. Structure: [A,B,A,C,...]
    #    'between': (Level 4.1) Robust Discovery. Structure: [A,B,N,N,A,C,...]
    #    'conflict':(Level 4.2) Dynamic World Modeling. Structure: [A,B,N,N,A,C,...,A,...] -> predicts C
    "level_4_noise_type": 'none',
    
    # C) For difficulty levels 1, 2, 3, control the amount of noise.
    "max_noise_len": 4,

    # ----------------------------------------------------------------------
    #                  <<< QUANTITATIVE DIFFICULTY LEVERS >>>
    # ----------------------------------------------------------------------

    # Increase to test model's predictive fidelity and handle a larger output space.
    # Original paper used 16. A harder setting could be 512, 1024, or more.
    "vocab_size": 16,

    # Increase to test model's memory capacity. How many facts can it hold?
    # A harder setting could be 15, 20, or more.
    "num_induction_pairs": 3,

    # ----------------------------------------------------------------------
    #                        <<< TASK DEFINITION >>>
    # ----------------------------------------------------------------------
    
    # The sequence length used during training.
    "train_seq_len": 256,

    # A list of sequence lengths to evaluate on, for testing extrapolation.
    # The Mamba paper tested up to 2**20, which requires massive VRAM.
    # Adjust the range based on your hardware capabilities (e.g., range(6, 17) for up to 2^16).
    "eval_seq_lens": [2**i for i in range(6, 21)],
    # "eval_seq_lens": [2**i for i in range(6, 21)],
}

# -----------------------------------------------------------------------------
# 3. Configuration for Mamba Model (Architecture Definition)
# -----------------------------------------------------------------------------
neuromamba_config = NeuroMambaConfig(
    # From Section 4.2: "We use a model dimension D of 64 for Mamba"
    d_model=128,
    
    # From Section 4.2: "we use 2 layer models"
    n_layer=4,
    expand_gc=2,
    
    # This MUST match the dataset's vocabulary size.
    # Referencing the dictionary directly ensures consistency.
    vocab_size=dataset_config['vocab_size'],
    
    # These are standard Mamba settings, kept as defaults for a robust model.
    ssm_cfg={},
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    
    # It's common practice to pad the vocab size to a multiple of 8 for performance.
    pad_vocab_size_multiple=8,

    # Tying input and output embeddings is standard practice for language models.
    tie_embeddings=True,
)