from pathlib import Path
import h5py

conf = workflow.overwrite_configfile
config_stem = Path(conf).stem if conf is not None else "default"

rule train_all:
    input:
        expand(f"runs/{config['dataset']}_sigma{config['noise_sigma']}_{{i}}/run_{config_stem}_rep{{rep}}.h5",
               i=range(3), rep=range(config["reps"]))
      

rule train:
    input:
        f"runs/{{img}}_sigma{{sigma}}_{{i}}/data_patchsize{config['patch_size']}.h5"
    output:
        "runs/{img}_sigma{sigma}_{i}/run_" + config_stem + "_rep{rep}.h5"
    params:
        shape=":".join(map(str,(config["patch_size"]**2, config["H1"], config["H0"])))
    shell:
        """
        python src/train.py {input} \
                        --epochs {config[epochs]} \
                        --net-shape {params.shape} \
                        --batch-size={config[batch_size]}\
                        --Ksize={config[S]} \
                        --min-lr={config[min_lr]} \
                        --max-lr={config[max_lr]} \
                        --epochs-per-cycle={config[epochs_per_cycle]} \
                        --n-parents={config[n_parents]} \
                        --n-children={config[n_children]} \
                        --n-generations={config[n_generations]} \
                        --crossover={config[crossover]} \
                        --analytical-pi={config[analytical_pi_updates]} \
                        --analytical-sigma={config[analytical_sigma_updates]} \
                        --output={output[0]}
        """


rule patchify:
    input: "data/{img}_sigma{sigma}_{i}.h5"
    output: "runs/{img}_sigma{sigma}_{i}/data_patchsize{patch_size}.h5"
    shell:
        """
        python src/patchify.py --in-fname {input} --out-fname {output} --patch-size {wildcards.patch_size}
        """


rule add_noise:
   input: "data/{img}.png"
   output: "data/{img}_sigma{sigma}_{i}.h5"
   shell:
      """
      python src/add_noise.py --img {input[0]} --out-fname {output[0]} --sigma {wildcards.sigma}
      """


ruleorder: patchify > add_noise
