# Copyright (c) Facebook, Inc. and its affiliates.

import os
import torch

from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.file_io import PathManager

from .torchscript_patch import freeze_training_mode, patch_instances

__all__ = ["scripting_with_instances", "dump_torchscript_IR"]


def scripting_with_instances(model, fields):
    """
    Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
    attributes of :class:`Instances` are "dynamically" added in eager mode，it is difficult
    for scripting to support it out of the box. This function is made to support scripting
    a model that uses :class:`Instances`. It does the following:

    1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
       but with all attributes been "static".
       The attributes need to be statically declared in the ``fields`` argument.
    2. Register ``new_Instances``, and force scripting compiler to
       use it when trying to compile ``Instances``.

    After this function, the process will be reverted. User should be able to script another model
    using different fields.

    Example:
        Assume that ``Instances`` in the model consist of two attributes named
        ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
        :class:`Tensor` respectively during inference. You can call this function like:
        ::
            fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
            torchscipt_model =  scripting_with_instances(model, fields)

    Note:
        It only support models in evaluation mode.

    Args:
        model (nn.Module): The input model to be exported by scripting.
        fields (Dict[str, type]): Attribute names and corresponding type that
            ``Instances`` will use in the model. Note that all attributes used in ``Instances``
            need to be added, regardless of whether they are inputs/outputs of the model.
            Data type not defined in detectron2 is not supported for now.

    Returns:
        torch.jit.ScriptModule: the model in torchscript format
    """
    assert TORCH_VERSION >= (1, 8), "This feature is not available in PyTorch < 1.8"
    assert (
        not model.training
    ), "Currently we only support exporting models in evaluation mode to torchscript"

    with freeze_training_mode(model), patch_instances(fields):
        scripted_model = torch.jit.script(model)
        return scripted_model


# alias for old name
export_torchscript_with_instances = scripting_with_instances


def dump_torchscript_IR(model, dir):
    """
    Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph,
    inlined graph). Useful for debugging.

    Args:
        model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
        dir (str): output directory to dump files.
    """
    dir = os.path.expanduser(dir)
    PathManager.mkdirs(dir)

    def _get_script_mod(mod):
        if isinstance(mod, torch.jit.TracedModule):
            return mod._actual_script_module
        return mod

    # Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
    with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:

        def get_code(mod):
            # Try a few ways to get code using private attributes.
            try:
                # This contains more information than just `mod.code`
                return _get_script_mod(mod)._c.code
            except AttributeError:
                pass
            try:
                return mod.code
            except AttributeError:
                return None

        def dump_code(prefix, mod):
            code = get_code(mod)
            name = prefix or "root model"
            if code is None:
                f.write(f"Could not found code for {name} (type={mod.original_name})\n")
                f.write("\n")
            else:
                f.write(f"\nCode for {name}, type={mod.original_name}:\n")
                f.write(code)
                f.write("\n")
                f.write("-" * 80)

            for name, m in mod.named_children():
                dump_code(prefix + "." + name, m)

        if isinstance(model, torch.jit.ScriptFunction):
            f.write(get_code(model))
        else:
            dump_code("", model)

    def _get_graph(model):
        try:
            # Recursively dump IR of all modules
            return _get_script_mod(model)._c.dump_to_str(True, False, False)
        except AttributeError:
            return model.graph.str()

    with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
        f.write(_get_graph(model))

    # Dump IR of the entire graph (all submodules inlined)
    with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
        f.write(str(model.inlined_graph))

    if not isinstance(model, torch.jit.ScriptFunction):
        # Dump the model structure in pytorch style
        with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
            f.write(str(model))
