# SPDX-License-Identifier: Apache-2.0
import glob
import itertools
import os
import subprocess

import jinja2

FILE_HEAD = """
// auto generated by generate.py
// clang-format off

#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {
""".strip()

TEMPLATE = ("template __global__ void Marlin<"
            "{{scalar_t}}, "
            "{{w_type_id}}, "
            "{{threads}}, "
            "{{thread_m_blocks}}, "
            "{{thread_n_blocks}}, "
            "{{thread_k_blocks}}, "
            "{{'true' if m_block_size_8 else 'false'}}, "
            "{{stages}}, "
            "{{'true' if has_act_order else 'false'}}, "
            "{{'true' if has_zp else 'false'}}, "
            "{{group_blocks}}, "
            "{{'true' if is_zp_float else 'false'}}>"
            "( MARLIN_KERNEL_PARAMS );")

# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
#   = 0 : act order case
#   = -1 : channelwise quantization
#   > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]


def remove_old_kernels():
    for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
        subprocess.call(["rm", "-f", filename])


def generate_new_kernels():
    for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
        has_zp = "B" not in scalar_type
        all_template_str_list = []

        for group_blocks, m_blocks, thread_configs in itertools.product(
                GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):

            has_act_order = group_blocks == 0
            if has_zp and has_act_order:
                continue
            if thread_configs[2] == 256:
                if m_blocks <= 1 and thread_configs[0] != 128:
                    continue
                if m_blocks > 1 and thread_configs[0] != 64:
                    continue

            k_blocks = thread_configs[0] // 16
            n_blocks = thread_configs[1] // 16
            threads = thread_configs[2]

            c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"

            template_str = jinja2.Template(TEMPLATE).render(
                scalar_t=c_dtype,
                w_type_id=scalar_type + ".id()",
                threads=threads,
                thread_m_blocks=max(m_blocks, 1),
                thread_n_blocks=n_blocks,
                thread_k_blocks=k_blocks,
                m_block_size_8=m_blocks == 0.5,
                stages="pipe_stages",
                has_act_order=has_act_order,
                has_zp=has_zp,
                group_blocks=group_blocks,
                is_zp_float=False,
            )

            all_template_str_list.append(template_str)

        file_content = FILE_HEAD + "\n\n"
        file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
        filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"

        with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
            f.write(file_content)


if __name__ == "__main__":
    remove_old_kernels()
    generate_new_kernels()
