# nohup python -u run_digs_mlp_posterior.py > run_digs_mlp_posterior.log 2>&1 &

import os
import itertools
from run_utils import run_jobs

# --- Configuration ---
GPUS = [7]  # Your available GPUs

LOGGER = "wandb"
NUM_FOLDS = 3

OUTPUT_DIR = "/home/fran/work_fran/sampling/experiments/run/output"

grid_list = [
    { # time schedule min_val ablations
        "task": ["mlp_posterior"],
        "sampler": ["digs"],
        "sampler.n_steps": [1000, 5000, 10000, 50000, 100000],
        "sampler.params.noise_schedule.alpha_min": [0.1],
        "sampler.params.noise_schedule.alpha_max": [0.9],
        "sampler.params.noise_schedule.n_noise_levels": [1, 5],
        "sampler.params.n_denoising_steps": [1, 4],
        "sampler.params.denoising_kernel.name": ["mala"],
        "sampler.params.denoising_step_size": [0.00001],
    }
]

os.makedirs(OUTPUT_DIR, exist_ok=True)

# remove existing log files in OUTPUT_DIR
for filename in os.listdir(OUTPUT_DIR):
    if filename.endswith(".log"):
        file_path = os.path.join(OUTPUT_DIR, filename)
        try:
            os.remove(file_path)
        except Exception as e:
            print(f"Could not remove file {file_path}: {e}")

job_list = []
for grid in grid_list:
    for _ in range(NUM_FOLDS):
        # Generate all combinations
        keys, values = zip(*grid.items())
        for v in itertools.product(*values):
            combo = dict(zip(keys, v))
            n_steps = combo["sampler.n_steps"]
            n_noise_levels = combo["sampler.params.noise_schedule.n_noise_levels"]
            n_denoising_steps = combo["sampler.params.n_denoising_steps"]
            n_steps_digs = n_steps // (1 + n_denoising_steps)
            n_gibbs_sweeps = n_steps_digs // n_noise_levels
            combo["sampler.params.n_gibbs_sweeps"] = n_gibbs_sweeps
            combo["sampler.n_steps"] = n_steps_digs
            job_list.append(combo)

run_jobs(
    job_list,
    GPUS,
    LOGGER,
    OUTPUT_DIR
)