import glob
import subprocess
from multiprocessing import Pool
import sys
import os
from tqdm import tqdm
import argparse
from tqdm import tqdm


def run_process(cmd):
    process = subprocess.Popen(cmd)
    process.wait()

data_folder = "./data/datasets/"

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the reward model")
    parser.add_argument("--name", type=str, default="BabyAI", help="The name of the dataset")
    
    args = parser.parse_args()
    
    dataset_name = args.name

    list_env_id = [x.split("/")[-2] for x in glob.glob(f"{data_folder}{dataset_name}/env/*/")]
    
    for env_id in tqdm(list_env_id):
            
        goals = glob.glob(f"{data_folder}{dataset_name}/env/{env_id}/*")
        
        goals = [
            g.split("/")[-1]
            for g in goals
            if (("csv" not in g) and os.path.exists(f"{g}/data_with_reward.csv") and not os.path.exists(f"{g}/Q_policy.csv") ) 
        ]


        n_job = 16
        cmds = []
        for i,g in enumerate(goals):
            cmds += [
                [
                    "python",
                    "source/policy/policy_improvement.py",
                    "--dataset",
                    dataset_name,
                    "--goal",
                    g,
                    "--env_id",
                    str(env_id),
                    "--device",
                    str(i %4)
                ]
            ]


        pool = Pool(n_job)
        for _ in tqdm(pool.imap_unordered(run_process, cmds), total=len(goals), leave=False):
            pass
