import os
import threading

if True:
    load_repre_dir = "xxx"
    save_model_dir = "xxx"
    remark = "model_error"
    use_dense = True

cuda_devices = [0, 1]

item_list = [
    "xxx",
    "xxx",
    "xxx",
]
configs = ["dense_navigation", "0-1-2-3-4-5-6-7"]
res_direc = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "collect", *configs)

def one_few_shot_test(train_task, task, target_repre_path, cuda_device):
    cmd = f"export CUDA_VISIBLE_DEVICES={cuda_device}; python ../src/main.py --config=ts_collect_run --env-config=dense_p2_navigation --remark=model_error_test_{task}_{train_task} --checkpoint_path={save_model_dir} --load_repre_dir={load_repre_dir} --task-config=dense_navigation_mpe --task_id={train_task} --collect_src=False --compute_model_error=True --target_repre_path={target_repre_path} --test_nepisode=5"
    os.system(cmd)

if __name__ == "__main__":
    test_tasks = range(8, 12)
    train_tasks = range(8)
    for train_task in train_tasks:
        threads = []    
        for i, task in enumerate(test_tasks):
            cuda_device = cuda_devices[i%len(cuda_devices)]
            for item in item_list:
                target_repre_path = os.path.join(res_direc, item, str(task), "task_repre")
                th = threading.Thread(target=one_few_shot_test, args=(train_task, task, target_repre_path, cuda_device))
                th.start()
                threads.append(th)

        for th in threads:
            th.join()