import numpy
from joblib import Parallel, delayed
from itertools import product
import subprocess
import multiprocessing
import os
import argparse
os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
def run_command(command):
    """
    Function to run a command.
    """
    subprocess.run(command, shell=True)
    
def generate_commands(base_command, params):
    """
    Generate a list of commands based on the base command and parameters.
    """
    keys = list(params.keys())
    values = list(params.values())
    # Generate all combinations of parameters
    combinations = list(product(*values))
    commands = []
    for combination in combinations:
        command = base_command
        for key, value in zip(keys, combination):
            command += f" {key}={value}"
        commands.append(command)
    return commands

def create_all_commands():
    

    BASE_ARGS="arch=cnn proj_name=bwt3 wandb=true task=continual_cifar100"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'

    commands_batchs = []

    # CBP
    commands_batchs += [f'{base_command_str} agent=cbp seed=0 agent.optimizer=sgd agent.lr=0.01']

    # ReDO
    commands_batchs += [f'{base_command_str} agent=redo seed=0 agent.optimizer=adam agent.lr=0.001']
    
    # EWC
    commands_batchs += [f'{base_command_str} agent=ewc seed=0 agent.optimizer=adam agent.lr=0.001']
    
    # EWC + L2Init
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=0 agent.optimizer=adam agent.lr=0.001']

    BASE_ARGS="arch=mlp proj_name=bwt3 wandb=true task=permuted_MNIST"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'
    
    # CBP
    commands_batchs += [f'{base_command_str} agent=cbp seed=0 agent.optimizer=adam agent.lr=0.001']

    # ReDO
    commands_batchs += [f'{base_command_str} agent=redo seed=0 agent.optimizer=sgd agent.lr=0.01']
    
    # EWC
    commands_batchs += [f'{base_command_str} agent=ewc seed=0 agent.optimizer=sgd agent.lr=0.01']
    
    # EWC + L2Init
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=0 agent.optimizer=sgd agent.lr=0.01']

    BASE_ARGS="arch=cnn proj_name=bwt3 wandb=true task=continual_imagenet task.benchmark=new_continual_imagenet task.num_classes=200"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'

    # CBP
    commands_batchs += [f'{base_command_str} agent=cbp seed=0 agent.optimizer=sgd agent.lr=0.01']

    # ReDO
    commands_batchs += [f'{base_command_str} agent=redo seed=0 agent.optimizer=sgd agent.lr=0.01']
    
    # EWC
    commands_batchs += [f'{base_command_str} agent=ewc seed=0 agent.optimizer=adam agent.lr=0.001']
    
    # EWC + L2Init
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=0 agent.optimizer=adam agent.lr=0.001']

    print(commands_batchs)
    commands = commands_batchs
    
    return commands

def create_all_commands_crelu_prelu_df():

    BASE_ARGS="arch=cnn proj_name=bwt3 wandb=true task=continual_cifar100"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'

    commands_batchs = []

    # CRelu
    commands_batchs += [f'{base_command_str} agent=crelu seed=0 agent.optimizer=adam agent.lr=0.001']

    # deepF
    commands_batchs += [f'{base_command_str} agent=deep_fourier seed=0 agent.optimizer=adam agent.lr=0.001']
    
    # L2Init
    commands_batchs += [f'{base_command_str} agent=l2_init seed=0 agent.optimizer=adam agent.lr=0.001']
    
    # L2
    commands_batchs += [f'{base_command_str} agent=l2 seed=0 agent.optimizer=adam agent.lr=0.001']

    # Layer_norm
    commands_batchs += [f'{base_command_str} agent=layer_norm seed=0 agent.optimizer=sgd agent.lr=0.01']

    # PreLU
    commands_batchs += [f'{base_command_str} agent=prelu seed=0 agent.optimizer=adam agent.lr=0.001']

    BASE_ARGS="arch=mlp proj_name=bwt3 wandb=true task=permuted_MNIST"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'
    
    # CRelu
    commands_batchs += [f'{base_command_str} agent=crelu seed=0 agent.optimizer=adam agent.lr=0.001']

    # deepF
    commands_batchs += [f'{base_command_str} agent=deep_fourier seed=0 agent.optimizer=sgd agent.lr=0.01']
    
    # L2Init
    commands_batchs += [f'{base_command_str} agent=l2_init seed=0 agent.optimizer=sgd agent.lr=0.01']
    
    # L2
    commands_batchs += [f'{base_command_str} agent=l2 seed=0 agent.optimizer=sgd agent.lr=0.01']

    # Layer_norm
    commands_batchs += [f'{base_command_str} agent=layer_norm seed=0 agent.optimizer=sgd agent.lr=0.01']

    # PreLU
    commands_batchs += [f'{base_command_str} agent=prelu seed=0 agent.optimizer=adam agent.lr=0.001']

    BASE_ARGS="arch=cnn proj_name=bwt3 wandb=true task=continual_imagenet task.benchmark=new_continual_imagenet task.num_classes=200"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'

    # CRelu
    commands_batchs += [f'{base_command_str} agent=crelu seed=0 agent.optimizer=sgd agent.lr=0.01']

    # deepF
    commands_batchs += [f'{base_command_str} agent=deep_fourier seed=0 agent.optimizer=adam agent.lr=0.001']
    
    # L2Init
    commands_batchs += [f'{base_command_str} agent=l2_init seed=0 agent.optimizer=adam agent.lr=0.001']
    
    # L2
    commands_batchs += [f'{base_command_str} agent=l2 seed=0 agent.optimizer=adam agent.lr=0.001']

    # Layer_norm
    commands_batchs += [f'{base_command_str} agent=layer_norm seed=0 agent.optimizer=sgd agent.lr=0.01']

    # PreLU
    commands_batchs += [f'{base_command_str} agent=prelu seed=0 agent.optimizer=adam agent.lr=0.001']

    print(commands_batchs)
    commands = commands_batchs
    
    return commands

def create_all_commands_neuro_sync():

    BASE_ARGS="arch=cnn proj_name=bwt3 wandb=true task=continual_cifar100"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'
    commands_batchs = []

    # Neuro Sync
    commands_batchs += [f'{base_command_str} agent=neuro_sync seed=0 agent.optimizer=adam agent.lr=0.001']

    BASE_ARGS="arch=mlp proj_name=bwt3 wandb=true task=permuted_MNIST"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'
    
    # Neuro Sync
    commands_batchs += [f'{base_command_str} agent=neuro_sync seed=0 agent.optimizer=adam agent.lr=0.001']

    BASE_ARGS="arch=cnn proj_name=bwt3 wandb=true task=continual_imagenet task.benchmark=new_continual_imagenet task.num_classes=200"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'

    # Neuro Sync
    commands_batchs += [f'{base_command_str} agent=neuro_sync seed=0 agent.optimizer=adam agent.lr=0.001']

    print(commands_batchs)
    commands = commands_batchs
    
    return commands

def create_all_commands_neuro_sync_cifar100():

    BASE_ARGS="arch=cnn proj_name=bwt3 wandb=true task=continual_cifar100"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'
    commands_batchs = []

    # Neuro Sync
    commands_batchs += [f'{base_command_str} agent=neuro_sync seed=0 agent.optimizer=adam agent.lr=0.001']

    # EWC + L2Init
    commands_batchs += [f'{base_command_str} agent=ewc_l2 seed=0 agent.optimizer=adam agent.lr=0.001']

    print(commands_batchs)
    commands = commands_batchs
    
    return commands

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run commands based on a given name.")
    parser.add_argument("--name", type=str, required=True)

    args = parser.parse_args()
    if args.name == 'ewc':
        commands = create_all_commands()
        num_cores = 8
    elif args.name == 'neuro':
        commands = create_all_commands_neuro_sync()
        num_cores = 8
    elif args.name == 'crelu':
        commands = create_all_commands_crelu_prelu_df()
        num_cores = 8
    elif args.name == 'cifar100':
        commands = create_all_commands_neuro_sync_cifar100()
        num_cores = 4
    else:
        print('WTF')
        exit()

    print("Num commands", len(commands))
    print(f'*****{num_cores}*****')
    Parallel(n_jobs=num_cores)(delayed(run_command)(command) for command in commands)