# nohup python -u run_pisde_gm_ablations.py > run_pisde_gm_ablations.log 2>&1 &

import os
import itertools
from run_utils import run_jobs

# --- Configuration ---
GPUS = [0, 1, 2, 3, 4]  # Your available GPUs

LOGGER = "wandb"
NUM_FOLDS = 3

OUTPUT_DIR = "/home/fran/work_fran/sampling/experiments/run/output"


grid_list = [
    { # jump prop ablations
        "task": ["gm_16_40"],
        "logger.project": ["sampling_ablations"],
        "sampler": ["progressive_interpolation_sde"],
        "sampler.n_steps": [10000],
        "sampler.jump_prop": [1.0, 0.9, 0.75, 0.5, 0.25, 0.1, 0.0],
        "sampler.params.time_schedule.type": ["geometric"],
        "sampler.params.time_schedule.min_val": [0.01],
        "sampler.params.noise_schedule.base_noise_var": [0.1],
        "sampler.params.noise_schedule.type": ["constant"],
        "sampler.params.corrector_mode": ["mala"],
        "sampler.params.corrector_steps": [0],
        "sampler.params.corrector_adaptation_rate": [0.0],
        "sampler.params.corrector_step_size": [0.1],
        "sampler.params.jump_ref_std": [1.0],
        "sampler.params.jump_beta_schedule.type": ["geometric"],
        "sampler.params.jump_beta_schedule.min_val": [0.001],
        "sampler.params.jump_beta_schedule.n_replicas": [10],
        "sampler.params.jump_adaptation_rate": [0.05],
    },    
    { # corrector steps ablations
        "task": ["gm_16_40"],
        "logger.project": ["sampling_ablations"],
        "sampler": ["progressive_interpolation_sde"], 
        "sampler.n_steps": [1000, 5000, 10000],
        "sampler.jump_prop": [0.75],
        "sampler.params.time_schedule.type": ["geometric"],
        "sampler.params.time_schedule.min_val": [0.01],
        "sampler.params.noise_schedule.base_noise_var": [0.1],
        "sampler.params.noise_schedule.type": ["constant"],
        "sampler.params.corrector_mode": ["mala"],
        "sampler.params.corrector_steps": [0, 1],
        "sampler.params.corrector_adaptation_rate": [0.0],
        "sampler.params.corrector_step_size": [0.1],
        "sampler.params.jump_beta_schedule.type": ["geometric"],
        "sampler.params.jump_beta_schedule.min_val": [0.001],
        "sampler.params.jump_beta_schedule.n_replicas": [10],
        "sampler.params.jump_adaptation_rate": [0.05],
    },
    { # noise schedule ablations
        "task": ["gm_16_40"],
        "logger.project": ["sampling_ablations"],
        "sampler": ["progressive_interpolation_sde"], 
        "sampler.n_steps": [1000, 5000, 10000],
        "sampler.jump_prop": [0.75],
        "sampler.params.time_schedule.type": ["geometric"],
        "sampler.params.time_schedule.min_val": [0.01],
        "sampler.params.noise_schedule.base_noise_var": [0.1],
        "sampler.params.noise_schedule.type": ["constant", "linear"],
        "sampler.params.corrector_mode": ["mala"],
        "sampler.params.corrector_steps": [0],
        "sampler.params.corrector_adaptation_rate": [0.0],
        "sampler.params.corrector_step_size": [0.1],
        "sampler.params.jump_beta_schedule.type": ["geometric"],
        "sampler.params.jump_beta_schedule.min_val": [0.001],
        "sampler.params.jump_beta_schedule.n_replicas": [10],
        "sampler.params.jump_adaptation_rate": [0.05],
    },
]

os.makedirs(OUTPUT_DIR, exist_ok=True)

# remove existing log files in OUTPUT_DIR
for filename in os.listdir(OUTPUT_DIR):
    if filename.endswith(".log"):
        file_path = os.path.join(OUTPUT_DIR, filename)
        try:
            os.remove(file_path)
        except Exception as e:
            print(f"Could not remove file {file_path}: {e}")

job_list = []
for grid in grid_list:
    for _ in range(NUM_FOLDS):
        # Generate all combinations
        keys, values = zip(*grid.items())
        for v in itertools.product(*values):
            combo = dict(zip(keys, v))
            job_list.append(combo)

run_jobs(
    job_list,
    GPUS,
    LOGGER,
    OUTPUT_DIR
)