# execute experiments with different parameter values with subprocess
import subprocess
import os

# Run kWTA experiments
os.makedirs("autoattack_kwta_output_jsons", exist_ok=True)
for model_file in ["kWTA_models/spresnet18_0.1_cifar.pth"]:
    for dims, bins in [(1,1001), (2,51), (3,21), (4,11), (5,9), (6,9)]:
        command = ["python", "code/basis_experiments.py", "cifar10", \
                model_file, str(0.5), "autoattack_kwta_output_jsons/%ddims_%dbins.json"%(dims, bins), \
                "--skip", "20", "--batch", "800", "--num_bins", str(bins), \
                "--num_basis", str(dims)]
        print(command)
        subprocess.check_output(command)

# Run Deterministic Smoothing experiments
os.makedirs("autoattack_smooth_output_jsons", exist_ok=True)
for model_file in ["models/cifar10/resnet110/noise_0.50/checkpoint.pth.tar"]:
    for dims, bins in [(1,1001), (2,51), (3,21), (4,11), (5,9), (6,9)]:
        command = ["python", "code/basis_experiments.py", "cifar10", \
                model_file, str(0.5), "autoattack_smooth_output_jsons/%ddims_%dbins.json"%(dims, bins), \
                "--skip", "20", "--batch", "800", "--num_bins", str(bins), \
                "--num_basis", str(dims), "--num_corrupt", "100"]
        print(command)
        subprocess.check_output(command)

# Run EBM-Def experiments
os.makedirs("autoattack_ebm_output_jsons", exist_ok=True)
for model_file in ["jem_models/CIFAR10_MODEL.pt"]:
    for dims, bins in [(1,1001), (2,51), (3,21), (4,11), (5,9), (6,9)]:
        command = ["python", "code/basis_experiments.py", "cifar10", \
                model_file, str(0.5), "autoattack_ebm_output_jsons/%ddims_%dbins.json"%(dims, bins), \
                "--skip", "20", "--batch", "80", "--num_bins", str(bins), \
                "--num_basis", str(dims), "--jem_defense", "1"]
        print(command)
        subprocess.check_output(command)


# Run DiffSmall experiments
timestep_respacings = [("[500,3999]","1steps"), ("[250,500]","2steps"), \
                       ("[166,333,500]","3steps"), ("[100,200,300,400,500]","5steps"), \
                       ("[50,100,150,200,250,300,350,400,450,500]","10steps")]


for model_file in ["kWTA_models/resnet18_cifar.pth"]:
    for timestep_respacing_list, timestep_name in timestep_respacings:
        for dims, bins in [(1,1001), (2,51), (3,21), (4,11), (5,9), (6,9)]:
            command = ["python", "code/basis_experiments.py", "cifar10", \
                    model_file, str(0.5), "autoattack_diff_output_jsons/%ddims_%dbins%s.json"%(dims, bins, "" if not timestep_name else "_" + timestep_name), \
                    "--skip", "100", "--batch", "4", "--num_bins", str(bins), \
                    "--num_basis", str(dims), "--diffusion_defense"]
            command.extend(["--timestep_respacing", timestep_respacing_list.replace(' ', '')])
            print(command)
            subprocess.check_output(command)