from neuromamba.models.config_neuromamba import NeuroMambaConfig

# Configuration for training
training_config = {
    "batch_size": 64,
    "learning_rate": 0.0001,
    "num_steps": 400000
}

# Configuration for dataset
dataset_config = {
    "l_noise": 4096,           # number of padding tokens
    "l_memorize": 16,          # number of tokens to memorize
    "n_tokens": 16,            # alphabet size
    "lag": False,
    "variable": True,          # Randomly distribute memorization tokens throughout sequence instead of frontloading them
    "variable_length": False,  # Randomize number of tokens to memorize
    "one_hot": False,
    "reverse": False,
    "static": False,
}

# Configuration for NeuMa model
neuma_config = NeuroMambaConfig(
    d_model=18,
    n_layer=2,
    expand_gc=2,
    vocab_size=dataset_config['n_tokens'],
    # ssm_cfg=dict(layer="NeuroMamba"),
    ssm_cfg=dict(),
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    pad_vocab_size_multiple=1,
    tie_embeddings = False
)