# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The definition of base kernel class.

Init Phase:
1. Define base kernel class.
2. Define abstract methods.

"""

from abc import ABC, abstractmethod
from typing import Any

from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel


class BaseKernel(ABC):
    r"""Base class for all kernel implementations.

    Subclasses must implement the abstract methods and define the required class attributes.
    """

    _kernel_id: Any = ""  # kernel ID, any hashable value to identify a kernel implementation
    _device: DeviceType = DeviceType.CPU  # "cuda", "npu", "cpu", etc.

    @classmethod
    def get_kernel_id(cls) -> str:
        """Returns the unique identifier for the kernel."""
        return cls._kernel_id

    @classmethod
    def get_device(cls) -> str:
        """Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
        return cls._device

    @classmethod
    def check_deps(cls) -> bool:
        """Checks if the required dependencies for the kernel are available.

        Returns:
            bool: ``True`` if dependencies are met, ``False`` otherwise.

        .. note::
            In explicit mode, if a user specifies an implementation but this check fails,
            it should raise an error instead of silently switching.
            Kernels can override this method to implement custom dependency checks.
        """
        if cls._device != get_current_accelerator().type:
            return False
        return True

    @classmethod
    @abstractmethod
    def apply(cls, **kwargs) -> HFModel:
        """Applies the kernel optimization to the model.

        Args:
            **kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.

        Returns:
            HFModel: The model with the kernel applied.

        Raises:
            RuntimeError: If the kernel dependencies are not met.
            NotImplementedError: If the method is not implemented by the subclass.

        Example:
            >>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
            >>> model = HFModel(config=config)
            >>> model = apply_kernel(model=model, kernel_id="npu_fused_moe")
        """
        if not cls.check_deps():
            raise RuntimeError(f"{cls.__name__} is not available but {cls.__name__} kernel was called.")
        raise NotImplementedError
