# nohup python -u run_hmc_aldp.py > run_hmc_aldp.log 2>&1 &

import os
import itertools
from run_utils import run_jobs

# --- Configuration ---
GPUS = [3]  # Your available GPUs

LOGGER = "wandb"
# NUM_FOLDS = 3
NUM_FOLDS = 2

OUTPUT_DIR = "/home/fran/work_fran/sampling/experiments/run/output"


grid_list = [
    {
        "task": ["aldp_vacuum"],
        "task.n_samples_test": [90000],
        "sampler": ["hmc"],
        "sampler.n_steps": [10000, 100000, 1000000],
        "sampler.params.step_size": [0.000001],
        "sampler.params.n_leapfrog_steps": [1, 3],
        "sampler.params.adaptation_rate": [0.05],
    }
]

os.makedirs(OUTPUT_DIR, exist_ok=True)

# remove existing log files in OUTPUT_DIR
for filename in os.listdir(OUTPUT_DIR):
    if filename.endswith(".log"):
        file_path = os.path.join(OUTPUT_DIR, filename)
        try:
            os.remove(file_path)
        except Exception as e:
            print(f"Could not remove file {file_path}: {e}")

job_list = []
for grid in grid_list:
    for _ in range(NUM_FOLDS):
        # Generate all combinations
        keys, values = zip(*grid.items())
        for v in itertools.product(*values):
            combo = dict(zip(keys, v))
            job_list.append(combo)

run_jobs(
    job_list,
    GPUS,
    LOGGER,
    OUTPUT_DIR
)