import os
import subprocess
from typing import Sequence

from absl import logging

from openfold.data.tools import utils


def _to_a3m(sequences: Sequence[str]) -> str:

    names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
    a3m = []
    for sequence, name in zip(sequences, names):
        a3m.append(">" + name + "\n")
        a3m.append(sequence + "\n")
    return "".join(a3m)


class Kalign:

    def __init__(self, *, binary_path: str):

        self.binary_path = binary_path

    def align(self, sequences: Sequence[str]) -> str:

        logging.info("Aligning %d sequences", len(sequences))

        for s in sequences:
            if len(s) < 6:
                raise ValueError(
                    "Kalign requires all sequences to be at least 6 "
                    "residues long. Got %s (%d residues)." % (s, len(s))
                )

        with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
            input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
            output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")

            with open(input_fasta_path, "w") as f:
                f.write(_to_a3m(sequences))

            cmd = [
                self.binary_path,
                "-i",
                input_fasta_path,
                "-o",
                output_a3m_path,
                "-format",
                "fasta",
            ]

            logging.info('Launching subprocess "%s"', " ".join(cmd))
            process = subprocess.Popen(
                cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
            )

            with utils.timing("Kalign query"):
                stdout, stderr = process.communicate()
                retcode = process.wait()
                logging.info(
                    "Kalign stdout:\n%s\n\nstderr:\n%s\n",
                    stdout.decode("utf-8"),
                    stderr.decode("utf-8"),
                )

            if retcode:
                raise RuntimeError(
                    "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
                    % (stdout.decode("utf-8"), stderr.decode("utf-8"))
                )

            with open(output_a3m_path) as f:
                a3m = f.read()

            return a3m
