import os
import threading


if True:
    load_repre_dir = "xxx"
    save_model_dir = "xxx"
    remark = "large_scale_collect"

cuda_devices = [0, 1]

def one_few_shot_test(task, sparse_loss_coef, cuda_device):
    cmd = f"export CUDA_VISIBLE_DEVICES={cuda_device}; python ../src/main.py --config=ts_collect_run --env-config=dense_p2_navigation --remark=bench_test_{sparse_loss_coef}_{remark} --checkpoint_path={save_model_dir} --load_repre_dir={load_repre_dir} --task-config=dense_navigation_mpe --task_id={task} --collect_src=False --sparse_loss_coef={sparse_loss_coef}"
    os.system(cmd)

if __name__ == "__main__":
    train_tasks = range(8, 12)
    sparse_loss_coef = 0
    threads = []
    for i, task in enumerate(train_tasks):
        cuda_device = cuda_devices[i%len(cuda_devices)]
        th = threading.Thread(target=one_few_shot_test, args=(task, sparse_loss_coef, cuda_device))
        th.start()
        threads.append(th)
    
    for th in threads:
        th.join()