# SPDX-License-Identifier: Apache-2.0

import torch

from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types

ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024
ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128]
ALLSPARK_AMPERE_N_ALIGN = 16
ALLSPARK_AMPERE_K_ALIGN = 16


def check_allspark_supported_dtype_shape(input_size_per_partition: int,
                                         output_size_per_partition: int,
                                         group_size: int,
                                         weight_dtype: ScalarType,
                                         act_dtype: torch.dtype):
    capability_tuple = current_platform.get_device_capability()
    device_capability = (-1 if capability_tuple is None else
                         capability_tuple.to_int())

    # For Ampere GPU
    if device_capability >= 80 and device_capability < 90:
        if group_size != -1:
            return False, \
                "For Ampere GPU, AllSpark does not support group_size "\
                f"= {group_size}. Only group_size = -1 are supported."

        if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
            return False, "For Ampere GPU, AllSpark does not support "\
                f"quant type ({weight_dtype}). Only quant type "\
                f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported."

        if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \
            or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0:
            return False, \
                "AllSpark needs input_size_per_partition % "\
                f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\
                f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\
                "for Ampere GPU optimized kernels."

        if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
            return False, \
                "AllSpark only supports act_dtype = float16 or bfloat16,"\
                f"for Ampere GPU, but got act_dtype = {act_dtype}."
    else:
        return False, "AllSpark currently does not support "\
            f"device_capability = {device_capability}."

    return True, None
