# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html

"""Submit a function to be run either locally or in a computing cluster."""

import copy
import inspect
import os
import pathlib
import pickle
import platform
import pprint
import re
import shutil
import sys
import time
import traceback

from enum import Enum

from .. import util
from ..util import EasyDict

from . import internal

class SubmitTarget(Enum):
    """The target where the function should be run.

    LOCAL: Run it locally.
    """
    LOCAL = 1


class PathType(Enum):
    """Determines in which format should a path be formatted.

    WINDOWS: Format with Windows style.
    LINUX: Format with Linux/Posix style.
    AUTO: Use current OS type to select either WINDOWS or LINUX.
    """
    WINDOWS = 1
    LINUX = 2
    AUTO = 3


class PlatformExtras:
    """A mixed bag of values used by dnnlib heuristics.

    Attributes:

        data_reader_buffer_size: Used by DataReader to size internal shared memory buffers.
        data_reader_process_count: Number of worker processes to spawn (zero for single thread operation)
    """
    def __init__(self):
        self.data_reader_buffer_size = 1<<30    # 1 GB
        self.data_reader_process_count = 0      # single threaded default


_user_name_override = None

class SubmitConfig(util.EasyDict):
    """Strongly typed config dict needed to submit runs.

    Attributes:
        run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
        run_desc: Description of the run. Will be used in the run dir and task name.
        run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
        run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
        submit_target: Submit target enum value. Used to select where the run is actually launched.
        num_gpus: Number of GPUs used/requested for the run.
        print_info: Whether to print debug information when submitting.
        local.do_not_copy_source_files: Do not copy source files from the working directory to the run dir.
        run_id: Automatically populated value during submit.
        run_name: Automatically populated value during submit.
        run_dir: Automatically populated value during submit.
        run_func_name: Automatically populated value during submit.
        run_func_kwargs: Automatically populated value during submit.
        user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
        task_name: Automatically populated value during submit.
        host_name: Automatically populated value during submit.
        platform_extras: Automatically populated values during submit.  Used by various dnnlib libraries such as the DataReader class.
    """

    def __init__(self):
        super().__init__()

        # run (set these)
        self.run_dir_root = ""  # should always be passed through get_path_from_template
        self.run_desc = ""
        self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode", "_cudacache"]
        self.run_dir_extra_files = []

        # submit (set these)
        self.submit_target = SubmitTarget.LOCAL
        self.num_gpus = 1
        self.print_info = False
        self.nvprof = False
        self.local = internal.local.TargetOptions()
        self.datasets = []

        # (automatically populated)
        self.run_id = None
        self.run_name = None
        self.run_dir = None
        self.run_func_name = None
        self.run_func_kwargs = None
        self.user_name = None
        self.task_name = None
        self.host_name = "localhost"
        self.platform_extras = PlatformExtras()


def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
    """Replace tags in the given path template and return either Windows or Linux formatted path."""
    # automatically select path type depending on running OS
    if path_type == PathType.AUTO:
        if platform.system() == "Windows":
            path_type = PathType.WINDOWS
        elif platform.system() == "Linux":
            path_type = PathType.LINUX
        else:
            raise RuntimeError("Unknown platform")

    path_template = path_template.replace("<USERNAME>", get_user_name())

    # return correctly formatted path
    if path_type == PathType.WINDOWS:
        return str(pathlib.PureWindowsPath(path_template))
    elif path_type == PathType.LINUX:
        return str(pathlib.PurePosixPath(path_template))
    else:
        raise RuntimeError("Unknown platform")


def get_template_from_path(path: str) -> str:
    """Convert a normal path back to its template representation."""
    path = path.replace("\\", "/")
    return path


def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
    """Convert a normal path to template and the convert it back to a normal path with given path type."""
    path_template = get_template_from_path(path)
    path = get_path_from_template(path_template, path_type)
    return path


def set_user_name_override(name: str) -> None:
    """Set the global username override value."""
    global _user_name_override
    _user_name_override = name


def get_user_name():
    """Get the current user name."""
    if _user_name_override is not None:
        return _user_name_override
    elif platform.system() == "Windows":
        return os.getlogin()
    elif platform.system() == "Linux":
        try:
            import pwd
            return pwd.getpwuid(os.geteuid()).pw_name
        except:
            return "unknown"
    else:
        raise RuntimeError("Unknown platform")


def make_run_dir_path(*paths):
    """Make a path/filename that resides under the current submit run_dir.

    Args:
        *paths: Path components to be passed to os.path.join

    Returns:
        A file/dirname rooted at submit_config.run_dir.  If there's no
        submit_config or run_dir, the base directory is the current
        working directory.

    E.g., `os.path.join(dnnlib.submit_config.run_dir, "output.txt"))`
    """
    import dnnlib
    if (dnnlib.submit_config is None) or (dnnlib.submit_config.run_dir is None):
        return os.path.join(os.getcwd(), *paths)
    return os.path.join(dnnlib.submit_config.run_dir, *paths)


def _create_run_dir_local(submit_config: SubmitConfig) -> str:
    """Create a new run dir with increasing ID number at the start."""
    run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)

    if not os.path.exists(run_dir_root):
        os.makedirs(run_dir_root)

    submit_config.run_id = _get_next_run_id_local(run_dir_root)
    submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
    run_dir = os.path.join(run_dir_root, submit_config.run_name)

    if os.path.exists(run_dir):
        raise RuntimeError("The run dir already exists! ({0})".format(run_dir))

    os.makedirs(run_dir)

    return run_dir


def _get_next_run_id_local(run_dir_root: str) -> int:
    """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
    dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
    r = re.compile("^\\d+")  # match one or more digits at the start of the string
    run_id = 0

    for dir_name in dir_names:
        m = r.match(dir_name)

        if m is not None:
            i = int(m.group())
            run_id = max(run_id, i + 1)

    return run_id


def _populate_run_dir(submit_config: SubmitConfig, run_dir: str) -> None:
    """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
    pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
    with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
        pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)

    if (submit_config.submit_target == SubmitTarget.LOCAL) and submit_config.local.do_not_copy_source_files:
        return

    files = []

    run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
    assert '.' in submit_config.run_func_name
    for _idx in range(submit_config.run_func_name.count('.') - 1):
        run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
    files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)

    dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
    files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)

    files += submit_config.run_dir_extra_files

    files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
    files += [(os.path.join(dnnlib_module_dir_path, "submission", "internal", "run.py"), os.path.join(run_dir, "run.py"))]

    util.copy_files_and_create_dirs(files)



def run_wrapper(submit_config: SubmitConfig) -> None:
    """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
    is_local = submit_config.submit_target == SubmitTarget.LOCAL

    # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
    if is_local:
        logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
    else:  # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
        logger = util.Logger(file_name=None, should_flush=True)

    import dnnlib
    dnnlib.submit_config = submit_config

    exit_with_errcode = False
    try:
        print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
        start_time = time.time()

        run_func_obj = util.get_obj_by_name(submit_config.run_func_name)
        assert callable(run_func_obj)
        sig = inspect.signature(run_func_obj)
        if 'submit_config' in sig.parameters:
            run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
        else:
            run_func_obj(**submit_config.run_func_kwargs)

        print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
    except:
        if is_local:
            raise
        else:
            traceback.print_exc()

            log_src = os.path.join(submit_config.run_dir, "log.txt")
            log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
            shutil.copyfile(log_src, log_dst)

            # Defer sys.exit(1) to happen after we close the logs and create a _finished.txt
            exit_with_errcode = True
    finally:
        open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()

    dnnlib.RunContext.get().close()
    dnnlib.submit_config = None
    logger.close()

    # If we hit an error, get out of the script now and signal the error
    # to whatever process that started this script.
    if exit_with_errcode:
        sys.exit(1)

    return submit_config


def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
    """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
    submit_config = copy.deepcopy(submit_config)

    submit_target = submit_config.submit_target
    farm = None
    if submit_target == SubmitTarget.LOCAL:
        farm = internal.local.Target()
    assert farm is not None # unknown target

    # Disallow submitting jobs with zero num_gpus.
    if (submit_config.num_gpus is None) or (submit_config.num_gpus == 0):
        raise RuntimeError("submit_config.num_gpus must be set to a non-zero value")

    if submit_config.user_name is None:
        submit_config.user_name = get_user_name()

    submit_config.run_func_name = run_func_name
    submit_config.run_func_kwargs = run_func_kwargs

    #--------------------------------------------------------------------
    # Prepare submission by populating the run dir
    #--------------------------------------------------------------------
    host_run_dir = _create_run_dir_local(submit_config)

    submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
    docker_valid_name_regex = "^[a-zA-Z0-9][a-zA-Z0-9_.-]+$"
    if not re.match(docker_valid_name_regex, submit_config.task_name):
        raise RuntimeError("Invalid task name.  Probable reason: unacceptable characters in your submit_config.run_desc.  Task name must be accepted by the following regex: " + docker_valid_name_regex + ", got " + submit_config.task_name)

    # Farm specific preparations for a submit
    farm.finalize_submit_config(submit_config, host_run_dir)
    _populate_run_dir(submit_config, host_run_dir)
    return farm.submit(submit_config, host_run_dir)
