all_datasets = ["avila","breast_cancer","car_evaluation","congressional_votes","digits","haberman_survival","iris","mice_protein","poker_hand","vowel","wine"]
all_methods = ["kauri", "exkmc", "ktree"]
max_leaves = 20

def get_input_list(wildcards):
    n_clusters = 1
    if wildcards.dataset in ["breast_cancer", "congressional_votes", "haberman_survival", "vowel"]:
        n_clusters = 2
    elif wildcards.dataset in ["iris", "wine"]:
        n_clusters = 3
    elif wildcards.dataset=="car_evaluation":
        n_clusters = 4
    elif wildcards.dataset in ["poker_hand", "digits"]:
        n_clusters = 10
    elif wildcards.dataset=="mice_protein":
        n_clusters = 8
    elif wildcards.dataset=="avila":
        n_clusters = 12
    return expand("{method}/{dataset}/{n_leaf}_leaves.csv", method=wildcards.method, dataset=wildcards.dataset, n_leaf=range(n_clusters, max_leaves+1))

rule all:
    input:
        expand("{method}/{dataset}/result.csv", method=all_methods, dataset=all_datasets)

rule concatenate_results:
    input:
        get_input_list
    output:
        "{method}/{dataset}/result.csv"
    params:
        header="{method}/{dataset}/head.csv",
        content="{method}/{dataset}/content.csv"
    run:
        shell("head -n1 {input[0]} > {params.header}"),
        shell("tail -n+2 {input} -q > {params.content}")
        shell("cat {params.header} {params.content} > {output}")
        shell("rm {params.header} {params.content}")

rule make_run:
    output:
        "{method}/{dataset}/{n_leaf}_leaves.csv"
    log:
        out="logs/{method}_{dataset}_{n_leaf}.out",
        err="logs/{method}_{dataset}_{n_leaf}.err"
    shell:
        "python ../scripts/benchmark_main.py {wildcards.method} --dataset {wildcards.dataset} --output_file {output} --subset_size 1 --max_leaves {wildcards.n_leaf} --path_to_data ../data/datasets >{log.out} 2> {log.err}"
