import numpy
from joblib import Parallel, delayed
from itertools import product
import subprocess
import multiprocessing
import os
import argparse
os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"


def run_command(command):
    """
    Function to run a command.
    """
    subprocess.run(command, shell=True)
    

def create_all_commands(method):
    
    # SWEEP='sweep=TRAIN'

    BASE_ARGS="arch=mlp proj_name=scratch_train wandb=true task=random_MNIST task.num_tasks=30"
    
    base_command_str = f'python main_sl.py {BASE_ARGS} seed=0 agent.reset_network=true monitor_backward_transfer=false'

    optimizers = ['agent.optimizer=sgd agent.lr=0.01', 'agent.optimizer=adam agent.lr=0.001']
    
    
    commands_batchs = []

    if method == 'baseline':
    
        #commands_batchs += [f'{base_command_str} agent=baseline {optimizers[0]}']
        
        #commands_batchs += [f'{base_command_str} agent=cbp {optimizers[1]}']
        
        #commands_batchs += [f'{base_command_str} agent=crelu {optimizers[1]}']
        
        #commands_batchs += [f'{base_command_str} agent=deep_fourier {optimizers[0]}']
        
        #commands_batchs += [f'{base_command_str} agent=ewc_l2 {optimizers[0]}']
        
        commands_batchs += [f'{base_command_str} agent=ewc {optimizers[0]}']
        
        # commands_batchs += [f'{base_command_str} agent=l2_init {optimizers[0]}']
        
        # commands_batchs += [f'{base_command_str} agent=l2 {optimizers[0]}']
        
        # commands_batchs += [f'{base_command_str} agent=layer_norm {optimizers[0]}']
        
        #commands_batchs += [f'{base_command_str} agent=prelu {optimizers[0]}']
        
        #commands_batchs += [f'{base_command_str} agent=redo {optimizers[1]}']
    
    elif method == 'neuro':
        commands_batchs += [f'{base_command_str} agent=neuro_sync']
    else:
        raise Exception()

    print(commands_batchs)
    
    return commands_batchs

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run commands based on a given name.")
    parser.add_argument("--method", type=str, required=True)
    args = parser.parse_args()
   
    commands = create_all_commands(args.method)
    print("Num commands", len(commands))
    num_cores = 8
    print(f'*****{num_cores}*****')
    Parallel(n_jobs=num_cores)(delayed(run_command)(command) for command in commands)