
import argparse
import multiprocessing
import numpy as np
import os
import time

# Config

parser = argparse.ArgumentParser(description = 'Runs various scripts for the synthetic experiments')
parser.add_argument('--mode', type = str, default = 'complex')
parser.add_argument('--range', type = str, default = '1-20')
parser.add_argument('--script', type = str, default = '')
parser.add_argument('--gpu_ids', type = str, default = '[0,1,2,3]')
args = parser.parse_args()

mode = args.mode
range_split = args.range.split('-')
identifiers = range(int(range_split[0]), int(range_split[1]) + 1)
script = args.script
gpu_ids = [int(v) for v in args.gpu_ids.strip('][').split(',')]

print()
print('Running with arguments:')
print('Mode:', mode)
print('Identifiers:', identifiers)
print('Script:', script)
print('GPU IDs:', gpu_ids)
print()

parent_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')

# Define the scripts
def setup(name, gpu_id):
    os.system('conda run -n blindspots python 0-SetupDataset.py {}'.format(name))
    os.system('conda run -n blindspots CUDA_VISIBLE_DEVICES={} python 1-TrainBaseline.py {}'.format(gpu_id, name))
    os.system('conda run -n scvis python 2-TrainSCVIS.py {}'.format(name))
    
def blindspots(name, gpu_id):
    os.chdir(parent_dir)
    os.system('conda run -n blindspots python 3-Blindspots.py --dataset Synthetic --name {}'.format(name))
 
def spotlight(name, gpu_id):
    os.chdir(parent_dir)
    os.system('conda run -n spotlight CUDA_VISIBLE_DEVICES={} python 3-Spotlight.py --dataset Synthetic --name {}'.format(gpu_id, name))

def barlow(name, gpu_id):
    os.chdir(parent_dir)
    os.system('conda run -n barlow CUDA_VISIBLE_DEVICES={} python 3-Barlow.py --dataset Synthetic --name {} --skip_feature_vis'.format(gpu_id, name))
    
def domino(name, gpu_id):
    os.chdir(parent_dir)
    os.system('conda run -n domino CUDA_VISIBLE_DEVICES={} python 3-Domino.py --dataset Synthetic --name {} --skip_plot'.format(gpu_id, name))
    
SCRIPTS = {}
SCRIPTS['setup'] = setup
SCRIPTS['blindspots'] = blindspots
SCRIPTS['spotlight'] = spotlight
SCRIPTS['barlow'] = barlow
SCRIPTS['domino'] = domino

# Run
names = ['{}-{}'.format(mode, i) for i in identifiers]

def init(queue):
    global idx
    idx = queue.get()

def run(name):
    global idx
    process = multiprocessing.current_process()
    print(name, idx)
    SCRIPTS[script](name, idx)

manager = multiprocessing.Manager()
idQueue = manager.Queue()
for i in gpu_ids:
    idQueue.put(i)
    
p = multiprocessing.Pool(len(gpu_ids), init, (idQueue,))
p.map(run, names, chunksize = 1)
