import os
import itertools
import nts_notears
import nts_notears_multiply
import nts_notears_and
import argparse

# Mapping from algorithm name string to the imported module
ALGORITHM_MODULE_MAP = {
    'nts_notears': nts_notears,
    'nts_notears_multiply': nts_notears_multiply,
    'nts_notears_and': nts_notears_and
}

def run_experiments(algorithm_name_strings):
    # sequence_lengths reverted to original request
    sequence_lengths = [200, 1000]
    d_values = [20]
    sem_types = ['AdditiveIndexModel', 'AdditiveNoiseModel']
    numbers_of_lags = [3]
    seeds = [1, 2, 3, 4, 5, 6]
    exist_edges_probs = [0.2, 0.4, 0.6, 0.8, 1.0]
    base_result_folder = 'results'

    # Create the base results directory if it doesn't exist
    os.makedirs(base_result_folder, exist_ok=True)

    for alg_name_str in algorithm_name_strings:
        if alg_name_str not in ALGORITHM_MODULE_MAP:
            print(f"Warning: Algorithm '{alg_name_str}' is not recognized and will be skipped.")
            continue

        alg_module = ALGORITHM_MODULE_MAP[alg_name_str]
        
        print(f"Running experiments for: {alg_name_str}")
        # Use the algorithm name string for the subfolder
        alg_result_folder = os.path.join(base_result_folder, alg_name_str)
        os.makedirs(alg_result_folder, exist_ok=True)

        param_combinations = itertools.product(
            sequence_lengths,
            d_values,
            sem_types,
            numbers_of_lags,
            seeds,
            exist_edges_probs
        )

        for sl, d_val, sem, lags, seed_val, prob in param_combinations:
            sem_short_name = 'AIM' if sem == 'AdditiveIndexModel' else 'ANM'
            param_dir_name = f"sl{sl}_d{d_val}_sem{sem_short_name}_lags{lags}_seed{seed_val}_prob{prob:.1f}"
            current_result_path = os.path.join(alg_result_folder, param_dir_name)
            os.makedirs(current_result_path, exist_ok=True)

            print(f"  Params: sl={sl}, d={d_val}, sem={sem}, lags={lags}, seed={seed_val}, prob={prob}")
            print(f"  Results will be saved to: {current_result_path}")

            try:
                # Call the main function from the dynamically selected module
                alg_module.main(
                    sequence_length=sl,
                    d=d_val,
                    sem_type=sem,
                    number_of_lags=lags,
                    result_folder=current_result_path,
                    seed=seed_val,
                    exist_edges_prob=prob
                )
                print(f"  Finished: {alg_name_str} with params {param_dir_name}")
            except Exception as e:
                print(f"  Error running {alg_name_str} with params {param_dir_name}: {e}")
            print("-" * 30)
    print("All specified experiments finished.")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run experiments with different algorithms')
    parser.add_argument('--algorithms', nargs='+', 
                        default=['nts_notears', 'nts_notears_multiply', 'nts_notears_and'], 
                        help='Algorithms to run. Choose from nts_notears, nts_notears_multiply, nts_notears_and.')
    args = parser.parse_args()

    if args.algorithms:
        run_experiments(args.algorithms) # Pass the list of algorithm name strings directly
    else:
        # This case should ideally not be reached if nargs='+' and a default is provided,
        # as args.algorithms will be a list (possibly the default list).
        print("No algorithms specified to run. Please provide at least one using --algorithms.") 