# run_all_neuromamba_experiments.py

import logging
import time
import copy
import torch
import os

# Import the main execution function from our refactored NeuroMamba script
from induction_heads_neuromamba_auto import execute_neuromamba_training_run

# Import the base configurations to be modified
from config import training_config, dataset_config, neuromamba_config

# --- 1. DEFINE THE NEUROMAMBA EXPERIMENT MATRIX ---

# A) expand_gc values to test
EXPAND_GC_VALUES = [1, 1.5, 2, 2.5, 3]

# B) Difficulty levels to test (same as before)
DIFFICULTY_LEVELS_TO_TEST = [
    {'level': 0},
    {'level': 1},
    {'level': 2},
    {'level': 3},
    {'level': 4, 'noise_type': 'none'},
    {'level': 4, 'noise_type': 'between'},
    {'level': 4, 'noise_type': 'conflict'},
]

# --- 2. SCRIPT EXECUTION LOGIC ---

def main():
    """
    Main function to iterate through all experiment combinations and run them.
    """
    # Setup logging to a file and console
    log_dir = "logs_neuromamba"
    os.makedirs(log_dir, exist_ok=True)
    log_filename = os.path.join(log_dir, f"neuromamba_suite_{time.strftime('%Y%m%d-%H%M%S')}.log")
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_filename),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger()
    
    total_experiments = len(EXPAND_GC_VALUES) * len(DIFFICULTY_LEVELS_TO_TEST)
    logger.info(f"Starting NeuroMamba experiment suite: {len(EXPAND_GC_VALUES)} expand_gc values x {len(DIFFICULTY_LEVELS_TO_TEST)} difficulties = {total_experiments} total runs.")
    logger.info(f"Master log for this suite is being saved to: {log_filename}")
    
    experiment_counter = 0
    suite_start_time = time.time()

    # Outer loop: Iterate over expand_gc values
    for gc_value in EXPAND_GC_VALUES:
        # Inner loop: Iterate over difficulty settings
        for diff in DIFFICULTY_LEVELS_TO_TEST:
            experiment_counter += 1
            
            # --- Prepare configurations for this specific run ---
            t_config = copy.deepcopy(training_config)
            d_config = copy.deepcopy(dataset_config)
            nm_config = copy.deepcopy(neuromamba_config)

            # a) Update NeuroMamba config with the current expand_gc value
            nm_config.expand_gc = gc_value
            
            # b) Update Dataset difficulty config
            d_config['difficulty_level'] = diff['level']
            if 'noise_type' in diff:
                d_config['level_4_noise_type'] = diff['noise_type']

            # --- Log and Execute ---
            level_str = f"Level {diff['level']}"
            if 'noise_type' in diff:
                level_str += f" (Subtype: {diff['noise_type']})"
                
            logger.info("="*80)
            logger.info(f"STARTING EXPERIMENT {experiment_counter} / {total_experiments}")
            logger.info(f"  - NeuroMamba Parameter: expand_gc = {gc_value}")
            logger.info(f"  - Difficulty: {level_str}")
            logger.info("="*80)

            try:
                execute_neuromamba_training_run(t_config, d_config, nm_config)
                
                if torch.cuda.is_available():
                    logger.info("Clearing CUDA cache...")
                    torch.cuda.empty_cache()
                    
            except Exception as e:
                logger.error(f"!!! EXPERIMENT {experiment_counter} FAILED !!!")
                logger.error(f"  - expand_gc: {gc_value}")
                logger.error(f"  - Difficulty: {diff}")
                logger.error(f"  - Error: {e}", exc_info=True)
                logger.info("Continuing to the next experiment...")
                continue

    suite_end_time = time.time()
    logger.info("="*80)
    logger.info("NEUROMAMBA EXPERIMENT SUITE COMPLETE")
    logger.info(f"Total time taken: {(suite_end_time - suite_start_time) / 3600:.2f} hours")
    logger.info("="*80)

if __name__ == '__main__':
    main()