# nohup python -u run_pt_gm.py > run_pt_gm.log 2>&1 &

import os
import itertools
from run_utils import run_jobs

# --- Configuration ---
GPUS = [5]  # Your available GPUs

LOGGER = "wandb"
NUM_FOLDS = 3

OUTPUT_DIR = "/home/fran/work_fran/sampling/experiments/run/output"


grid_list = [
    {
        "task": ["gm_2_40", "gm_16_40", "gmnu_2_40", "gmnu_16_40"],
        "sampler": ["parallel_tempering"],
        "sampler.n_steps": [1000, 5000, 10000, 50000, 100000],
        "sampler.params.beta_schedule.type": ["geometric"],
        "sampler.params.beta_schedule.min_val": [0.001],
        "sampler.params.beta_schedule.n_replicas": [5, 10],
        "sampler.params.beta_schedule.optimize": [False, True],
        "sampler.params.kernel.name": ["mala"],
        "sampler.params.step_size": [0.1],
        "sampler.params.swap_mode": ["nrpt"],
        "sampler.params.swap_every": [1],
        "sampler.params.adaptation_rate": [0.05],
    }
]

os.makedirs(OUTPUT_DIR, exist_ok=True)

# remove existing log files in OUTPUT_DIR
for filename in os.listdir(OUTPUT_DIR):
    if filename.endswith(".log"):
        file_path = os.path.join(OUTPUT_DIR, filename)
        try:
            os.remove(file_path)
        except Exception as e:
            print(f"Could not remove file {file_path}: {e}")

job_list = []
for grid in grid_list:
    for _ in range(NUM_FOLDS):
        # Generate all combinations
        keys, values = zip(*grid.items())
        for v in itertools.product(*values):
            combo = dict(zip(keys, v))
            n_steps = combo["sampler.n_steps"]
            n_replicas = combo["sampler.params.beta_schedule.n_replicas"]
            n_steps_pt = n_steps // n_replicas
            combo["sampler.n_steps"] = n_steps_pt
            job_list.append(combo)

run_jobs(
    job_list,
    GPUS,
    LOGGER,
    OUTPUT_DIR
)