import itertools
import os

# Fixed parameters
base_command = "python scripts/sweep_cli_entry.py"
fixed_params = {
    "vq_decay": 0.95,
    "commit_weight": 2.0,
    "threshold_ema_dead_code": 0.008,
    "model": "vqvae",
    "codebook_size": 512,
    "codebook_dim": -1,
    "batch_size": 256,
    "gpu": 0,
    "dataset": None,            # will be varied
    "arch": None,               # will be varied
    "dim_per_quantizer": None,  # will be varied
    "num_quantizers": None,      # will be varied
    "split_vq": None,           # will be varied
}

# Parameters with multiple values
param_grid = {
    "dataset": ["cifar10", "celeba"], #, "imagenet"], # imagenet was too large to train with this script
    "arch": ["taming", "enhancing"],
    "num_quantizers": [4, 8, 16],
    "dim_per_quantizer": [8, 32],
    "split_vq": [True, False],
}

# Generate all combinations of the varying parameters
keys, values = zip(*param_grid.items())
combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

# Create and print commands
commands = []
for combo in combinations:
    params = fixed_params.copy()
    params.update(combo)
    if params["dataset"] == "imagenet":
        # To avoid OOM, reduce batch size and increase gradient accumulation steps for imagenet
        params["batch_size"] = 32
        params["n_accumulate"] = 8
    cmd_parts = [base_command] + [
        f"--{k}={v}" for k, v in params.items()
    ]
    commands.append(" ".join(cmd_parts))

# Write commands to files
N_GPU = 8
os.makedirs("src/configs/all_configs", exist_ok=True)
for i, cmd in enumerate(commands):
    gpu_id = i % N_GPU
    node_id = i // N_GPU
    with open(f"src/configs/all_configs/config_{node_id}_{gpu_id}.sh", "w") as f:
        f.write(f"CUDA_VISIBLE_DEVICES={gpu_id} {cmd}\n")