# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The definition of kernel registry.

Init Phase:
1. Define kernel registry.
2. Register kernels.

"""

from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel


__all__ = ["Registry", "register_kernel"]


class Registry:
    """Registry for managing kernel implementations.

    Storage structure: ``{ "kernel_id": Class }``
    """

    _kernels: dict[str, type[BaseKernel]] = {}

    @classmethod
    def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
        """Decorator to register a kernel class.

        The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.

        Args:
            kernel_cls (type[BaseKernel]): The kernel class to register.

        Returns:
            type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator

        Raises:
            TypeError: If the class does not inherit from :class:`BaseKernel`.
            ValueError: If the kernel ID is missing or already registered.
        """
        if not issubclass(kernel_cls, BaseKernel):
            raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")

        kernel_id = kernel_cls.get_kernel_id()
        device = kernel_cls.get_device()

        # The device type of the current accelerator does not match the device type required by the kernel, skip registration
        if device != get_current_accelerator().type:
            return

        if not kernel_id:
            raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register")

        if kernel_id in cls._kernels:
            raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}")

        cls._kernels[kernel_id] = kernel_cls
        return kernel_cls

    @classmethod
    def get(cls, kernel_id: str) -> type[BaseKernel] | None:
        """Retrieves a registered kernel implementation by its ID.

        Args:
            kernel_id (str): The ID of the kernel to retrieve.

        Returns:
            type[BaseKernel] | None: The kernel class if found, else ``None``.
        """
        return cls._kernels.get(kernel_id)

    @classmethod
    def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
        """Returns a dictionary of all registered kernels.

        Returns:
            dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
        """
        return cls._kernels


# export decorator alias
register_kernel = Registry.register
