import argparse

import pandas as pd
from Bio import SeqIO


def generate_query_sequence(dataset: str, verbose: bool = True):
    """Function to parse MSA sequences from CSV file and generate a query sequence for use in EVE
    The resulting string will have the character A where any of the sequences have a non-gap in their MSA. The
    remaining positions will be gap-filled.

    Args:
        dataset: One of `gh1`, `cm` or `ww`.
    """

    # Define paths
    if dataset == "cm":
        msa_path = f"data/raw/{dataset}/{dataset}_uniref100.aln.fasta"
        csv_path = f"data/raw/{dataset}/{dataset}.csv"
        name_key = "name"
    elif dataset in ["ppat", "tim"]:
        msa_path = f"data/raw/{dataset}/{dataset}_family.aln.fasta"
        csv_path = f"data/processed/{dataset}/{dataset}.csv"
        name_key = "name"
    else:
        raise NotImplementedError

    # Load first query to extract sequence length
    fasta = next(iter(SeqIO.parse(open(msa_path), "fasta")))
    seq_len = len(str(fasta.seq))

    # Load csv to generate queries
    df = pd.read_csv(csv_path)
    names = df[name_key].tolist()

    # Initialise query as all-gaps
    query_sequence = ["-"] * seq_len
    if verbose:
        print("Generating query sequence.")

    # Iterating through MSAs to generate query
    fasta_sequences = SeqIO.parse(open(msa_path), "fasta")
    for fasta in fasta_sequences:
        name, sequence = fasta.id, str(fasta.seq)
        if name in names:
            for i, char in enumerate(sequence):
                if char != "-":
                    query_sequence[i] = "A"

    query_sequence = "".join(query_sequence)

    if verbose:
        print(f"Length of MSA: {len(query_sequence)}.")
        print(f'Non-gaps in query: {len(query_sequence.replace("-", ""))}')
        print(f'First non-gap at position {query_sequence.find("A")}.')
        print(f'Last non-gap at position {query_sequence[::-1].find("A")}.')
        print(
            f'Non-gap interval length: {query_sequence[::-1].find("A") - query_sequence.find("A")}'
        )

    with open(f"data/processed/{dataset}/{dataset}_EVE_query.txt", "w") as f:
        f.write(query_sequence)

    if verbose:
        print(
            f"Query sequence written to data/processed/{dataset}/{dataset}_EVE_query.txt"
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", type=str)
    args = parser.parse_args()
    generate_query_sequence(args.dataset)
