import numpy
from joblib import Parallel, delayed
from itertools import product
import subprocess
import multiprocessing
import os
os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
def run_command(command):
    """
    Function to run a command.
    """
    subprocess.run(command, shell=True)
    
def generate_commands(base_command, params):
    """
    Generate a list of commands based on the base command and parameters.
    """
    keys = list(params.keys())
    values = list(params.values())
    # Generate all combinations of parameters
    combinations = list(product(*values))
    commands = []
    for combination in combinations:
        command = base_command
        for key, value in zip(keys, combination):
            command += f" {key}={value}"
        commands.append(command)
    return commands

def create_all_commands():
    
    # commands = [
    #     'python main_sl.py arch=mix agent=baseline agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=permuted_MNIST forgetting_mech=true er_type=er buffer_size=4000 buffer_batch_size_ratio=1 task.num_tasks=10',
    #     'python main_sl.py arch=mix agent=crelu agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=permuted_MNIST forgetting_mech=true er_type=er buffer_size=4000 buffer_batch_size_ratio=1 task.num_tasks=10',
    #     'python main_sl.py arch=mix agent=hat agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=permuted_MNIST task.num_tasks=10',
    #     'python main_sl.py arch=mix agent=neuro_sync agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=permuted_MNIST forgetting_mech=true er_type=er buffer_size=4000 buffer_batch_size_ratio=1 task.num_tasks=10'
    # ]
    commands = []
    # agem + baseline + CNN permuted_MNIST
    commands += [
        'python main_sl.py arch=cnn agent=baseline agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=permuted_MNIST forgetting_mech=true er_type=agem buffer_size=4000 buffer_batch_size_ratio=1 task.num_tasks=10'
    ]
    # agem + baseline + MIX continual_cifar10
    commands += [
        'python main_sl.py arch=mix agent=baseline agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=continual_cifar10 forgetting_mech=true er_type=agem buffer_size=4000 buffer_batch_size_ratio=1'
    ]
    # agem + baseline + MIX continual_cifar100
    commands += [
        'python main_sl.py arch=mix agent=baseline agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=continual_cifar100 forgetting_mech=true er_type=agem buffer_size=4000 buffer_batch_size_ratio=1'
    ]

    # agem + crelu + CNN permuted_MNIST
    commands += [
        'python main_sl.py arch=mlp agent=crelu agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=permuted_MNIST forgetting_mech=true er_type=agem buffer_size=4000 buffer_batch_size_ratio=1 task.num_tasks=10'
    ]
    # agem + crelu + MIX continual_cifar10
    commands += [
        'python main_sl.py arch=mix agent=crelu agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=continual_cifar10 forgetting_mech=true er_type=agem buffer_size=4000 buffer_batch_size_ratio=1'
    ]
    # agem + crelu + MIX continual_cifar100
    commands += [
        'python main_sl.py arch=mix agent=crelu agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=continual_cifar100 forgetting_mech=true er_type=agem buffer_size=4000 buffer_batch_size_ratio=1'
    ]

    # ewc + CNN permuted_MNIST
    commands += [
        'python main_sl.py arch=cnn agent=ewc agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=permuted_MNIST task.num_tasks=10'
    ]
    # ewc + MIX continual_cifar10
    commands += [
        'python main_sl.py arch=mix agent=ewc agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=continual_cifar10'
    ]
    # ewc + MIX continual_cifar100
    commands += [
        'python main_sl.py arch=mix agent=ewc agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=continual_cifar100'
    ]


    # ewc_l2 + CNN permuted_MNIST
    commands += [
        'python main_sl.py arch=cnn agent=ewc_l2 agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=permuted_MNIST task.num_tasks=10'
    ]
    # ewc_l2 + MIX continual_cifar10
    commands += [
        'python main_sl.py arch=mix agent=ewc_l2 agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=continual_cifar10'
    ]
    # ewc_l2 + MIX continual_cifar100
    commands += [
        'python main_sl.py arch=mix agent=ewc_l2 agent.optimizer=adam agent.lr=0.001 proj_name=forgetting_with_ER wandb=true task=continual_cifar100'
    ]

    return commands

if __name__ == "__main__":
    
   
    commands = create_all_commands()
    #print("Commands", commands)
    print("Num commands", len(commands))
    # Using joblib to parallelize the execution
    num_cores = multiprocessing.cpu_count()
    num_cores = 2
    Parallel(n_jobs=num_cores)(delayed(run_command)(command) for command in commands)