from joblib import Parallel, delayed
from itertools import product
import subprocess
import multiprocessing
import os
os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"


def run_command(command):
    """
    Function to run a command.
    """
    subprocess.run(command, shell=True)
    
def generate_commands(base_command, params):
    """
    Generate a list of commands based on the base command and parameters.
    """
    keys = list(params.keys())
    values = list(params.values())
    # Generate all combinations of parameters
    combinations = list(product(*values))
    commands = []
    for combination in combinations:
        command = base_command
        for key, value in zip(keys, combination):
            command += f" {key}={value}"
        commands.append(command)
    return commands

def create_all_commands(reoccurring_freq):
    reoccurring_type = 'random_seq'

    BASE_ARGS=f"arch=mlp proj_name=reoccurring_setting_{reoccurring_freq}_{reoccurring_type} wandb=true task=shuffle_cifar10 task.reoccurring_type={reoccurring_type} task.reoccurring_freq={reoccurring_freq} task.epochs=10"
    
    base_command_str = f'python main_sl.py {BASE_ARGS}'
    
    num_seeds = 1

    # Batch 1 - BaseAgent.
    # command_str = f'{base_command_str} agent=baseline'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)],
    # }
    # params = {
    #     **params,
    # }
    # commands_batch1 = generate_commands(command_str, params)

    # # Batch 1 - Scratch.
    # command_str = f'{base_command_str} agent=baseline agent.reset_network=true'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)],
    # }
    # params = {
    #     **params,
    # }
    # commands_batch2 = generate_commands(command_str, params)

    
    # # Batch 2 - L2InitAgent.
    # command_str = f'{base_command_str} agent=l2_init'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)]
    # }
    # params = {
    #     **params,
    # }
    # commands_batch3 = generate_commands(command_str, params)

    # # Batch 3 - L2Agent.
    # command_str = f'{base_command_str} agent=l2'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)],
    # }
    # params = {
    #     **params,
    # }
    # commands_batch4 = generate_commands(command_str, params)
    
    # Batch 5 - CReLU.
    command_str = f'{base_command_str} agent=crelu'
    params = {
        'seed': [str(i) for i in range(num_seeds)],
    }
    params = {
        **params,
    }
    commands_batch5 = generate_commands(command_str, params)
    
    # # Batch 6 - ContinualBackpropAgent
    # command_str = f'{base_command_str} agent=cbp' 
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)],
    
    # }
    # params = {
    #     **params,
    # }
    # commands_batch6 = generate_commands(command_str, params)
    
    # # Batch 7 - Layer Norm
    # command_str = f'{base_command_str} agent=layer_norm'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)],
    # }
    # params = {
    #     **params,
    # }
    # commands_batch7 = generate_commands(command_str, params)
    
    # # Batch 8 - ReDO
    # command_str = f'{base_command_str} agent=redo'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)]
    # }
    # params = {
    #     **params,
    # }
    # commands_batch8 = generate_commands(command_str, params)
    
    # # Batch 9 - Shrink and Perturb
    # command_str = f'{base_command_str} agent=s_and_p'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)],
    # }
    # params = {
    #     **params
    # }
    # commands_batch9 = generate_commands(command_str, params)
    
    # Batch 10 - PReLU
    command_str = f'{base_command_str} agent=prelu'
    params = {
        'seed': [str(i) for i in range(num_seeds)],
    }
    params = {
        **params
    }
    commands_batch10 = generate_commands(command_str, params)
    
    # Batch 11- DeepFourier
    command_str = f'{base_command_str} agent=deep_fourier'
    params = {
        'seed': [str(i) for i in range(num_seeds)],
    }
    params = {
        **params
    }
    commands_batch11 = generate_commands(command_str, params)
    
    # # Batch 11- NeuroSync
    # command_str = f'{base_command_str} agent=neuro_sync'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)],
    # }
    # params = {
    #     **params
    # }
    # commands_batch12 = generate_commands(command_str, params)
    
    # # Batch 11- EWC
    # command_str = f'{base_command_str} agent=ewc'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)],
    # }
    # params = {
    #     **params
    # }
    # commands_batch13 = generate_commands(command_str, params)
    
    # # Batch 11- L2InitPlusEWC
    # command_str = f'{base_command_str} agent=ewc_l2'
    # params = {
    #     'seed': [str(i) for i in range(num_seeds)],
    # }
    # params = {
    #     **params
    # }
    # commands_batch14 = generate_commands(command_str, params)
    
    

    commands = commands_batch5 + commands_batch10 + commands_batch11
    
    return commands

if __name__ == "__main__":
    
    reoccurring_freqs = [2, 5, 10]
    for reoccurring_freq in reoccurring_freqs:
        commands = create_all_commands(reoccurring_freq)
        #print("Commands", commands)
        print("Num commands", len(commands))
        # Using joblib to parallelize the execution
        num_cores = multiprocessing.cpu_count()
        if num_cores >= 4:
            num_cores = num_cores // 2
        num_cores = 4
        Parallel(n_jobs=num_cores)(delayed(run_command)(command) for command in commands)