# MAIN RULE OUTPUTS:
# data dir:
#  - one file for bars sampled from BSC
#  - one file for bars with uniform random noise
#
# runs dir:
#  - one directory for each dataset, with:
#     * one output file per training configuration and training kind (BSC, TVAE, TVAE+BSC)
#
# figures dir:
#  - one directory for each dataset, with:
#     - one directory per training output, with:
#        - gif and png viz for that run

from itertools import product
from pathlib import Path
from collections import deque

shell.prefix("set -e;")
shell.executable("/bin/bash")

configfile: "config.yml"

data_conf = config["data"]
data_dir = "data"

run_conf = config["experiment"]
runs_dir = "runs"

post_conf = config["post"]
figs_dir = "figures"

datasets = expand("H{H}_N{N}.h5", N=data_conf["N"], H=data_conf["H"])

def dirname_from_filename(fname):
   return Path(fname).name.replace(".h5", "")

def run_outputs_per_dataset():
   batches, Ss, min_lr, max_lr = (run_conf[p] for p in ("batch-size", "Ksize", "min_lr", "max_lr"))
   reps = range(int(run_conf["reps"]))
   return expand("batch{b}_S{s}_lr{min_lr:1.0e}to{max_lr:1.0e}_rep{r}.h5", b=batches, s=Ss, min_lr=min_lr, max_lr=max_lr, r=reps)

run_outputs_per_dataset = run_outputs_per_dataset()

def all_run_outputs():
   run_outs = []
   for d, o in product(datasets, run_outputs_per_dataset):
      dir_for_dataset = Path(d).name.replace(".h5", "")
      run_outs.append(f"{runs_dir}/{dir_for_dataset}/{o}")
   return run_outs

all_run_outputs = all_run_outputs()

def all_figures():
   figs = []
   for ds, run in product(datasets, run_outputs_per_dataset):
      data_dir, run_dir = map(dirname_from_filename, (ds, run))
      for suffix in ("png",):# "gif"):
         figs.append(f"{figs_dir}/{data_dir}/{run_dir}/viz.{suffix}")
   return figs

all_figures = all_figures()

def print_list(title, l):
   print(f"{title}:", *map("\t".__add__, l), sep="\n", end="\n\n")

print_list("datasets", datasets)
print_list("run outs", all_run_outputs)
print_list("figures", all_figures)

rule run_experiment:
   input: all_figures

rule viz_run:
   input:
      run_out=runs_dir + "/{dataset}/{run}.h5",
      dataset=data_dir + "/{dataset}.h5"
   output: figs_dir + "/{dataset}/{run}/viz.png"#, figs_dir + "/{dataset}/{run}/viz.gif"
   shell:
      """
      python viz.py --data {input.dataset} --train-output {input.run_out}\
             --output-dir {figs_dir}/{wildcards.dataset}/{wildcards.run}\
             --gif-step={post_conf[gif_step]}
      """

rule gen_data:
   output: data_dir + "/H{H}_N{N}.h5"
   shell:
      """
      python gen_bars_data.py -N {wildcards.N} -H {wildcards.H} --output-dir {data_dir}
      """

rule train:
   input:
      dataset=data_dir + "/{dataset}.h5",
   output: "runs/{dataset}/batch{batch_size}_S{Ksize}_lr{min_lr}to{max_lr}_rep{rep}.h5"
   shell:
      """
      python train.py {input.dataset} \
                      --epochs {run_conf[epochs]} \
                      --net-shape {run_conf[net-shape]} \
                      --batch-size={wildcards.batch_size}\
                      --Ksize={wildcards.Ksize} \
                      --min_lr={wildcards.min_lr} \
                      --max_lr={wildcards.max_lr} \
                      --output={output}
      """

rule clean:
   shell: "rm -rf {runs_dir} {figs_dir} {data_dir} __pycache__"
