"""Common utilities for running on craffel's servers."""
import base64
import dataclasses
import json
import os
import shlex
from typing import Any, Mapping, Optional, Sequence
import uuid


from . import ssh_execution_util


_SERVERS = frozenset({"watermelon", "guava", "banana", "mango"})

# The file paths here are relative to the project code directory.
_DEFAULT_REQUIREMENTS_FILE = "requirements.txt"
# _START_PROCESS_FILE = "mm/execution/bgmanager/start_process.py"


@dataclasses.dataclass
class RunConfig(object):
    # IDs of GPUS to use. As the servers stand as of 2021-09-02,
    # each id should belong to {0, 1, 2, 3}.
    gpus: Sequence[int]

    install_pip_requirements: bool = False

    @property
    def num_gpus(self) -> int:
        return len(self.gpus)


@dataclasses.dataclass
class EnvConfig(ssh_execution_util.SshEnvConfig):
    user: str
    server: str

    # Name of the virtualenv to use. Must already exist but does not
    # have to be intialized.
    virtualenv_name: str

    local_project_code_dir: str

    remote_project_code_dir: str
    remote_project_data_dir: str
    remote_logs_dir: str
    remote_tfds_data_dir: Optional[str] = None

    # The file path here is relative to the project code
    # directory.
    requirements_file: str = _DEFAULT_REQUIREMENTS_FILE

    def __post_init__(self):
        assert self.server in _SERVERS

    @property
    def ssh_host(self) -> str:
        return f"{self.user}@{self.server}.cs.unc.edu"

    @property
    def start_process_script(self) -> str:
        return os.path.join(self.remote_project_code_dir, _START_PROCESS_FILE)

    @property
    def requirements_filepath(self) -> str:
        return os.path.join(self.remote_project_code_dir, self.requirements_file)


@dataclasses.dataclass
class FruitEnv(ssh_execution_util.SshEnv):

    run_config: RunConfig

    def _generate_run_id(self) -> str:
        return uuid.uuid4().hex

    def _move_code_to_remote(self):
        self.copy_to_env(
            self.env_config.local_project_code_dir,
            os.path.dirname(self.env_config.remote_project_code_dir),
        )

    def _make_start_process_command(
        self, cmd: str, run_id: str, extra_info: Mapping[str, Any]
    ) -> str:
        ec = self.env_config
        rc = self.run_config
        remote_logs_file = os.path.join(ec.remote_logs_dir, f"{run_id}.log")
        env_vars = json.dumps(
            {"CUDA_VISIBLE_DEVICES": ",".join(str(g) for g in rc.gpus)}
        )
        ret = [
            f"python {ec.start_process_script}",
            f'--cmd_b64={base64.b64encode(cmd.encode("utf8")).decode("utf8")}',
            f"--env_vars={shlex.quote(env_vars)}",
            f"--run_id={run_id}",
            f"--output_file={remote_logs_file}",
            f"--extra_info={shlex.quote(json.dumps(extra_info))}",
        ]
        return " ".join(ret)

    def _make_virtualenv_command(self):
        ec = self.env_config
        cmd = f"workon {ec.virtualenv_name}"
        if self.run_config.install_pip_requirements:
            cmd += f"\npip install -r {ec.requirements_filepath}"
        return cmd

    def launch_experiment(self, script_config, args_config):
        ec = self.env_config

        run_id = self._generate_run_id()

        self._move_code_to_remote()

        python_cmd = script_config.make_run_command(
            args_config=args_config,
            env_config=self.env_config,
            num_gpus=self.run_config.num_gpus,
        )
        print(python_cmd)

        extra_info = {
            "env_config": dataclasses.asdict(self.env_config),
            "run_config": dataclasses.asdict(self.run_config),
            "script_config": dataclasses.asdict(script_config),
            "args_config": dataclasses.asdict(args_config),
        }

        start_cmd = self._make_start_process_command(python_cmd, run_id, extra_info)

        ll_cmd = "\n".join(
            [
                self._make_virtualenv_command(),
                f"export PYTHONPATH=$PYTHONPATH:{ec.remote_project_code_dir}",
                start_cmd,
            ]
        )

        # Actually run on the server.
        self.run_command(ll_cmd)
