import shlex
import subprocess


for dataset in ["cifar", "imagenet"]:
    data = dict(cifar="--cifar10", imagenet="--data PATH/TO/IMAGENET")[dataset]
    n_blocks = dict(cifar=8, imagenet=11)[dataset]
    epochs = dict(cifar=300, imagenet=150)[dataset]
    lr_adjust = dict(cifar=70, imagenet=45)[dataset]
    factorize_filters = dict(cifar=0, imagenet=1)[dataset]
    angles = ' '.join([dict(cifar='4', imagenet='8')[dataset]] * n_blocks)

    cmd = f"{data} --n-blocks {n_blocks} -p 100 --scattering-wph {' '.join(['1'] * n_blocks)}" \
          f" --epochs {epochs} --learning-rate-adjust-frequency {lr_adjust} --lr 0.01 -j 10" \
          f" --weight-decay 0.0001 --batch-size 128 --factorize-filters {factorize_filters}" \
          f" --scat-angles {angles}"

    proj_sizes = []
    if dataset == "imagenet":
        proj_sizes.extend(["32", "64"])
    proj_sizes.extend(["64", "128", "256", "512", "512", "512", "512", "512"])
    if dataset == "imagenet":
        proj_sizes.append("256")
    Pc_sizes = f"--Pc-size {' '.join(proj_sizes)}"

    for skip in [False, True]:
        if skip:
            arch = "-a 'Fw Std Pc N' --psi-arch mod"
        else:
            arch = "-a 'Fw rho Std Pc N'"

        nonlins = dict(mod=[])
        if dataset == "cifar":
            nonlins.update(cst=[("bias", 0.1)], ms=[("gain", 1.0), ("bias", 0.0)], mc=[])

        for nonlin, params in nonlins.items():
            learned = ['0' if nonlin == "mod" else '1'] * n_blocks
            non_lin_str = f"--scat-non-linearity {' '.join([nonlin] * n_blocks)} --scat-non-linearity-learned {' '.join(learned)}"
            for param, value in params:
                values = [str(value)] * n_blocks
                non_lin_str = non_lin_str + f" --scat-non-linearity-{param} {' '.join(values)}"

            name = f"{dataset}-{nonlin}-skip-{skip}"
            to_run = f"python main_block.py {cmd} {arch} {Pc_sizes} {non_lin_str} --dir {name}"
            subprocess.run(shlex.split(to_run), check=True, capture_output=True)
