# CONFIGURABLE VALUES
N_TIPS = config.get("n_tips", 50)
N_BATCHES = config.get("n_batches", 2)  # this is to parallelize
N_TREES = config.get("n_trees", 10)  # Number per batch
SEQLEN = config.get("seqlen", 500)
SUBST_MAT = config.get("substitution_matrix", "LG")
SITE_RATES = config.get("site_rates", "GC")
# Set default root dir
ROOT = config.get("root", "simulated_data")

# Imports for python rules
import os
from Bio import SeqIO
from glob import glob
from pathlib import Path


def move_batch(srcdir, dstdir, filepattern, prefix):
    for srcpath in glob(f"{srcdir}/{filepattern}"):
        filename = Path(srcpath).name
        destpath = os.path.join(dstdir, f"{prefix}_{filename}")
        os.rename(srcpath, destpath)


rule all:
    input:
        trees=f"{ROOT}/trees",
        msas=f"{ROOT}/msas",
        msas_i=f"{ROOT}/msas_int",


rule concat_batches:
    input:
        trees=expand(f"{ROOT}/batch_{{batchnum}}/trees", batchnum=range(N_BATCHES)),
        msas_i=expand(f"{ROOT}/batch_{{batchnum}}/msas_int", batchnum=range(N_BATCHES)),
        msas=expand(f"{ROOT}/batch_{{batchnum}}/msas", batchnum=range(N_BATCHES)),
    output:
        trees=directory(f"{ROOT}/trees"),
        msas_i=directory(f"{ROOT}/msas_int"),
        msas=directory(f"{ROOT}/msas"),
    run:
        # Make sure output directories exist
        os.makedirs(output.trees, exist_ok=True)
        os.makedirs(output.msas, exist_ok=True)
        os.makedirs(output.msas_i, exist_ok=True)

        # Move files
        for tree_batch, msa_i_batch, msa_batch in zip(
            input.trees, input.msas_i, input.msas
        ):
            batchnum = tree_batch.split("_")[-1].split("/")[0]
            prefix = f"b{batchnum}"
            move_batch(tree_batch, output.trees, "*.nwk", prefix)
            move_batch(msa_i_batch, output.msas_i, "*.fa", prefix)
            move_batch(msa_batch, output.msas, "*.fa", prefix)


rule filter_tips:
    input:
        f"{ROOT}/batch_{{batchnum}}/msas_int",
    output:
        msas=temp(directory(f"{ROOT}/batch_{{batchnum}}/msas")),
    threads: 1
    run:
        os.makedirs(output.msas, exist_ok=True)
        for srcpath in glob(f"{input}/*.fa"):
            dstpath = os.path.join(output.msas, Path(srcpath).name)
            SeqIO.write(
                filter(
                    lambda rec: rec.id.startswith("T"),
                    SeqIO.parse(srcpath, format="fasta"),
                ),
                dstpath,
                format="fasta",
            )


rule find_unduplicated:
    input:
        deduplicator="find_unduplicated.py",
        msas=f"{ROOT}/batch_{{batchnum}}/msas_duped",
    params:
        length=SEQLEN,
    output:
        temp(directory(f"{ROOT}/batch_{{batchnum}}/msas_int")),
    threads: 1
    shell:
        """
      python {input.deduplicator} {input.msas} {output} --length {params.length}
      """


rule simulate_trees:
    input:
        simulator="SimulateAlterRescale.py",
    params:
        numtips=N_TIPS,
        numtrees=N_TREES,
    output:
        temp(directory(f"{ROOT}/batch_{{batchnum}}/trees")),
    threads: 1
    shell:
        """
      python {input.simulator}\
        --ntips {params.numtips}\
        --ntrees {params.numtrees}\
        --type birth-death\
        --outdir {output}
      """


rule simulate_alignments:
    input:
        simulator="alisim.py",
        trees=f"{ROOT}/batch_{{batchnum}}/trees",
    params:
        seqlen=SEQLEN,
        matrix=SUBST_MAT,
        rates=SITE_RATES,
    output:
        temp(directory(f"{ROOT}/batch_{{batchnum}}/msas_duped")),
    threads: 4
    shell:
        """
    python {input.simulator}\
      --outdir {output}\
      --length {params.seqlen}\
      --substitution {params.matrix}\
      --gamma {params.rates}\
      --indels\
      --ancestral\
      --allow-duplicate-sequences\
      --processes {threads}\
      {input.trees}
    """
