import os
import subprocess
from pathlib import Path
import time

# Configuration
datasets = ["cifar10", "cifar100", "tinyimagenet", "imagenet"]
methods = ["random", "crest", "single_spread_bn"]
seeds = [0, 1, 2]
scales = [0.1]
ratios = [0.0, 0.1, 0.3, 0.5]
learning_fracs = [0.1]
selects = [1]

# Directories
os.makedirs("logs", exist_ok=True)
os.makedirs("slurm_scripts", exist_ok=True)

def get_dataset_config(dataset):
    if dataset == "cifar10":
        return "resnet20", 0.05, ""
    elif dataset == "cifar100":
        return "resnet18", 0.01, "--interval_mul=10"
    elif dataset == "tinyimagenet":
        return "resnet50", 0.005, "--interval_mul=1"
    elif dataset in ["mnist", "emnist"]:
        return "lenet", 0.05, ""
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

def create_slurm_script(script_path, job_name, dataset, method, seed, ratio, scale, select, arch, check_thresh_factor, extra_params, budget):
    slurm_script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --output=logs/{job_name}.out
#SBATCH --error=logs/{job_name}.err

module load gcc/12.3.0
module load cuda/12.1.1
module load rcac
module load conda

cd $SLURM_SUBMIT_DIR/coreset

python train.py --selection_method={method} \\
    --warm_start_epochs=20 \\
    --seed={seed} \\
    --dataset={dataset} \\
    --arch={arch} \\
    --check_thresh_factor={check_thresh_factor} \\
    --ensemble_num=4 \\
    --corrupt_ratio={ratio} \\
    --noise_std={scale} \\
    --randomparse=True \\
    --select_every={select} \\
    --train_frac={budget} \\
    {extra_params}
"""
    with open(script_path, "w") as f:
        f.write(slurm_script)

def submit_job(script_path):
    subprocess.run(["sbatch", "--nodes=1", "--gpus-per-node=1", "--time=4:00:00", "-A", "standby", script_path])

# Iterate over parameter combinations
for dataset in datasets:
    for seed in seeds:
        for method in methods:
            for scale in scales:
                for ratio in ratios:
                    for budget in learning_fracs:
                        for select in selects:
                            try:
                                arch, check_thresh_factor, extra_params = get_dataset_config(dataset)
                            except ValueError as e:
                                print(e)
                                continue

                            job_name = f"{dataset}_{method}_seed{seed}_noise{scale}_ratio{ratio}_select{select}_budget{budget}"
                            script_path = Path("slurm_scripts") / f"{job_name}.sub"

                            create_slurm_script(
                                script_path, job_name, dataset, method, seed,
                                ratio, scale, select, arch,
                                check_thresh_factor, extra_params, budget
                            )
                            submit_job(str(script_path))
                            print(f"Submitted: {job_name}")
                            time.sleep(1)
