# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import json
import math
import os
from collections import defaultdict
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Set, Tuple

import torch.cuda.memory
import torch.cuda.nvtx
import torch.profiler
import torch.utils.hooks
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily
from torch.utils._pytree import tree_map

from ..ops.common import FUNC_TO_XFORMERS_OPERATOR
from .device_limits import get_device_limits
from .profiler import _Profiler


class TorchFuncMockNoDispatch:
    """
    Wraps a method to call it without the custom
    pytorch dispatcher
    """

    def __init__(self, pt_impl):
        self.pt_impl = pt_impl

    def __get__(self, obj, c):
        return partial(self, obj)

    def __call__(self, obj, *args, **kwargs):
        with _pop_mode_temporarily():
            return self.pt_impl(obj, *args, **kwargs)


class DispatcherWithoutBrokenFuncs(TorchDispatchMode):
    TENSOR_FUNCS_NO_DISPATCH = [
        # Can't convert Stream argument to Python object
        # https://github.com/pytorch/pytorch/issues/94403
        "record_stream"
    ]

    def __enter__(self) -> None:
        self._pt_impls = {}
        for k in self.TENSOR_FUNCS_NO_DISPATCH:
            impl = getattr(torch.Tensor, k)
            self._pt_impls[k] = impl
            setattr(torch.Tensor, k, TorchFuncMockNoDispatch(impl))
        return super().__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        for k in self.TENSOR_FUNCS_NO_DISPATCH:
            setattr(torch.Tensor, k, self._pt_impls[k])
        return super().__exit__(exc_type, exc_val, exc_tb)


def get_shape(i):
    return i.shape


def prod(x):
    res = 1
    for i in x:
        res *= i
    return res


class GemmOpComputeFlops:
    def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
        return (prod(inputs[0].shape[:-1]), inputs[1].shape[1], inputs[0].shape[-1])

    def __call__(self, inputs: List[Any], outputs: List[Any]) -> float:
        return 2 * prod(self._get_mnk(inputs))

    def op_suffix(self, inputs: List[Any]) -> str:
        m, n, k = self._get_mnk(inputs)
        return f"_{m}x{n}x{k}"


class GemmOpComputeFlopsLinear(GemmOpComputeFlops):
    def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
        return (prod(inputs[0].shape[:-1]), inputs[1].shape[0], inputs[0].shape[-1])


class GemmOpComputeFlopsMv(GemmOpComputeFlops):
    def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
        return (prod(inputs[0].shape[:-1]), 1, inputs[0].shape[-1])


class GemmOpComputeFlopsBmm(GemmOpComputeFlops):
    def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
        a, b = inputs[0], inputs[1]
        assert a.ndim == 3
        assert b.ndim == 3
        bs = max(inputs[0].shape[0], inputs[1].shape[0])
        return (bs * a.shape[1], b.shape[-1], b.shape[-2])


class GemmOpComputeFlopsAddmm(GemmOpComputeFlops):
    def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
        return super()._get_mnk(inputs[1:])


class GemmOpComputeFlopsAddbmm(GemmOpComputeFlopsBmm):
    def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
        return super()._get_mnk(inputs[1:])


def conv_flop_count(
    x_shape: List[int],
    w_shape: List[int],
    out_shape: List[int],
    transposed: bool = False,
) -> float:
    """
    Count flops for convolution. Note only multiplication is
    counted. Computation for addition and bias is ignored.
    Flops for a transposed convolution are calculated as
    flops = (x_shape[2:] * prod(w_shape) * batch_size).
    Args:
        x_shape (list(int)): The input shape before convolution.
        w_shape (list(int)): The filter shape.
        out_shape (list(int)): The output shape after convolution.
        transposed (bool): is the convolution transposed
    Returns:
        int: the number of flops
    """
    batch_size = x_shape[0]
    conv_shape = (x_shape if transposed else out_shape)[2:]
    flop = batch_size * prod(w_shape) * prod(conv_shape)
    return flop


def conv_flop(inputs: List[Any], outputs: List[Any]):
    """
    Count flops for convolution.
    """
    x, w = inputs[:2]
    x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
    transposed = inputs[6]

    return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)


def transpose_shape(shape):
    return [shape[1], shape[0]] + list(shape[2:])


def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
    grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
    output_mask = inputs[-1]
    fwd_transposed = inputs[7]
    flop_count = 0.0

    if output_mask[0]:
        grad_input_shape = get_shape(outputs[0])
        flop_count += conv_flop_count(
            grad_out_shape, w_shape, grad_input_shape, not fwd_transposed
        )
    if output_mask[1]:
        grad_weight_shape = get_shape(outputs[1])
        flop_count += conv_flop_count(
            transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed
        )

    return flop_count


def tensor_storage_size_in_mem(x: torch.Tensor):
    total = 1
    for dim_sz, stride in zip(x.shape, x.stride()):
        if stride >= 1:
            total *= dim_sz
    return total


def get_size(inputs: List[Any]):
    total_bytes = 0

    def process(x) -> None:
        nonlocal total_bytes
        if isinstance(x, torch.Tensor):
            total_bytes += tensor_storage_size_in_mem(x) * x.element_size()

    tree_map(process, inputs)
    return total_bytes


def operation_memory_rw_bytes(inputs: List[Any], outputs: List[Any]):
    size_input, size_output = get_size(inputs), get_size(outputs)
    return size_input + size_output


def output_read_from_input(inputs: List[Any], outputs: List[Any]):
    size_input, size_output = get_size(inputs), get_size(outputs)
    return size_output + min(size_input, size_output)


def output_total_size(inputs: List[Any], outputs: List[Any]):
    return get_size(outputs)


def input_total_size(inputs: List[Any], outputs: List[Any]):
    return get_size(inputs)


def guess_flops_unknown_op(inputs: List[Any], outputs: List[Any]):
    # Approximation that isn't too bad
    total_elements = 0

    def process(x) -> None:
        nonlocal total_elements
        if isinstance(x, torch.Tensor):
            total_elements += x.numel()

    tree_map(process, inputs)
    tree_map(process, outputs)
    return total_elements / 2


def no_flop(inputs: List[Any], outputs: List[Any]):
    return 0


def no_io(inputs: List[Any], outputs: List[Any]):
    return 0


aten = torch.ops.aten
NO_FLOPS_NO_IO_OPS = [
    aten.permute,
    aten.view,
    aten.view_as,
    aten.detach,
    aten.t,
    aten.transpose,
    aten.expand,
    aten._unsafe_view,
    aten.select,
    aten.split,
    aten.split_with_sizes,
    aten.empty,
    aten.empty_strided,
    aten.empty_like,
    aten.is_same_size,
]
NO_FLOPS_OPS = [
    aten._reshape_alias,
    aten.reshape,
    aten.clone,
    aten.cat,
    aten.select_backward,
    aten.slice,
    aten.slice_backward,
    aten.ones,
    aten.ones_like,
    aten.zeros_like,
    aten.zero_,
    aten.zeros,
    aten.masked_fill,
    aten.masked_fill_,
]

flop_mapping = {
    aten.mv: GemmOpComputeFlopsMv(),  # mat-vec
    aten.mm: GemmOpComputeFlops(),
    aten.matmul: GemmOpComputeFlops(),
    aten.addmm: GemmOpComputeFlopsAddmm(),
    aten.bmm: GemmOpComputeFlopsBmm(),
    aten.addbmm: GemmOpComputeFlopsAddbmm(),
    aten.linear: GemmOpComputeFlopsLinear(),
    aten.convolution: conv_flop,
    aten._convolution: conv_flop,
    aten.convolution_backward: conv_backward_flop,
    # Operations with 0 flop
    **{op: no_flop for op in NO_FLOPS_OPS},
    **{op: no_flop for op in NO_FLOPS_NO_IO_OPS},
}
io_mapping = {
    aten.clone: output_read_from_input,
    aten.cat: output_read_from_input,
    aten.slice: output_read_from_input,
    aten.ones_like: output_total_size,
    aten.zeros_like: output_total_size,
    aten.zero_: input_total_size,
    **{op: no_io for op in NO_FLOPS_NO_IO_OPS}
    # TODO: Check how this is implemented in PT
    # aten.slice_backward: no_flop,
    # aten.select_backward: no_flop,
}


@dataclass
class _OpInfo:
    flop_count: float = 0.0
    time_ms: float = 0.0
    io_bytes: int = 0
    is_exact_flop: bool = True
    op_name: str = ""
    op_suffix: str = ""
    stacktrace: Tuple[str, ...] = field(default_factory=tuple)
    ev_start: torch.cuda.Event = field(
        default_factory=lambda: torch.cuda.Event(enable_timing=True)
    )
    ev_end: torch.cuda.Event = field(
        default_factory=lambda: torch.cuda.Event(enable_timing=True)
    )

    # Hardware limits for this operation (inf if unknown)
    hardware_tflops_limit: float = math.inf
    hardware_membw_limit: float = math.inf

    @property
    def time_membound_ms(self) -> float:
        assert self.time_ms > 0.0
        if self.io_bytes == 0:
            return 0.0
        return min(self.time_ms, 1000 * self.io_bytes / self.hardware_membw_limit)

    @property
    def time_computebound_ms(self) -> float:
        assert self.time_ms > 0.0
        tflop = self.flop_count / (1000**4)
        if tflop == 0.0:
            return 0.0
        return min(self.time_ms, 1000 * tflop / self.hardware_tflops_limit)

    def finalize(self) -> None:
        self.time_ms = self.ev_start.elapsed_time(self.ev_end)


@dataclass
class _OpInfoAggregated:
    is_exact_flop: bool = True
    total_flop_count: float = 0.0
    total_io_bytes: int = 0
    total_time_ms: float = 0.0
    total_time_membound_ms: float = 0.0
    total_time_computebound_ms: float = 0.0
    num: int = 0
    stacktraces: List[Tuple[str, ...]] = field(default_factory=list)

    def add(self, op: _OpInfo) -> None:
        self.total_flop_count += op.flop_count
        self.total_time_ms += op.time_ms
        self.total_io_bytes += op.io_bytes
        self.total_time_membound_ms += op.time_membound_ms
        self.total_time_computebound_ms += op.time_computebound_ms
        self.num += 1
        self.is_exact_flop = op.is_exact_flop
        self.stacktraces.append(op.stacktrace)

    def as_dict(self, **kwargs) -> Dict[str, Any]:
        mem_bound = min(1, self.total_time_membound_ms / self.total_time_ms)
        tflops = self.total_flop_count / (self.total_time_ms / 1000) / (1000**4)
        compute_bound = min(1, self.total_time_computebound_ms / self.total_time_ms)
        return {
            "is_exact_flop": self.is_exact_flop,
            "total_flop_count": self.total_flop_count,
            "total_time_ms": self.total_time_ms,
            "total_io_bytes": self.total_io_bytes,
            "num": self.num,
            "Tflops": tflops,
            "mem_bound": mem_bound,
            "compute_bound": compute_bound,
            **kwargs,
        }


class DetectSlowOpsProfiler(DispatcherWithoutBrokenFuncs):
    """
    Inspired from https://fb.workplace.com/groups/pytorch.dev/permalink/1054537595124720/
    """

    def __init__(self, main_profiler: _Profiler) -> None:
        self.main_profiler = main_profiler
        self.trace: List[_OpInfo] = []
        self.temp_disabled = False

    def _hardware_tflops_membw_limit(
        self, args: Tuple[Any, ...], outputs: Tuple[Any, ...]
    ) -> Tuple[float, float]:
        device = None
        dtypes: List[torch.dtype] = []
        for a in itertools.chain(outputs, args):
            if isinstance(a, torch.Tensor):
                if device is None:
                    device = a.device
                dtypes.append(a.dtype)
        limits = get_device_limits(device)
        dtypes = [dt for dt in dtypes if dt in limits.gemm_tflops]
        if not dtypes or device is None:
            return (math.inf, math.inf)
        dtype = dtypes[0]
        if torch.is_autocast_enabled() and dtype is torch.float32:
            dtype = torch.get_autocast_gpu_dtype()
        return limits.gemm_tflops[dtype], limits.gmem_bandwidth

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        func_packet = func._overloadpacket
        if self.temp_disabled or func_packet.__name__ in [
            "_record_function_exit",
            "_record_function_enter_new",
        ]:
            return func(*args, **kwargs)

        op = _OpInfo()
        op.ev_start.record()
        out = func(*args, **kwargs)
        op.ev_end.record()

        (
            op.hardware_tflops_limit,
            op.hardware_membw_limit,
        ) = self._hardware_tflops_membw_limit(
            args, out if isinstance(out, tuple) else (out,)
        )
        op.op_name = func_packet.__name__
        # Prevent functions called by flop counting ops to be recorded
        self.temp_disabled = True
        flop_count = -1
        compute_flops = None
        if func_packet in FUNC_TO_XFORMERS_OPERATOR:
            flop_count = FUNC_TO_XFORMERS_OPERATOR[func_packet].operator_flop(
                *args, **kwargs
            )
        if flop_count == -1:
            compute_flops = flop_mapping.get(func_packet, guess_flops_unknown_op)
            flop_count = compute_flops(args, out if isinstance(out, tuple) else (out,))
            if isinstance(compute_flops, GemmOpComputeFlops):
                op.op_name += compute_flops.op_suffix(args)

        compute_io = io_mapping.get(func_packet, operation_memory_rw_bytes)
        op.io_bytes = compute_io(args, out if isinstance(out, tuple) else (out,))
        self.temp_disabled = False

        op.stacktrace = tuple(self.main_profiler.parents)
        op.flop_count = flop_count
        op.is_exact_flop = compute_flops is not guess_flops_unknown_op
        self.trace.append(op)

        return out

    def __enter__(self):
        self.main_profiler._install_hooks()
        super().__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        super().__exit__(exc_type, exc_val, exc_tb)
        self.main_profiler._remove_hooks()
        torch.cuda.synchronize()  # Wait for the events to be recorded
        for op in self.trace:
            op.finalize()
        self.save_json()

    def step(self) -> None:
        pass

    def save_json(self) -> None:
        # Aggregate data at the module + op level
        all_paths: Set[Tuple[str, ...]] = set()
        per_module_data: Dict[Tuple[str, ...], _OpInfoAggregated] = defaultdict(
            _OpInfoAggregated
        )
        per_op_data: Dict[str, _OpInfoAggregated] = defaultdict(_OpInfoAggregated)
        for op in self.trace:
            all_paths.add(op.stacktrace)
        for op in self.trace:
            for i in range(len(op.stacktrace)):
                if op.stacktrace[: i + 1] in all_paths:
                    per_module_data[op.stacktrace[: i + 1]].add(op)
            per_op_data[op.op_name].add(op)

        # Generate JSON
        all_data = []
        for stacktrace, agg_info in per_module_data.items():
            all_data.append(
                agg_info.as_dict(
                    agg="module", path=stacktrace, name=stacktrace[-1], op=""
                )
            )
        for op_name, agg_info in per_op_data.items():
            # Find the most common path
            paths_count: Dict[Tuple[str, ...], int] = defaultdict(int)
            agg_info.stacktraces.sort()  # In case of a draw, let's always return the same
            for p in agg_info.stacktraces:
                paths_count[p] += 1
            maxp = agg_info.stacktraces[0]
            for p, count in paths_count.items():
                if count > paths_count[maxp]:
                    maxp = p
            all_data.append(
                agg_info.as_dict(
                    agg="opname",
                    path=f"{'.'.join(maxp)} (x{paths_count[maxp]})",
                    name="",
                    op=op_name,
                )
            )

        filename = os.path.abspath(
            os.path.join(
                self.main_profiler.output_dir,
                f"{self.main_profiler.worker_name}_ops.json",
            )
        )
        self.main_profiler.summary.append(("OpsSummary", filename))
        with open(filename, "w+") as f:
            json.dump(all_data, f)
