import os,glob,threading,subprocess
from models_and_datas import models_and_datas
# os.environ['HF_HOME'] = "../llama_on_glue/checkpoints"
from huggingface_hub import login
# #
import time

gpus = 8
output_folder = "output/test0123"

threads = {}
for i in range(gpus):
    threads[i] = None

def simple_run(bs, method, device, extra = ''):
    cmd = f' python run.py --device {device} --branches {" ".join(bs)} --method {method} '
    cmd += extra
    cmd += f'  --output_folder={output_folder}  --base_model={models_and_datas["base"]["model"][0]}'
    subprocess.call([cmd],shell=True)

def get_running(target, args):
    while True:
        Done = False
        for k in range(gpus):
            if threads[k] is None or not threads[k].is_alive():
                thread = threading.Thread(target=target, \
                    args = tuple(list(args) + [f'cuda:{k}']))
                threads[k] = thread
                threads[k].setDaemon(True)
                threads[k].start()
                Done = True
                break
        if Done:
            break
        # time.sleep(60)

# settings!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
branches = [i for i in models_and_datas][1:]
# branches = ['guard','math','italian','code']
# branches = ['math','code']
branches = ['qnli','cola']
methods = ['ties', 'dare_ties', 'linear', 'slerp', 'dare_linear']
# methods = [ 'dare_ties', 'linear', 'slerp', 'dare_linear']
methods = ['ties']
branches = sorted(branches)
methods = sorted(methods)



# h20 can merge mostly on gpu

# for i in range(len(branches)):
#     for m in ['simple']:
#         get_running(simple_run, (branches[i:i+1],m,) )

# for k in range(gpus):
#     t = threads[k]
#     if t is not None and t.is_alive():
#         t.join()

# for i in range(len(branches)):
#     for j in range(i+1,len(branches)):
#         for m in methods:
#             simple_run([branches[i],branches[j]],m,'cuda:0','  --no_eval  ')        

for i in range(len(branches)):
    for j in range(i+1,len(branches)):
        for m in methods:
            get_running(simple_run, ([branches[i],branches[j]],m,) )
        
for k in range(gpus):
    t = threads[k]
    if t is not None and t.is_alive():
        t.join()
