thin_datasets = ["avila","breast_cancer","congressional_votes","haberman_survival","iris","poker_hand","vowel","wine"]
fat_datasets = ["mice_protein", "car_evaluation", "digits"]
all_deterministic_trees = ["kauri_large", "kauri_small", "exkmc", "ktree_large", "ktree_small", "exshallow", "imm", "rdm"]
all_differentiable_trees = ["douglas", "torchdouglas"]
n_runs = 30

def get_leaves_large(wildcards):
    if wildcards.dataset in ["breast_cancer", "congressional_votes", "haberman_survival", "vowel"]:
        return 8
    elif wildcards.dataset in ["iris", "wine"]:
        return 12
    elif wildcards.dataset=="car_evaluation":
        return 16
    elif wildcards.dataset in ["poker_hand", "digits"]:
        return 40
    elif wildcards.dataset=="mice_protein":
        return 32
    elif wildcards.dataset=="avila":
        return 48

def get_leaves_small(wildcards):
    if wildcards.dataset in ["breast_cancer", "congressional_votes", "haberman_survival", "vowel"]:
        return 2
    elif wildcards.dataset in ["iris", "wine"]:
        return 3
    elif wildcards.dataset=="car_evaluation":
        return 4
    elif wildcards.dataset in ["poker_hand", "digits"]:
        return 10
    elif wildcards.dataset=="mice_protein":
        return 8
    elif wildcards.dataset=="avila":
        return 12

def get_batch_size(wildcards):
    if wildcards.dataset in ["avila", "poker_hand"]:
        return "256 --n_epochs 50"
    elif wildcards.dataset in ["breast_cancer", "haberman_survival", "wine"]:
        return 64
    elif wildcards.dataset=="iris":
        return 32
    else:
        return 128

def get_torch(wildcards):
    if wildcards.tree == "torchdouglas":
        return "--torch"
    return ""

rule all:
    input:
        expand("{method}/{dataset}_run_{run}.csv", method=all_deterministic_trees+all_differentiable_trees, dataset=thin_datasets, run=range(n_runs)),
        expand("{method}/{dataset}_run_{run}.csv", method=all_deterministic_trees, dataset=fat_datasets, run=range(n_runs)),

rule run_dt_large:
    output:
        "{tree,(kauri|ktree)}_large/{dataset}_run_{run}.csv"
    log:
        out="log/{tree}_{dataset}_run_{run}.out",
        err="log/{tree}_{dataset}_run_{run}.err"
    params:
        max_leaves=get_leaves_large
    shell:
        "python ../scripts/benchmark_main.py {wildcards.tree} --subset_size 0.8 --max_leaves {params.max_leaves} --dataset {wildcards.dataset} --path_to_data ../data/datasets --output_file {output} > {log.out} 2> {log.err}"

rule run_dt_small:
    output:
        "{tree,(kauri|ktree)}_small/{dataset}_run_{run}.csv"
    log:
        out="log/{tree}_{dataset}_run_{run}.out",
        err="log/{tree}_{dataset}_run_{run}.err"
    params:
        max_leaves=get_leaves_small
    shell:
        "python ../scripts/benchmark_main.py {wildcards.tree} --subset_size 0.8 --max_leaves {params.max_leaves} --dataset {wildcards.dataset} --path_to_data ../data/datasets --output_file {output} > {log.out} 2> {log.err}"


rule run_competitor:
    output:
        "{tree,(rdm|exkmc|exshallow|imm)}/{dataset}_run_{run}.csv"
    log:
        out="log/{tree}_{dataset}_run_{run}.out",
        err="log/{tree}_{dataset}_run_{run}.err"
    shell:
        "python ../scripts/benchmark_main.py {wildcards.tree} --subset_size 0.8 --dataset {wildcards.dataset} --path_to_data ../data/datasets --output_file {output} > {log.out} 2> {log.err}"

rule run_douglas:
    output:
        "{tree,(douglas|torchdouglas)}/{dataset}_run_{run}.csv"
    log:
        out="log/{tree}_{dataset}_run_{run}.out",
        err="log/{tree}_{dataset}_run_{run}.err"
    params:
        batch_size=get_batch_size,
        with_torch=get_torch
    shell:
        "python ../scripts/benchmark_main.py {wildcards.tree} --subset_size 0.8 --dataset {wildcards.dataset} --path_to_data ../data/datasets --output_file {output} --batch_size {params.batch_size} {params.with_torch}> {log.out} 2> {log.err}"