"""JIT C++ strings into executables."""
import atexit
import os
import re
import shutil
import textwrap
import threading
from typing import Any, List, Optional

import torch
from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType
from torch.utils.benchmark.utils.common import _make_temp_dir
from torch.utils import cpp_extension


LOCK = threading.Lock()
SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0]

# We calculate uuid once at import time so that separate processes will have
# separate build roots, but threads will share the same build root.
# `cpp_extension` uses build root as part of the cache key, so per-invocation
# uuid's (e.g. different build root per _compile_template call) would lead to
# a 0% cache hit rate and spurious recompilation. Consider the following:
#   ```
#   setup = "auto x = torch::ones({1024, 1024});"
#   stmt = "torch::mm(x, x);"
#   for num_threads in [1, 2, 4, 8]:
#     print(Timer(stmt, setup, num_threads=num_threads, language="c++").blocked_autorange())
#   ````
# `setup` and `stmt` do not change, so we can reuse the executable from the
# first pass through the loop.
_BUILD_ROOT: Optional[str] = None

def _get_build_root() -> str:
    global _BUILD_ROOT
    if _BUILD_ROOT is None:
        _BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
        atexit.register(shutil.rmtree, _BUILD_ROOT)
    return _BUILD_ROOT


# BACK_TESTING_NOTE:
#   There are two workflows where this code could be used. One is the obvious
#   case where someone simply builds or installs PyTorch and uses Timer.
#   The other is that the entire `torch/utils/benchmark` folder from a CURRENT
#   PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch
#   source code. This is what we refer to here as "back testing". The rationale
#   is that we might want to use current tooling to study some aspect of an
#   earlier version of PyTorch. (e.g. a regression.)
#
#   The problem is that Timer relies on several aspects of core PyTorch, namely
#   some binding functions for Valgrind symbols in `torch._C` and the
#   `torch.__config__._cxx_flags()` method. If we were to naively copy code
#   around this wouldn't work as the symbols of interest aren't present in
#   earlier versions of PyTorch. In order to work around this, we must add back
#   testing shims. These shims will never activate during normal use, but will
#   allow Timer to function outside of the "correct" version of PyTorch by
#   emulating functionality that was added later.
#
#   These shims are temporary, and as Timer becomes more integrated with
#   PyTorch the cost and complexity of such shims will increase. Once back
#   testing is no longer required (which is to say we have done enough historic
#   analysis and the shims no longer justify their maintenance and code
#   complexity costs) back testing paths will be removed.

CXX_FLAGS: Optional[List[str]]
if hasattr(torch.__config__, "_cxx_flags"):
    try:
        CXX_FLAGS = torch.__config__._cxx_flags().strip().split()
        if CXX_FLAGS is not None and "-g" not in CXX_FLAGS:
            CXX_FLAGS.append("-g")

    except RuntimeError:
        # We are in FBCode.
        CXX_FLAGS = None
else:
    # FIXME: Remove when back testing is no longer required.
    CXX_FLAGS = ["-O2", "-fPIC", "-g"]

EXTRA_INCLUDE_PATHS: List[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")]
CONDA_PREFIX = os.getenv("CONDA_PREFIX")
if CONDA_PREFIX is not None:
    # Load will automatically search /usr/include, but not conda include.
    EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include"))


COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None
def get_compat_bindings() -> CallgrindModuleType:
    with LOCK:
        global COMPAT_CALLGRIND_BINDINGS
        if COMPAT_CALLGRIND_BINDINGS is None:
            COMPAT_CALLGRIND_BINDINGS = cpp_extension.load(
                name="callgrind_bindings",
                sources=[os.path.join(
                    SOURCE_ROOT,
                    "valgrind_wrapper",
                    "compat_bindings.cpp"
                )],
                extra_cflags=CXX_FLAGS,
                extra_include_paths=EXTRA_INCLUDE_PATHS,
            )
    return COMPAT_CALLGRIND_BINDINGS


def _compile_template(
    *,
    stmt: str,
    setup: str,
    global_setup: str,
    src: str,
    is_standalone: bool
) -> Any:
    for before, after, indentation in (
        ("// GLOBAL_SETUP_TEMPLATE_LOCATION", global_setup, 0),
        ("// SETUP_TEMPLATE_LOCATION", setup, 4),
        ("// STMT_TEMPLATE_LOCATION", stmt, 8)
    ):
        # C++ doesn't care about indentation so this code isn't load
        # bearing the way it is with Python, but this makes the source
        # look nicer if a human has to look at it.
        src = re.sub(
            before,
            textwrap.indent(after, " " * indentation)[indentation:],
            src
        )

    # We want to isolate different Timers. However `cpp_extension` will
    # cache builds which will significantly reduce the cost of repeated
    # invocations.
    with LOCK:
        name = f"timer_cpp_{abs(hash(src))}"
        build_dir = os.path.join(_get_build_root(), name)
        os.makedirs(build_dir, exist_ok=True)

        src_path = os.path.join(build_dir, "timer_src.cpp")
        with open(src_path, "wt") as f:
            f.write(src)

    # `cpp_extension` has its own locking scheme, so we don't need our lock.
    return cpp_extension.load(
        name=name,
        sources=[src_path],
        build_directory=build_dir,
        extra_cflags=CXX_FLAGS,
        extra_include_paths=EXTRA_INCLUDE_PATHS,
        is_python_module=not is_standalone,
        is_standalone=is_standalone,
    )


def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType:
    template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp")
    with open(template_path, "rt") as f:
        src: str = f.read()

    module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False)
    assert isinstance(module, TimeitModuleType)
    return module


def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str:
    template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp")
    with open(template_path, "rt") as f:
        src: str = f.read()

    target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True)
    assert isinstance(target, str)
    return target
