# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# pyre-strict
# pyre-ignore-all-errors[2,33]

import logging
import typing
import warnings
from collections import Counter
from copy import copy
from dataclasses import dataclass
from numbers import Number
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar, Union

import numpy as np
import torch
import torch.nn as nn
from fvcore.common.checkpoint import _named_modules_with_dup
from torch import Tensor
from torch.jit import _get_trace_graph, TracerWarning

from .jit_handles import Handle


T = TypeVar("T", bound="JitModelAnalysis")


# Only ignore ops that are technically truly 0 flops:
# shape-manipulation ops, integer ops, memory copy ops
_IGNORED_OPS: Set[str] = {
    "aten::Int",
    "aten::ScalarImplicit",
    "aten::__and__",
    "aten::arange",
    "aten::bitwise_not",
    "aten::cat",
    "aten::chunk",
    "aten::clamp",
    "aten::clamp_",
    "aten::constant_pad_nd",
    "aten::contiguous",
    "aten::copy_",
    "aten::detach",
    "aten::dropout",
    "aten::empty",
    "aten::eq",
    "aten::expand",
    "aten::flatten",
    "aten::floor",
    "aten::floor_divide",
    "aten::full",
    "aten::full_like",
    "aten::gather",
    "aten::ge",
    "aten::gt",
    "aten::index",
    "aten::index_put_",
    "aten::masked_fill",
    "aten::max",
    "aten::narrow",
    "aten::new_empty",
    "aten::new_full",
    "aten::new_zeros",
    "aten::nonzero",
    "aten::ones",
    "aten::permute",
    "aten::relu",
    "aten::relu_",
    "aten::remainder",
    "aten::reshape",
    "aten::roll",
    "aten::select",
    "aten::size",
    "aten::slice",
    "aten::split",
    "aten::split_with_sizes",
    "aten::squeeze",
    "aten::stack",
    "aten::t",
    "aten::to",
    "aten::transpose",
    "aten::type_as",
    "aten::unbind",
    "aten::unsqueeze",
    "aten::unsqueeze_",
    "aten::view",
    "aten::zeros",
    "aten::zeros_like",
}


@dataclass
class Statistics:
    """
    For keeping track of the various model statistics recorded during
    analysis.
    """

    counts: "Dict[str, Counter[str]]"
    unsupported_ops: "Dict[str, Counter[str]]"
    uncalled_mods: "Set[str]"


def _named_modules_without_dup(model: nn.Module) -> Iterator[Tuple[str, nn.Module]]:
    """
    Like .named_modules(), but the results are slightly different for
    some wrapped models.
    """
    seen = set()
    for name, mod in _named_modules_with_dup(model):
        if mod not in seen:
            seen.add(mod)
            yield name, mod


def _get_scoped_trace_graph(
    module: nn.Module,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    aliases: Dict[Union[str, nn.Module], str],
) -> torch._C.Graph:
    """
    Traces the provided module using torch.jit._get_trace_graph, but adds
    submodule scope information to each graph node. The resulting graph
    is in-lined and has all model parameters treated as inputs. The input
    model has the scope name '', while its descendants have names of the
    form 'child.grandchild.grandgrandchild...'.

    Args:
        model (nn.Module) : The module to trace
        inputs (tuple) : Inputs used during the trace of the model
        aliases (dict(str or nn.Module, str) : maps modules and module
            names to the canonical name to be used as the scope for
            that module.

    Returns:
        graph (torch._C.Graph) : The pytorch JIT trace of the model
    """

    class ScopePushHook:
        def __init__(self, name: str) -> None:
            self.name = name

        def __call__(self, module: nn.Module, inputs: Any) -> Any:
            tracing_state = torch._C._get_tracing_state()
            if tracing_state:
                tracing_state.push_scope(self.name)
            return inputs

    class ScopePopHook:
        def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any:
            tracing_state = torch._C._get_tracing_state()
            if tracing_state:
                tracing_state.pop_scope()
            return outputs

    hook_handles: List[Any] = []

    def register_hooks(mod: nn.Module, name: str) -> None:
        prehook = mod.register_forward_pre_hook(ScopePushHook(name))
        posthook = mod.register_forward_hook(ScopePopHook())
        hook_handles.append(prehook)
        hook_handles.append(posthook)

    # Unwrap DDP, but correct the scope names for the root module.
    if isinstance(
        module, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)
    ):
        # Since DataParallel just wraps the model, add an extra set of hooks
        # to the model it wraps to account for the wrapper. Then trace it.
        root_name = aliases[module]
        module = module.module
        register_hooks(module, root_name)

    for name, mod in _named_modules_without_dup(module):
        name = aliases[mod]
        register_hooks(mod, name)

    graph, _ = _get_trace_graph(module, inputs)

    for handle in hook_handles:
        handle.remove()

    return graph


class JitModelAnalysis:
    """
    Provides access to per-submodule model statistics obtained by
    tracing a model with pytorch's jit tracing functionality. Calculates
    a statistic on a per-operator basis using the provided set of functions
    that acts on the inputs and outputs to the operator, then aggregates
    this over modules in the model. Can return the aggregate statistic for
    any submodule in the model. Is lazily evaluated, and will perform the
    trace when a statistic is first requested. Changing the operator handles
    will cause the trace to be rerun on the next request.

    Submodules may be referred to using the module's name. The input model has
    name "", while its descendants have names of the form
    "child.grandchild.grandgrandchild...".

    An operator is treated as within the scope of a module if calling that
    module directly resulted in that operator being run. In particular,
    this means that calls to other functions owned by a module or explicit
    calls to module.forward(...) will not register resulting operators as
    contributing statistics to that module.
    """

    def __init__(
        self,
        model: nn.Module,
        inputs: Union[Tensor, Tuple[Tensor, ...]],
    ) -> None:
        """
        Args:
            model: The model to analyze
            inputs: The inputs to the model for analysis.

        We will trace the execution of `model.forward(inputs)`. This means
        inputs have to be tensors or tuple of tensors (see
        https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace).
        In order to trace other methods or unsupported input types, you may need
        to implement a wrapper module.
        """
        self._model = model
        self._inputs = inputs
        self._op_handles: Dict[str, Handle] = {}
        # Mapping from names to submodules
        self._named_modules: Dict[str, nn.Module] = dict(_named_modules_with_dup(model))
        # Mapping from submodules and their aliases to the canonical name of each submodule
        self._aliases: Dict[Union[nn.Module, str], str] = self._get_aliases(model)
        self._stats: Optional[Statistics] = None

        self._ignored_ops: Set[str] = copy(_IGNORED_OPS)
        self.unsupported_ops_warnings(True)
        self.uncalled_modules_warnings(True)
        self.tracer_warnings("no_tracer_warning")
        self.ancestor_mode("owner")

    def total(self, module_name: str = "") -> int:
        """
        Returns the total aggregated statistic across all operators
        for the requested module.

        Args:
            module_name (str) : The submodule to get data for. Defaults to
                the entire model.
        Returns:
            int : The aggregated statistic.
        """
        stats = self._analyze()
        module_name = self.canonical_module_name(module_name)
        total_count = sum(stats.counts[module_name].values())
        return total_count

    def by_operator(self, module_name: str = "") -> typing.Counter[str]:
        """
        Returns the statistics for a requested module, grouped by operator
        type. The operator handle determines the name associated with each
        operator type.

        Args:
            module_name (str) : The submodule to get data for. Defaults
                to the entire model.
        Returns:
            Counter(str) : The statistics for each operator.
        """
        stats = self._analyze()
        module_name = self.canonical_module_name(module_name)
        return stats.counts[module_name]

    def by_module_and_operator(self) -> Dict[str, typing.Counter[str]]:
        """
        Returns the statistics for all submodules, separated out by
        operator type for each submodule. The operator handle determines
        the name associated with each operator type.

        Returns:
            dict(str, Counter(str)):
                The statistics for each submodule and each operator.
                Grouped by submodule names, then by operator name.
        """
        stats = self._analyze()
        return stats.counts

    def by_module(self) -> typing.Counter[str]:
        """
        Returns the statistics for all submodules, aggregated over
        all operators.

        Returns:
            Counter(str): statistics counter grouped by submodule names
        """
        stats = self._analyze()
        summed_counts = Counter()
        for mod, results in stats.counts.items():
            summed_counts[mod] = sum(results.values())
        return summed_counts

    def unsupported_ops(self, module_name: str = "") -> typing.Counter[str]:
        """
        Lists the number of operators that were encountered but unsupported
        because no operator handle is available for them. Does not include
        operators that are explicitly ignored.

        Args:
            module_name (str) : The submodule to list unsupported ops.
                Defaults to the entire model.

        Returns:
            Counter(str) : The number of occurences each unsupported operator.
        """
        if self._stats is None:
            raise RuntimeError(
                "Analysis results should be computed "
                "before calling unsupported_ops()"
            )
        module_name = self.canonical_module_name(module_name)
        return self._stats.unsupported_ops[module_name]  # pyre-fixme

    def uncalled_modules(self) -> Set[str]:
        """
        Returns a set of submodules that were never called during the
        trace of the graph. This may be because they were unused, or
        because they were accessed via direct calls .forward() or with
        other python methods. In the latter case, statistics will not be
        attributed to the submodule, though the statistics will be included
        in the parent module.

        Returns:
            set(str) : The set of submodule names that were never called
                during the trace of the model.
        """
        stats = self._analyze()
        return stats.uncalled_mods

    def set_op_handle(self, *args, **kwargs: Optional[Handle]) -> "JitModelAnalysis":
        """
        Sets additional operator handles, or replaces existing ones.

        Args:
            args: (str, Handle) pairs of operator names and handles.
            kwargs: mapping from operator names to handles.

        If a handle is ``None``, the op will be explicitly ignored. Otherwise,
        handle should be a function that calculates the desirable statistic
        from an operator. The function must take two arguments, which are the
        inputs and outputs of the operator, in the form of ``list(torch._C.Value)``.
        The function should return a counter object with per-operator statistics.

        Examples
        ::
            handlers = {"aten::linear": my_handler}
            counter.set_op_handle("aten::matmul", None, "aten::bmm", my_handler2)
                   .set_op_handle(**handlers)
        """
        self._stats = None
        if len(args) % 2 != 0:
            raise TypeError(
                "set_op_handle should be called with pairs of names and handles!"
            )
        for name, handle in zip(args[::2], args[1::2]):
            kwargs[name] = handle
        for name, handle in kwargs.items():
            if handle is None:
                self._ignored_ops.add(name)
            else:
                self._op_handles[name] = handle
        return self

    def clear_op_handles(self) -> "JitModelAnalysis":
        """
        Clears all operator handles currently set.
        """
        self._op_handles = {}
        self._ignored_ops = copy(_IGNORED_OPS)
        self._stats = None
        return self

    def canonical_module_name(self, name: str) -> str:
        """
        Returns the canonical module name of the given ``name``, which might be
        different from the given ``name`` if the module is shared.
        This is the name that will be used as a key when statistics are
        output using .by_module() and .by_module_and_operator().

        Args:
            name (str) : The name of the module to find the canonical name for.
        Returns:
            str : The canonical name of the module.
        """
        # Blocks access by a direct module reference
        assert isinstance(name, str), "Module name must be a string."
        if name in self._aliases:
            return self._aliases[name]
        else:
            raise KeyError(
                "Requested module name is not among "
                "the descendants of the analyzed model."
            )

    def copy(
        self,
        new_model: Optional[nn.Module] = None,
        new_inputs: Union[None, Tensor, Tuple[Tensor, ...]] = None,
    ) -> "JitModelAnalysis":
        """
        Returns a copy of the :class:`JitModelAnalysis` object, keeping all
        settings, but on a new model or new inputs.

        Args:
            new_model (nn.Module or None) : a new model for the new
                JitModelAnalysis. If None, uses the original model.
            new_inputs (typing.Tuple[object, ...] or None) : new inputs
                for the new JitModelAnalysis. If None, uses the original
                inputs.
        Returns:
            JitModelAnalysis : the new model analysis object
        """
        model = self._model if new_model is None else new_model
        inputs = self._inputs if new_inputs is None else new_inputs
        return (
            JitModelAnalysis(model=model, inputs=inputs)
            .set_op_handle(**self._op_handles)
            .unsupported_ops_warnings(self._enable_warn_unsupported_ops)
            .uncalled_modules_warnings(self._enable_warn_uncalled_mods)
            .tracer_warnings(self._warn_trace)
        )

    def tracer_warnings(self: T, mode: str) -> T:
        """
        Sets which warnings to print when tracing the graph to calculate
        statistics. There are three modes. Defaults to 'no_tracer_warning'.
        Allowed values are:

        * 'all' : keeps all warnings raised while tracing
        * 'no_tracer_warning' : suppress torch.jit.TracerWarning only
        * 'none' : suppress all warnings raised while tracing

        Args:
            mode (str) : warning mode in one of the above values.
        """
        if mode not in ["all", "no_tracer_warning", "none"]:
            raise ValueError(f"Unrecognized tracer warning mode {mode}.")
        self._warn_trace = mode
        return self

    def ancestor_mode(self: T, mode: str) -> T:
        """
        Sets how to determine the ancestor modules of an operator. Must be one of
        "owner" or "caller".

        * "caller": an operator belongs to all modules that is currently executing
          `forward()` at the time the operator is called.
        * "owner": an operator belongs to the last module that's executing
          `forward()` at the time the operator is called, plus this module's recursive
          parents.  If an module has multiple parents (e.g. a shared module), only one
          will be picked.

        For most cases, a module only calls submodules it owns, so both options would
        work identically. In certain edge cases, this option will affect the hierarchy
        of results, but won't affect the total count.
        """
        if mode not in ["owner", "caller"]:
            raise ValueError(f"Unrecognized ancestor mode: {mode}")
        self._ancestor_mode = mode
        return self

    def unsupported_ops_warnings(self: T, enabled: bool) -> T:
        """
        Sets if warnings for unsupported operators are shown. Defaults
        to True. Counts of unsupported operators may be obtained from
        :meth:`unsupported_ops` regardless of this setting.

        Args:
            enabled (bool) : Set to 'True' to show unsupported operator
                warnings.
        """
        self._enable_warn_unsupported_ops = enabled
        return self

    def uncalled_modules_warnings(self: T, enabled: bool) -> T:
        """
        Sets if warnings from uncalled submodules are shown. Defaults to true.
        A submodule is considered "uncalled" if it is never called during
        tracing. This may be because it is actually unused, or because it is
        accessed via calls to ``.forward()`` or other methods of the module.
        The set of uncalled modules may be obtained from
        :meth:`uncalled_modules` regardless of this setting.

        Args:
            enabled (bool) : Set to 'True' to show warnings.
        """
        self._enable_warn_uncalled_mods = enabled
        return self

    def _warn_unsupported_ops(self, ops: typing.Counter[str]) -> None:
        if not self._enable_warn_unsupported_ops:
            return
        logger = logging.getLogger(__name__)
        for op, freq in ops.items():
            logger.warning(
                "Unsupported operator {} encountered {} time(s)".format(op, freq)
            )

    def _warn_uncalled_mods(self, uncalled_mods: Set[str]) -> None:
        if not self._enable_warn_uncalled_mods:
            return
        uncalled_mods = {x for x in uncalled_mods if self._has_forward(x)}
        if len(uncalled_mods) == 0:
            return

        logger = logging.getLogger(__name__)
        logger.warning(
            "The following submodules of the model were never "
            "called during the trace of the graph. They may be "
            "unused, or they were accessed by direct calls to "
            ".forward() or via other python methods. In the latter "
            "case they will have zeros for statistics, though their "
            "statistics will still contribute to their parent calling "
            "module.\n" + ", ".join(sorted(uncalled_mods))
        )

    def _get_aliases(self, model: nn.Module) -> Dict[Union[str, nn.Module], str]:
        aliases = {}
        for name, module in _named_modules_with_dup(model):
            if module not in aliases:
                aliases[module] = name
            aliases[name] = aliases[module]
            if "/" in name:
                sub_name = name.split("/")[-1]
                aliases[sub_name] = aliases[module]
        return aliases

    def _get_all_ancestors(self, module_name: str) -> Set[str]:
        """
        Get all ancestors of the given module, defined by ownership.
        If the given module has multiple owners, use its canonical name.
        """
        parts = self.canonical_module_name(module_name).split(".")
        res = {""}
        for k in range(len(parts) + 1):
            res.add(".".join(parts[:k]))
        return res

    def _analyze(self) -> "Statistics":
        # Don't calculate if results are already stored.
        stats = self._stats
        if stats is not None:
            return stats

        with warnings.catch_warnings():
            if self._warn_trace == "none":
                warnings.simplefilter("ignore")
            elif self._warn_trace == "no_tracer_warning":
                warnings.filterwarnings("ignore", category=TracerWarning)
            graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)

        # Assures even modules not in the trace graph are initialized to zero count
        counts = {}
        unsupported_ops = {}
        # We don't need the duplication here, but self._model.named_modules()
        # gives slightly different results for some wrapped models.
        for _, mod in _named_modules_with_dup(self._model):
            name = self._aliases[mod]
            counts[name] = Counter()
            unsupported_ops[name] = Counter()

        all_seen = set()
        for node in graph.nodes():
            kind = node.kind()
            if kind == "prim::PythonOp":
                # for PythonOp, pyname contains the actual name in Python
                # pyre-fixme[16]: `Node` has no attribute `pyname`.
                kind = kind + "." + node.pyname()
            scope_names = node.scopeName().split("/")
            all_seen.update(scope_names)
            if self._ancestor_mode == "caller":
                ancestors = set(scope_names)
            else:
                ancestors = self._get_all_ancestors(scope_names[-1])
                all_seen.update(ancestors)
            if kind not in self._op_handles:
                if self._should_ignore_node(node):
                    continue
                for name in ancestors:
                    unsupported_ops[name][kind] += 1
            else:
                inputs, outputs = list(node.inputs()), list(node.outputs())
                op_counts = self._op_handles[kind](inputs, outputs)
                if isinstance(op_counts, Number):
                    op_counts = Counter({self._simplify_op_name(kind): op_counts})
                for v in op_counts.values():
                    if not isinstance(v, (int, float, np.float64, np.int64)):
                        raise ValueError(
                            f"Invalid type {type(v)} for the flop count! "
                            "Please use a wider type to avoid overflow."
                        )

                # Assures an op contributes at most once to a module
                for name in ancestors:
                    counts[name] += op_counts

        uncalled_mods = set(self._aliases.values()) - all_seen
        stats = Statistics(
            counts=counts, unsupported_ops=unsupported_ops, uncalled_mods=uncalled_mods
        )
        self._stats = stats
        self._warn_unsupported_ops(unsupported_ops[""])
        self._warn_uncalled_mods(uncalled_mods)
        return stats

    def _simplify_op_name(self, full_op_name: str) -> str:
        """
        Get simplified name of the op without the preceding namespace, e.g.
        aten::batch_norm -> batch_norm
        """
        p = full_op_name.find("::")
        if p != -1:
            return full_op_name[p + 2 :]
        else:
            return full_op_name

    def _has_forward(self, mod_name: str) -> bool:
        # Whether the module has a valid forward method.
        # Modules without forward are not expected to get called
        # and therefore should not produce "uncalled" warnings
        module = self._named_modules.get(mod_name)
        if module is None:
            return False
        module_type = type(module)
        # Containers are not meant to be called anyway (they don't have forward)
        # NOTE: We add nn.Identity as well to silence the uncalled warning, but it's
        # different from other containers: Identity has a forward but the forward does
        # not contain ops, so it appears "uncalled" after tracing. A more proper way
        # may be to use forward hooks (instead of the graph) to decide whether a module
        # has been called.
        no_forward_mods = {nn.ModuleList, nn.ModuleDict, nn.Module, nn.Identity}
        for mod in no_forward_mods:
            if module_type.forward is mod.forward:
                return False
        return True

    def _should_ignore_node(self, node) -> bool:
        kind = node.kind()
        if kind in self._ignored_ops:
            return True
        # Ignore all prim:: operators, with two exceptions:
        # * prim::PythonOp can be a user-implemented `torch.autograd.Function`
        # * prim::CallFunction an be a call to scripted module/function.
        if kind.startswith("prim::PythonOp") or kind.startswith("prim::CallFunction"):
            return False
        if kind.startswith("prim::"):
            return True
        return False
