
"""
Rules file to run comparison scripts over a specified dataset
"""

import json
import os
import pandas as pd
import numpy as np

from trainable_scattering.conductor.train_model_shallow import train_model_shallow

configfile: "config.yaml"
# validate(config, schema="../schemas/config.schema.yaml")
EMAIL = "anonymous@gmail.com"

PARAM_FILE = 'shallow.json'
#PARAM_FILE = 'v2.json'

with open(os.path.join(config['args_dir'], PARAM_FILE), 'r') as fp:
    run_args = json.load(fp)
meta_data = run_args["meta_data"]
runs = run_args["runs"]

def compute_result_files(wildcards):
    result_files = []
    for rp in runs:
        result_files.append(os.path.join(rp['model_dir'], 'result.npz'))
    return result_files

rule:
    input:
        #compute_result_files
        [os.path.join(rp['model_dir'], 'result.npz') for rp in runs]
        #[os.path.join(rp['model_dir'], 'result.npy') for rp in runs]

rule write_args:
    output:
        # This should be changed for non-consqueutive or non-zero intialized runs
        [os.path.join(rp['model_dir'], 'args.json') for rp in runs]
    run:
        for i in range(len(runs)):
            with open(output[i], 'w') as fp:
                json.dump(runs[i], fp, indent=4, sort_keys=True)

rule train:
    output:
        "{model_dir}/{index}/result.npz"
    input:
        "{model_dir}/{index}/args.json"
    run:
        train_model_shallow(input, output)


# Unideal as it is essentially a global lock until all are computed
"""
rule aggregate:
    input:
        compute_result_files
    output:
        [os.path.join(rp['model_dir'], 'result.npy') for rp in runs]
    run:
        for rp in runs:
            results = []
            for j in range(1, rp['dataset_args']['num_tps']-1):
                results.append(np.load(os.path.join(rp['model_dir'], str(j), 'result.npy')))
            np.save(os.path.join(rp['model_dir'], 'result.npy'), np.array(results).flatten())
"""

onsuccess:
    shell("mail -s 'Snakemake Completed!' %s < {log}" % EMAIL)

onerror:
    shell("mail -s 'Snakemake Error' %s < {log}" % EMAIL)
