# nohup python -u run_pisde_aldp.py > run_pisde_aldp.log 2>&1 &
# nice -n 19 python run_pisde_aldp.py > run_pisde_aldp.log 2>&1 &

import os
import itertools
from run_utils import run_jobs

# --- Configuration ---
GPUS = [3, 4, 5, 6]  # Your available GPUs

LOGGER = "wandb"
NUM_FOLDS = 3

OUTPUT_DIR = "/home/fran/work_fran/sampling/experiments/run/output"

grid_list = [
    { # time schedule min_val ablations
        "task": ["aldp_vacuum"],
        "logger.project": ["sampling_new"],
        "task.n_samples_test": [90000],
        "save_samples": [True],
        "n_chunks": [1],
        "sampler": ["progressive_interpolation_sde"],
        "sampler.save_diagnostics": [False],
        "sampler.n_steps": [10000, 50000, 100000, 500000, 1000000, 1500000],
        "sampler.jump_prop": [0.5, 0.75],
        "sampler.params.time_schedule.type": ["geometric"],
        "sampler.params.time_schedule.min_val": [0.1],
        "sampler.params.noise_schedule.base_noise_var": [0.0001],
        "sampler.params.noise_schedule.type": ["constant"],
        "sampler.params.corrector_mode": ["mala"],
        "sampler.params.corrector_steps": [1],
        "sampler.params.corrector_adaptation_rate": [0.05],
        "sampler.params.corrector_step_size": [0.000001],
        "sampler.params.jump_ref_std": [1.0],
        "sampler.params.jump_beta_schedule.type": ["geometric"],
        "sampler.params.jump_beta_schedule.min_val": [0.1],
        "sampler.params.jump_beta_schedule.n_replicas": [5],
        "sampler.params.jump_adaptation_rate": [0.05],
    }
]

os.makedirs(OUTPUT_DIR, exist_ok=True)

# remove existing log files in OUTPUT_DIR
for filename in os.listdir(OUTPUT_DIR):
    if filename.endswith(".log"):
        file_path = os.path.join(OUTPUT_DIR, filename)
        try:
            os.remove(file_path)
        except Exception as e:
            print(f"Could not remove file {file_path}: {e}")

job_list = []
for grid in grid_list:
    for _ in range(NUM_FOLDS):
        # Generate all combinations
        keys, values = zip(*grid.items())
        for v in itertools.product(*values):
            combo = dict(zip(keys, v))
            n_steps = combo["sampler.n_steps"]
            n_replicas = combo["sampler.params.jump_beta_schedule.n_replicas"]
            n_steps_pt = n_steps // n_replicas
            combo["sampler.n_steps"] = n_steps_pt
            job_list.append(combo)

run_jobs(
    job_list,
    GPUS,
    LOGGER,
    OUTPUT_DIR
)