# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
                                             register_replacement)

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform

from .vllm_inductor_pass import VllmInductorPass

logger = init_logger(__name__)


def silu_mul_pattern_static(result: torch.Tensor,
                            result_silu_mul: torch.Tensor, input: torch.Tensor,
                            scale: torch.Tensor):
    at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
                              result=result_silu_mul,
                              input=input)
    at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
                              result=result,
                              input=at1[1],
                              scale=scale)
    return at2[1]


def silu_mul_replacement_static(result: torch.Tensor,
                                result_silu_mul: torch.Tensor,
                                input: torch.Tensor, scale: torch.Tensor):
    at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
                             result=result,
                             input=input,
                             scale=scale)
    return at[1]


def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp8(*args, **kwargs):
    fp8 = current_platform.fp8_dtype()
    return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")


def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


class ActivationQuantFusionPass(VllmInductorPass):
    """
    This pass fuses a pre-defined set of custom ops into fused ops.
    It uses the torch pattern matcher to find the patterns and replace them.

    Because patterns can only be registered once, the pass is a singleton.
    This will be addressed in a future version of PyTorch:
    https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
    """

    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="activation_quant_fusion_pass")

        inputs = [
            empty_fp8(5, 4),  # Quant output
            empty_bf16(5, 4),  # Silu_and_mul output
            empty_bf16(5, 4),  # Input
            empty_fp32(1, 1)  # Scale
        ]
        register_replacement(silu_mul_pattern_static,
                             silu_mul_replacement_static, inputs, fwd_only,
                             self.patterns)

    def __call__(self, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_act_quant_fusion")

        count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
                     count)

        self.dump_graph(graph, "after_act_quant_fusion")
        self.end_and_log()
