import os
import sys
import subprocess
import time
import logging
import torch

# Global Variables
root_dir = "generated_outputs"
num_folders = 200
stop_file = "stop.txt"

# Configure Logging
os.makedirs('logs', exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] [Main] %(message)s',
    handlers=[
        logging.FileHandler("logs/main.log"),
        logging.StreamHandler()
    ]
)

def prepare_folders(root_dir, num_folders=200):
    os.makedirs(root_dir, exist_ok=True)
    for i in range(num_folders):
        folder_path = os.path.join(root_dir, f"folder_{i+1}")
        os.makedirs(folder_path, exist_ok=True)
    logging.info(f"Prepared {num_folders} folders in '{root_dir}'.")

def main():
    # Get jsonl_file from command-line arguments
    if len(sys.argv) < 2:
        print("Usage: python main.py <jsonl_file>")
        sys.exit(1)
    jsonl_file = sys.argv[1]
    use_original = len(sys.argv) > 2 and sys.argv[2] == "original"

    prepare_folders(root_dir, num_folders)

    num_gpus = torch.cuda.device_count()
    num_workers = num_gpus  # Assuming one worker per GPU
    processes = []

    for i in range(num_gpus):
        env = {**os.environ, 'CUDA_VISIBLE_DEVICES': str(i)}
        cmd = [
            '/usr/local/bin/python3.11', './inference.py',
            jsonl_file,        # jsonl_file
            root_dir,          # root_dir
            str(num_folders),  # num_folders
            str(num_workers),  # num_workers
            str(i)             # worker_index
        ]
        if use_original:
            cmd.append("original")
        p = subprocess.Popen(cmd, env=env)
        processes.append(p)
        logging.info(f"Started worker {i} with PID {p.pid}")

    logging.info("All workers started. Press Ctrl+C to stop or create 'stop.txt' file.")

    try:
        while True:
            time.sleep(1)
            if os.path.exists(stop_file):
                logging.info("Main process: Stop file detected. Initiating shutdown.")
                break
    except KeyboardInterrupt:
        logging.info("Main process: KeyboardInterrupt detected. Initiating shutdown.")

    # Terminate all workers
    for p in processes:
        p.terminate()
        p.wait()
        logging.info(f"Worker with PID {p.pid} has exited.")

    logging.info("All workers have exited. Main process terminating.")

if __name__ == '__main__':
    main()
