#pragma once

#include <cstdio>
#include <stdexcept>

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#include "gn_utils.hpp"
#include "gn_cuda_kernel.cuh"


namespace group_norm_v2 {

#define DISPATCH_LOWER_BOUND_N(VALUE, CONST_NAME, ...) [&] { \
    if (VALUE >= 16) { constexpr int CONST_NAME = 16; return __VA_ARGS__(); } \
    if (VALUE >= 8) { constexpr int CONST_NAME = 8; return __VA_ARGS__(); } \
    if (VALUE >= 4) { constexpr int CONST_NAME = 4; return __VA_ARGS__(); } \
    if (VALUE >= 2) { constexpr int CONST_NAME = 2; return __VA_ARGS__(); } \
    if (VALUE >= 1) { constexpr int CONST_NAME = 1; return __VA_ARGS__(); } \
    throw std::invalid_argument("DISPATCH_LOWER_BOUND_N " + std::to_string(VALUE)); \
    }()

#define DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, ...) [&] { \
    if (runtime_cuda_arch == 1000 && sm_count >= 148) { constexpr int RUNTIME_CUDA_ARCH = 1000, LB_SM_COUNT = 148; return __VA_ARGS__(); } \
    throw std::invalid_argument("DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT " + std::to_string(runtime_cuda_arch) + " " + std::to_string(sm_count)); \
    }()

#define DISPATCH_SM_MARGIN(VALUE, CONST_NAME, ...) [&] { \
    if (VALUE == 0) { constexpr int CONST_NAME = 0; return __VA_ARGS__(); } \
    if (VALUE == 32) { constexpr int CONST_NAME = 32; return __VA_ARGS__(); } \
    throw std::invalid_argument("DISPATCH_SM_MARGIN " + std::to_string(VALUE)); \
    }()


inline constexpr int get_max_cuda_arch() {
    int cuda_arch_list[] = {__CUDA_ARCH_LIST__};
    int max_cuda_arch = -1;
    for (int cuda_arch_item : cuda_arch_list) {
        if (cuda_arch_item > max_cuda_arch) {
            max_cuda_arch = cuda_arch_item;
        }
    }
    return max_cuda_arch;
}

template<typename T, bool BWD, bool REQUIRES_WGRAD, int HW, int G, int CPG, int LB_N, int RUNTIME_CUDA_ARCH, int LB_SM_COUNT, int EFFECTIVE_CUDA_ARCH, int SM_MARGIN>
constexpr auto compute_gn_params() {
    constexpr int C = G * CPG;

    // Initialize each variable to comply with C++17
    int BLOCK_DIM_X = 0;
    int C_PER_BLOCK = 0;
    int ROWS_PER_BLOCK = 0;
    bool LOAD_TWICE = false;
    int BLOCKS_PER_SM = 0;
    WgradSyncMethod wgrad_sync_method = WGRAD_SYNC_UNSPECIFIED;

    // There are two tiling strategies:
    //   - block sync: each block handles a whole group, i.e., a multiple of (G * HW) elements
    //   - virtual cluster sync: each virtual cluster handles a group
    // Block sync can avoid cross-block synchronization latency, but it may cause low occupancy.
    //   Use block sync if the IO size is small, when latency rather than occupancy dominates the kernel running time.

    // Elements to load for forward pass is `x`, elements to load for backward pass are `x` and `grad_output`, hence there is a factor of (1 + BWD)
    if (HW * CPG * (1 + BWD) * sizeof(T) <= 20480) {
        // Strategy 1: block sync
        C_PER_BLOCK = CPG;
        ROWS_PER_BLOCK = HW;
        BLOCK_DIM_X = lcm(32, C_PER_BLOCK);
        while (BLOCK_DIM_X < 256) {
            BLOCK_DIM_X *= 2;
        }
        BLOCKS_PER_SM = 1;
        // The size of registers is 65536 registers * 4 bytes per register.
        //   We have to leave some room for other variables and compiler optimizations,
        //   so we use 36000 as the threshold.
        LOAD_TWICE = BLOCKS_PER_SM * ROWS_PER_BLOCK * C_PER_BLOCK * (1 + BWD) * sizeof(T) > 36000 * 4;
    } else {
        // Strategy 2: virtual cluster sync
        //   A virtual cluster is a group of blocks that are synchronized with each other.
        //   Each group, i.e., a multiple of (G * HW) elements, should be handled on the same virtual cluster.
        //   If the virtual cluster size is supported by the hardware, HARDWARE_CLUSTER is preferred;
        //   otherwise, cooperative groups are used (i.e., PERSISTENT kernels).
        int c_per_cluster = lcm(128 / (int)sizeof(T), CPG);

        C_PER_BLOCK = c_per_cluster;
        BLOCK_DIM_X = C_PER_BLOCK == 320 ? 320 : 480;
\
        // Maximum number of rows that should reside in registers
        int register_max_rows = 36000 * 4 / (C_PER_BLOCK * (1 + BWD) * sizeof(T));

        std::tuple<bool, int, int, int, int, int> best_candidate{};
        BLOCKS_PER_SM = 0;
        ROWS_PER_BLOCK = 0;
        for (int blocks_per_sm = 1; blocks_per_sm <= 3; blocks_per_sm++) {
            for (int rows_per_block = HW; rows_per_block >= 1; rows_per_block /= 2) {
                int virtual_cluster_size = (HW / rows_per_block) * (c_per_cluster / C_PER_BLOCK);
                if (virtual_cluster_size > blocks_per_sm * (LB_SM_COUNT - SM_MARGIN)) {
                    continue;
                }
                int num_clusters = blocks_per_sm * (LB_SM_COUNT - SM_MARGIN) / virtual_cluster_size;
                int num_tasks = LB_N * (C / c_per_cluster);
                int num_waves = up_div(num_tasks, num_clusters);
                bool load_twice = rows_per_block > register_max_rows / blocks_per_sm;

                // Wave utilization: the percent of SMs that are used for each wave
                //   For example, SM_COUNT=100 and VIRTUAL_CLUSTER_SIZE=64,
                //     if BLOCKS_PER_SM=1, num_clusters=1, wave_util=64%;
                //     if BLOCKS_PER_SM=2, num_clusters=3, wave_util=96%.
                //   This helps select a good number of BLOCKS_PER_SM
                int wave_util = 10000 * std::min(num_tasks, num_clusters) * virtual_cluster_size / (blocks_per_sm * (LB_SM_COUNT - SM_MARGIN));

                decltype(best_candidate) candidate = {
                    true,
                    !load_twice,                // Prefer no load twice
                    !(num_waves >= 2 && blocks_per_sm == 1),  // When there are multiple waves, prefer multiple blocks per SM to ensure overlapping
                    -num_waves,                 // Prefer fewer waves
                    std::min(9000, wave_util),  // Prefer high wave utilization
                    -blocks_per_sm,             // Prefer fewer blocks per SM in order to reduce threads overhead
                };
                if (candidate > best_candidate) {
                    // Assign each element respectively to comply with C++17
                    std::get<0>(best_candidate) = std::get<0>(candidate);
                    std::get<1>(best_candidate) = std::get<1>(candidate);
                    std::get<2>(best_candidate) = std::get<2>(candidate);
                    std::get<3>(best_candidate) = std::get<3>(candidate);
                    std::get<4>(best_candidate) = std::get<4>(candidate);
                    std::get<5>(best_candidate) = std::get<5>(candidate);
                    static_assert(std::tuple_size<decltype(best_candidate)>::value == 6, "missing assignments");

                    BLOCKS_PER_SM = blocks_per_sm;
                    ROWS_PER_BLOCK = rows_per_block;
                }
            }
        }

        LOAD_TWICE = ROWS_PER_BLOCK > register_max_rows / BLOCKS_PER_SM;
    }

    int c_per_cluster = lcm(CPG, C_PER_BLOCK);
    int virtual_cluster_size = (c_per_cluster / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);

    // The occupancy is affected if cluster size is large.
    //   For example, on H100, when gridDim=128 and each block occupies the whole SM,
    //     if cluster is not used, all blocks can be active simultaneously.
    //     if cluster size is 16, not all blocks can be active simultaneously (which can be queried by cudaOccupancyMaxActiveClusters),
    //       so there will be two waves which impacts efficiency.
    // When SM_MARGIN is set, no cluster should be used because other kernels may occupy a part of the cluster.
    bool HARDWARE_CLUSTER = virtual_cluster_size <= 2 && virtual_cluster_size != 1 && SM_MARGIN == 0;

    int MAX_VEC_BYTES = 8;  // Sometimes 4 or 16 is better, but there is no trivial way to select the best vectorization size.
    int VEC_ELEMS = std::min(gcd(MAX_VEC_BYTES / (int)sizeof(T), C_PER_BLOCK),
                             gcd(MAX_VEC_BYTES / (int)sizeof(T), ROWS_PER_BLOCK * C_PER_BLOCK / BLOCK_DIM_X));

    return std::make_tuple(
        BLOCK_DIM_X,
        C_PER_BLOCK,
        ROWS_PER_BLOCK,
        VEC_ELEMS,
        LOAD_TWICE,
        BLOCKS_PER_SM,
        HARDWARE_CLUSTER,
        wgrad_sync_method
    );
}

// Save compilation time for unused CUDA_ARCHs
//   For each template argument from DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT, the kernel is only compiled for the corresponding CUDA_ARCH
template<int EFFECTIVE_CUDA_ARCH>
class CompileCondition {
public:
    __host__ __device__ static constexpr bool matches() {
#if defined(__CUDA_ARCH__)
        return __CUDA_ARCH__ == EFFECTIVE_CUDA_ARCH;
#else
        return false;
#endif
    }
};

template<typename T, int HW, int C, int G, bool SILU>
void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T)) {
    if (out == x) {
        throw std::invalid_argument("not __restrict__");
    }

    cudaDeviceProp const &deviceProp = get_device_prop(device_id);
    int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10;
    int sm_count = deviceProp.multiProcessorCount;

    DISPATCH_LOWER_BOUND_N(n, LB_N, [&] {
        DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] {
        DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] {
            if (hw != HW) {
                throw std::invalid_argument("wrong HW");
            }
            if (num_groups * channels_per_group != C) {
                throw std::invalid_argument("wrong C");
            }
            if (num_groups != G) {
                throw std::invalid_argument("wrong G");
            }
            if (silu != SILU) {
                throw std::invalid_argument("wrong SILU");
            }
            if (n < LB_N) {
                throw std::invalid_argument("wrong LB_N");
            }
            if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) {
                throw std::invalid_argument("wrong RUNTIME_CUDA_ARCH");
            }
            if (sm_count < LB_SM_COUNT) {
                throw std::invalid_argument("wrong LB_SM_COUNT");
            }
            if (sm_margin != SM_MARGIN) {
                throw std::invalid_argument("wrong SM_MARGIN");
            }
            constexpr int EFFECTIVE_CUDA_ARCH = std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch());  // Assume the max CUDA_ARCH is used to generate PTX

            constexpr int CPG = C / G;

            constexpr auto params = compute_gn_params<T, false, false, HW, G, CPG, LB_N, RUNTIME_CUDA_ARCH, LB_SM_COUNT, EFFECTIVE_CUDA_ARCH, SM_MARGIN>();
            constexpr int BLOCK_DIM_X =       std::get<0>(params);
            constexpr int C_PER_BLOCK =       std::get<1>(params);
            constexpr int ROWS_PER_BLOCK =    std::get<2>(params);
            constexpr int VEC_ELEMS =         std::get<3>(params);
            constexpr bool LOAD_TWICE =       std::get<4>(params);
            constexpr int BLOCKS_PER_SM =     std::get<5>(params);
            constexpr bool HARDWARE_CLUSTER = std::get<6>(params);

            constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK);
            constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);
            constexpr int NUM_VIRTUAL_CLUSTERS = ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE;
            constexpr bool PERSISTENT = !HARDWARE_CLUSTER && VIRTUAL_CLUSTER_SIZE >= 2;  // Only virtual cluster sync (not include hardware cluster sync) requires PERSISTENT kernels

            if (meta_ptr) {
                constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1;
                meta_ptr->red_buffer_size = 2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2;
                meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS;
                meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X;
                meta_ptr->C_PER_BLOCK = C_PER_BLOCK;
                meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK;
                meta_ptr->VEC_ELEMS = VEC_ELEMS;
                meta_ptr->LOAD_TWICE = LOAD_TWICE;
                meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM;
                meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER;
                meta_ptr->wgrad_sync_method = (int)WGRAD_SYNC_UNSPECIFIED;
            }
            if (meta_only) {
                return;
            }

            cudaLaunchConfig_t config = {0};
            config.gridDim = dim3(VIRTUAL_CLUSTER_SIZE, PERSISTENT ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER), 1);
            config.blockDim = BLOCK_DIM_X;
            config.stream = stream;

            cudaLaunchAttribute attribute[2];
            if constexpr (HARDWARE_CLUSTER) {
                attribute[0].id = cudaLaunchAttributeClusterDimension;
                attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; // Cluster size in X-dimension
                attribute[0].val.clusterDim.y = 1;
                attribute[0].val.clusterDim.z = 1;
                config.attrs = attribute;
                config.numAttrs++;
            }
            if constexpr (PERSISTENT) {
                attribute[config.numAttrs].id = cudaLaunchAttributeCooperative;
                attribute[config.numAttrs].val.cooperative = 1;
                config.attrs = attribute;
                config.numAttrs++;
            }

            auto kernel = &gn_cuda_kernel<T, BLOCK_DIM_X, BLOCKS_PER_SM, G, CPG, HW, SILU, ROWS_PER_BLOCK, C_PER_BLOCK, C_PER_CLUSTER, VEC_ELEMS, PERSISTENT, NUM_VIRTUAL_CLUSTERS, LOAD_TWICE, HARDWARE_CLUSTER, CompileCondition<EFFECTIVE_CUDA_ARCH> >;
            if constexpr (HARDWARE_CLUSTER) {
                if constexpr (VIRTUAL_CLUSTER_SIZE > 8) {
                    CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
                }
                int max_cluster_size;
                int active_clusters;
                CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void *)kernel, &config));
                if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) {
                    attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE;
                    CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void *)kernel, &config));
                }
                if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) {
                    attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE;
                } else {
                    // Fallback to cooperative groups because hardware cluster cannot be active simultaneously
                    constexpr bool HARDWARE_CLUSTER_NEW = false;
                    constexpr bool PERSISTENT_NEW = !HARDWARE_CLUSTER_NEW && VIRTUAL_CLUSTER_SIZE >= 2;
                    config.gridDim = dim3(VIRTUAL_CLUSTER_SIZE, PERSISTENT_NEW ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER), 1);
                    config.attrs = nullptr;
                    config.numAttrs = 0;
                    if constexpr (PERSISTENT_NEW) {
                        attribute[config.numAttrs].id = cudaLaunchAttributeCooperative;
                        attribute[config.numAttrs].val.cooperative = 1;
                        config.attrs = attribute;
                        config.numAttrs++;
                    }
                    kernel = &gn_cuda_kernel<T, BLOCK_DIM_X, BLOCKS_PER_SM, G, CPG, HW, SILU, ROWS_PER_BLOCK, C_PER_BLOCK, C_PER_CLUSTER, VEC_ELEMS, PERSISTENT_NEW, NUM_VIRTUAL_CLUSTERS, LOAD_TWICE, HARDWARE_CLUSTER_NEW, CompileCondition<EFFECTIVE_CUDA_ARCH> >;
                }
            }
            CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, out, x, w, b, eps, n, mean_var_out, red_buffer, barrier));
        });
        });
    });
}


template<typename T, int HW, int C, int G, bool SILU>
void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T)) {
    if (grad_input == grad_output || grad_input == x) {
        throw std::invalid_argument("not __restrict__");
    }

    cudaDeviceProp const &deviceProp = get_device_prop(device_id);
    int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10;
    int sm_count = deviceProp.multiProcessorCount;

    DISPATCH_LOWER_BOUND_N(n, LB_N, [&] {
        DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] {
        DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] {
            if (hw != HW) {
                throw std::invalid_argument("wrong HW");
            }
            if (num_groups * channels_per_group != C) {
                throw std::invalid_argument("wrong C");
            }
            if (num_groups != G) {
                throw std::invalid_argument("wrong G");
            }
            if (silu != SILU) {
                throw std::invalid_argument("wrong SILU");
            }
            if (n < LB_N) {
                throw std::invalid_argument("wrong LB_N");
            }
            if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) {
                throw std::invalid_argument("wrong RUNTIME_CUDA_ARCH");
            }
            if (sm_count < LB_SM_COUNT) {
                throw std::invalid_argument("wrong LB_SM_COUNT");
            }
            if (sm_margin != SM_MARGIN) {
                throw std::invalid_argument("wrong SM_MARGIN");
            }
            constexpr int EFFECTIVE_CUDA_ARCH = std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch());  // Assume the max CUDA_ARCH is used to generate PTX

            constexpr bool REQUIRES_WGRAD = true;
            constexpr int CPG = C / G;

            constexpr auto params = compute_gn_params<T, true, REQUIRES_WGRAD, HW, G, CPG, LB_N, RUNTIME_CUDA_ARCH, LB_SM_COUNT, EFFECTIVE_CUDA_ARCH, SM_MARGIN>();
            constexpr int BLOCK_DIM_X =       std::get<0>(params);
            constexpr int C_PER_BLOCK =       std::get<1>(params);
            constexpr int ROWS_PER_BLOCK =    std::get<2>(params);
            constexpr int VEC_ELEMS =         std::get<3>(params);
            constexpr bool LOAD_TWICE =       std::get<4>(params);
            constexpr int BLOCKS_PER_SM =     std::get<5>(params);
            constexpr bool HARDWARE_CLUSTER = std::get<6>(params);
            constexpr WgradSyncMethod wgrad_sync_method_hint = std::get<7>(params);

            constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK);
            constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);
            constexpr int NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED = ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE;

            // PERSISTENT is required because wgrad reduction requires synchronization.
            //   TODO: specilize for the case that REQUIRES_WGRAD == false
            constexpr bool PERSISTENT = true;

            // Determine whether to align each virtual cluster to a fixed range of channels
            //   If aligned, WGRAD_REUSE_SUM_SYNC_GROUP can be used, then less local wgrad memory is used (leave more room for compiler
            //     optimizations), and wgrad reduction is more efficient.
            //   However, aligning can cause low occupancy.
            //   There is a trade-off, and the condition to align is `NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C / C_PER_CLUSTER)`
            constexpr WgradSyncMethod wgrad_sync_method =
                wgrad_sync_method_hint == WGRAD_SYNC_UNSPECIFIED ?
                    NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C / C_PER_CLUSTER) || NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED % (C / C_PER_CLUSTER) == 0 ?
                        (HARDWARE_CLUSTER ? WGRAD_ARRIVE_AND_WAIT_GROUP : WGRAD_REUSE_SUM_SYNC_GROUP) :
                        WGRAD_REUSE_SUM_SYNC_GRID :
                    wgrad_sync_method_hint;
            constexpr int NUM_VIRTUAL_CLUSTERS =
                wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP ?
                    NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED / (C / C_PER_CLUSTER) * (C / C_PER_CLUSTER) :
                    NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED;

            if (meta_ptr) {
                constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1;
                meta_ptr->red_buffer_size = 2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2 +
                    std::max(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) * (HW / ROWS_PER_BLOCK) * C * 2;
                meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS + C / C_PER_CLUSTER;
                meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X;
                meta_ptr->C_PER_BLOCK = C_PER_BLOCK;
                meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK;
                meta_ptr->VEC_ELEMS = VEC_ELEMS;
                meta_ptr->LOAD_TWICE = LOAD_TWICE;
                meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM;
                meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER;
                meta_ptr->wgrad_sync_method = (int)wgrad_sync_method;
            }
            if (meta_only) {
                return;
            }

            cudaLaunchConfig_t config = {0};
            config.gridDim = dim3(VIRTUAL_CLUSTER_SIZE, PERSISTENT ? NUM_VIRTUAL_CLUSTERS : n * (C / C_PER_CLUSTER), 1);
            config.blockDim = BLOCK_DIM_X;
            config.stream = stream;

            cudaLaunchAttribute attribute[2];
            if constexpr (HARDWARE_CLUSTER) {
                attribute[0].id = cudaLaunchAttributeClusterDimension;
                attribute[0].val.clusterDim.x = 1; // Cluster size in X-dimension
                attribute[0].val.clusterDim.y = 1;
                attribute[0].val.clusterDim.z = 1;
                config.attrs = attribute;
                config.numAttrs++;
            }
            if constexpr (PERSISTENT) {
                attribute[config.numAttrs].id = cudaLaunchAttributeCooperative;
                attribute[config.numAttrs].val.cooperative = 1;
                config.attrs = attribute;
                config.numAttrs++;
            }

            auto kernel = &gn_bwd_cuda_kernel<T, BLOCK_DIM_X, BLOCKS_PER_SM, G, CPG, HW, SILU, REQUIRES_WGRAD, ROWS_PER_BLOCK, C_PER_BLOCK, C_PER_CLUSTER, VEC_ELEMS, PERSISTENT, NUM_VIRTUAL_CLUSTERS, LOAD_TWICE, HARDWARE_CLUSTER, wgrad_sync_method, CompileCondition<EFFECTIVE_CUDA_ARCH> >;
            if constexpr (HARDWARE_CLUSTER) {
                if constexpr (VIRTUAL_CLUSTER_SIZE > 8) {
                    CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
                }
                int max_cluster_size;
                int active_clusters;
                CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void *)kernel, &config));
                if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) {
                    attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE;
                    CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void *)kernel, &config));
                }
                if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) {
                    attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE;
                } else {
                    // Fallback to cooperative groups for dgrad computation because hardware cluster cannot be active simultaneously
                    attribute[0].val.clusterDim.x = 1;
                    kernel = &gn_bwd_cuda_kernel<T, BLOCK_DIM_X, BLOCKS_PER_SM, G, CPG, HW, SILU, REQUIRES_WGRAD, ROWS_PER_BLOCK, C_PER_BLOCK, C_PER_CLUSTER, VEC_ELEMS, PERSISTENT, NUM_VIRTUAL_CLUSTERS, LOAD_TWICE, false, wgrad_sync_method, CompileCondition<EFFECTIVE_CUDA_ARCH> >;
                }
            }
            CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, grad_input, grad_weight, grad_bias, grad_output, x, w, b, mean_var, eps, n, red_buffer, barrier));
        });
        });
    });
}

#define GN_CUDA_INST_DEFINE(HW, C) \
    template void gn_cuda_single_shape<half, HW, C, 16, true>(GN_CUDA_HOST_PARAMS(half)); \
    template void gn_cuda_single_shape<half, HW, C, 32, false>(GN_CUDA_HOST_PARAMS(half)); \
    template void gn_bwd_cuda_single_shape<half, HW, C, 16, true>(GN_BWD_CUDA_HOST_PARAMS(half)); \
    template void gn_bwd_cuda_single_shape<half, HW, C, 32, false>(GN_BWD_CUDA_HOST_PARAMS(half)); \
    template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); \
    template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); \
    template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); \
    template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16));

}  // namespace group_norm_v2
