# config.py
import numpy as np

# --- Simulation Configuration ---
simulation_config = {
    "total_simulations": 16,   # Total number of independent runs from the new script
    "base_seed": 200,          # Starting seed for reproducibility
}

# --- Training Configuration ---
training_config = {
    # "epochs": 500000,
    "epochs": 50000,
    "learning_rate": 1e-4,     # Learning rate from the new script
    "save_interval": 1000,       # Interval for saving metrics and evaluating
    # "save_interval": 10000, 
}

# --- Dataset Configuration ---
# Defines the structure of the cognitive task
dataset_config = {
    "num_trials": 100,         # Total trials per epoch
    "num_train_trials": 50,    # Trials used for training within an epoch
    # Template for trial type 0
    "trial1x": np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 4, 6, 1, 1, 1, 5, 5, 1, 1, 0]),
    # Template for trial type 1
    "trial2x": np.array([1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 4, 4, 1, 1, 1, 5, 6, 1, 1, 0]),
}
# Automatically determine trial length and vocabulary size from templates
dataset_config["trial_length"] = len(dataset_config["trial1x"])
dataset_config["vocab_size"] = len(np.unique(np.concatenate((
    dataset_config["trial1x"], dataset_config["trial2x"]
))))

# --- Model Configuration ---
# Configuration for the custom NeuroMamba model architecture from the new script
model_config = {
    "vocab_size": dataset_config['vocab_size'],
    "d_model": 256,
    "d_state": 16,
    "d_conv": 4,
    "d_conv_gc": 4,
    "expand": 2,
    "expand_gc": 2, # In the original script, this was set equal to 'expand'
}
# The hidden state size for correlation analysis is d_inner (d_model * expand)
model_config["hidden_size_for_corr"] = model_config["d_model"] * model_config["expand"]


# --- Paths Configuration ---
paths_config = {
    # Directory to save the .npy results and plots
    "output_dir_pattern": "%m%d", # e.g., "0604" for June 4th
    "plot_filename_pattern": "neuromamba_training_curves_{timestamp}.pdf",
    "npy_filename_pattern": "{metric}_neuromamba_{timestamp}.npy" # e.g., loss_all_neuromamba_...
}