import argparse

import numpy as np
import pandas as pd
from Bio import SeqIO


def generate_onehot_msa_encoding(dataset: str):
    # Dataset specific paths
    output_dir = f"data/processed/{dataset}/onehot_msa_encodings"
    msa_path = f"data/processed/{dataset}/{dataset}_local.aln.fasta"

    # AA dictionary
    aa_dict = {aa: i for i, aa in enumerate("ACDEFGHIKLMNPQRSTVWY")}

    # Determine maximum sequence length
    df = pd.read_csv(f"data/processed/{dataset}/{dataset}.csv", index_col=0)
    dummy = next(iter(SeqIO.parse(open(msa_path), "fasta")))
    seq_len = len(dummy.seq)

    # Create zero-padded one-hot encoding and save to disk
    for fasta in SeqIO.parse(open(msa_path), "fasta"):
        one_hot = np.zeros((seq_len, 20))
        if fasta.id in df["name"].tolist():
            for j, letter in enumerate(str(fasta.seq)):
                if letter in aa_dict:
                    k = aa_dict[letter]
                    one_hot[j, k] = 1.0
            np.save(file=f"{output_dir}/{fasta.id}.npy", arr=one_hot)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", type=str)
    args = parser.parse_args()
    generate_onehot_msa_encoding(args.dataset)
