"""Common utilities for starting execution over ssh."""
import abc
import os
import subprocess
from typing import List, Sequence

from . import execution_abcs

# Default things to exclude via rsync's --exclude flags.
DEFAULT_EXCLUDES = (
    "*/__pycache__",
    "*/.git",
)


class SshEnvConfig(abc.ABC):
    # What we ssh to. For example, "m@longleaf.unc.edu".
    @property
    @abc.abstractmethod
    def ssh_host(self) -> str:
        raise NotImplementedError


class SshEnv(execution_abcs.ExecutionEnv):
    def get_copy_excludes(self) -> Sequence[str]:
        """Returns a list of patterns to exclude via rsync's --exclude flags."""
        return DEFAULT_EXCLUDES

    def _make_rsync_cmd(self, src: str, dst: str) -> List[str]:
        cmd = ["rsync", "-ra", "-e", "ssh"]
        for s in self.get_copy_excludes():
            cmd.extend(["--exclude", s])
        cmd.extend([src, dst])
        return cmd

    def _run_locally_via_subprocess(self, cmd: Sequence[str]):
        """Runs a command locally via subprocess.check_output(cmd)."""
        try:
            output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as e:
            print(e.output)
            raise e
        return output

    def copy_to_env(self, localpath: str, dstpath: str):
        """Copies a local folder to the execution environment.

        Args:
            localpath: the path to the local folder/file to copy over.
            dstpath: the destination folder/file to copy to.
        """
        localpath = os.path.expanduser(localpath)
        local_cmd = self._make_rsync_cmd(
            localpath, f"{self.env_config.ssh_host}:{dstpath}"
        )
        return self._run_locally_via_subprocess(local_cmd)

    def copy_from_env(self, envpath: str, localpath: str):
        """Copies from the remote execution environment to local machine.

        Args:
            envpath: the path to the folder/file on the environment to copy over.
            localpath: the destination folder/file on the local machine to copy to.
        """
        localpath = os.path.expanduser(localpath)
        local_cmd = self._make_rsync_cmd(
            f"{self.env_config.ssh_host}:{envpath}", localpath
        )
        return self._run_locally_via_subprocess(local_cmd)

    def run_command(self, cmd: str):
        """Run the command on the execution environment.

        NOTE: Need to have password-free access to the remove ssh server set up. See
        https://serverfault.com/questions/2429 for how to do this.

        Args:
            cmd: the command to run. Will typically be a shell command.
        """
        local_cmd = [
            "ssh",
            self.env_config.ssh_host,
            cmd,
        ]
        return self._run_locally_via_subprocess(local_cmd)
