import numpy
from joblib import Parallel, delayed
from itertools import product
import subprocess
import multiprocessing
import os
os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
def run_command(command):
    """
    Function to run a command.
    """
    subprocess.run(command, shell=True)
    
def generate_commands(base_command, params):
    """
    Generate a list of commands based on the base command and parameters.
    """
    keys = list(params.keys())
    values = list(params.values())
    # Generate all combinations of parameters
    combinations = list(product(*values))
    commands = []
    for combination in combinations:
        command = base_command
        for key, value in zip(keys, combination):
            command += f" {key}={value}"
        commands.append(command)
    return commands

def create_all_commands():
    
    # SWEEP='sweep=TRAIN'

    BASE_ARGS="arch=mlp proj_name=normal_train wandb=true task=shuffle_cifar10"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'
    
    commands_batchs = []
    
    # CBP
    commands_batchs += [f'{base_command_str} agent=cbp seed=2 agent.optimizer=adam agent.lr=0.001']
    
    # EWC
    commands_batchs += [f'{base_command_str} agent=ewc seed=0 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=ewc seed=0 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=ewc seed=1 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=ewc seed=1 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=ewc seed=2 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=ewc seed=2 agent.optimizer=sgd agent.lr=0.01']
    
    # EWC + L2Init
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=0 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=0 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=1 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=1 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=2 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=2 agent.optimizer=sgd agent.lr=0.01']
    
    # ReDO
    commands_batchs += [f'{base_command_str} agent=redo seed=0 agent.optimizer=sgd agent.lr=0.01']
    
    #CReLU
    commands_batchs += [f'{base_command_str} agent=crelu seed=0 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=crelu seed=0 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=crelu seed=1 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=crelu seed=1 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=crelu seed=2 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=crelu seed=2 agent.optimizer=sgd agent.lr=0.01']
    
    #DeepFourier
    commands_batchs += [f'{base_command_str} agent=deep_fourier seed=0 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=deep_fourier seed=0 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=deep_fourier seed=1 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=deep_fourier seed=1 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=deep_fourier seed=2 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=deep_fourier seed=2 agent.optimizer=sgd agent.lr=0.01']
    
    #PReLU
    commands_batchs += [f'{base_command_str} agent=prelu seed=0 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=prelu seed=0 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=prelu seed=1 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=prelu seed=1 agent.optimizer=sgd agent.lr=0.01']
    commands_batchs += [f'{base_command_str} agent=prelu seed=2 agent.optimizer=adam agent.lr=0.001']
    commands_batchs += [f'{base_command_str} agent=prelu seed=2 agent.optimizer=sgd agent.lr=0.01']

    print(commands_batchs)
    commands = commands_batchs
    
    return commands

if __name__ == "__main__":
    
   
    commands = create_all_commands()
    #print("Commands", commands)
    print("Num commands", len(commands))
    # Using joblib to parallelize the execution
    num_cores = multiprocessing.cpu_count()
    # if num_cores >= 4:
    #     num_cores = num_cores // 2
    num_cores = 4 #8
    print(f'*****{num_cores}*****')
    Parallel(n_jobs=num_cores)(delayed(run_command)(command) for command in commands)