// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
#include "quantization.h"
#include "quantization_utils.h"
#include "reduction_utils.h"

namespace cg = cooperative_groups;

/*
Pure quantization kernel with no fusion.
*/
template <int q_bits,
          quantize::Type quant_type,
          int UNROLL,
          int internal_unroll,
          int threads_per_group,
          int max_threads>
__global__ void cached_quantization(int8_t* __restrict__ output_data,
                                    float* __restrict__ params,
                                    const __half* __restrict__ input_data,
                                    int groups,
                                    int elems_per_group)
{
    cg::thread_block tb = cg::this_thread_block();
    cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);

    // Indexing offsets
    const int64_t block_offset =
        (static_cast<int64_t>(tb.group_index().x) * (max_threads / threads_per_group) * elems_per_group) +
        (tb.thread_index().y * elems_per_group);
    const int elem_offset = tb.thread_index().x * quantize::h_per_load;
    const int64_t base_offset = block_offset + elem_offset;
    const int stride = tb.size() * quantize::h_per_load;

    const __half* input_base = input_data + base_offset;  //..

    __half2 local_buffer[UNROLL * internal_unroll * quantize::h2_per_load];

#pragma unroll
    for (int i = 0; i < UNROLL; i++) {
        // Convenience helper, should resolve to register indices and not realize.
        __half2* iteration_buffer = local_buffer + i * internal_unroll * quantize::h2_per_load;
#pragma unroll
        for (int j = 0; j < internal_unroll; j++) {
            const int iteration = i * internal_unroll + j;
            mem_access::load_global<quantize::granularity>(
                iteration_buffer + j * quantize::h2_per_load,
                input_base + iteration * stride,
                elem_offset + iteration * stride < elems_per_group);
        }
    }

    quantize::
        local_array<quant_type, q_bits, UNROLL * internal_unroll, threads_per_group, max_threads>(
            local_buffer, params, output_data, elems_per_group, groups);
}

#ifdef BF16_AVAILABLE
template <int q_bits,
          quantize::Type quant_type,
          int UNROLL,
          int internal_unroll,
          int threads_per_group,
          int max_threads>
__global__ void cached_quantization(int8_t* __restrict__ output_data,
                                    float* __restrict__ params,
                                    const __nv_bfloat16* __restrict__ input_data, // Updated to bfloat16
                                    int groups,
                                    int elems_per_group)
{
    cg::thread_block tb = cg::this_thread_block();
    cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);

    // Indexing offsets
    const int64_t block_offset =
        (static_cast<int64_t>(tb.group_index().x) * (max_threads / threads_per_group) * elems_per_group) +
        (tb.thread_index().y * elems_per_group);
    const int elem_offset = tb.thread_index().x * quantize::bf_per_load;
    const int64_t base_offset = block_offset + elem_offset;
    const int stride = tb.size() * quantize::bf_per_load;

    const __nv_bfloat16* input_base = input_data + base_offset;

    __nv_bfloat162 local_buffer[UNROLL * internal_unroll * quantize::bf2_per_load]; // Updated buffer type

#pragma unroll
    for (int i = 0; i < UNROLL; i++) {
        __nv_bfloat162* iteration_buffer = local_buffer + i * internal_unroll * quantize::bf2_per_load; // Updated pointer type
#pragma unroll
        for (int j = 0; j < internal_unroll; j++) {
            const int iteration = i * internal_unroll + j;
            mem_access::load_global<quantize::granularity>(
                iteration_buffer + j * quantize::bf2_per_load,
                input_base + iteration * stride,
                elem_offset + iteration * stride < elems_per_group);
        }
    }

    quantize::
        local_array<quant_type, q_bits, UNROLL * internal_unroll, threads_per_group, max_threads>(
            local_buffer, params, output_data, elems_per_group, groups);
}
#endif

template <int q_bits,
          quantize::Type quant_type,
          int UNROLL,
          int internal_unroll,
          int threads_per_group,
          int max_threads>
__global__ void cached_quantization(int8_t* __restrict__ output_data,
                                    float* __restrict__ params,
                                    const float* __restrict__ input_data,
                                    int groups,
                                    int elems_per_group)
{
    cg::thread_block tb = cg::this_thread_block();
    cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);

    // Indexing offsets
    const int64_t block_offset =
        (static_cast<int64_t>(tb.group_index().x) * (max_threads / threads_per_group) * elems_per_group) +
        (tb.thread_index().y * elems_per_group);
    const int elem_offset = tb.thread_index().x * quantize::f_per_load;
    const int64_t base_offset = block_offset + elem_offset;
    const int stride = tb.size() * quantize::f_per_load;

    const float* input_base = input_data + base_offset;  //..

    float local_buffer[UNROLL * internal_unroll * quantize::f_per_load];

#pragma unroll
    for (int i = 0; i < UNROLL; i++) {
        // Convenience helper, should resolve to register indices and not realize.
        float* iteration_buffer = local_buffer + i * internal_unroll * quantize::f_per_load;
#pragma unroll
        for (int j = 0; j < internal_unroll; j++) {
            const int iteration = i * internal_unroll + j;
            mem_access::load_global<quantize::granularity>(
                iteration_buffer + j * quantize::f_per_load,
                input_base + iteration * stride,
                elem_offset + iteration * stride < elems_per_group);
        }
    }

    quantize::
        local_array<quant_type, q_bits, UNROLL * internal_unroll, threads_per_group, max_threads>(
            local_buffer, params, output_data, elems_per_group, groups);
}

/********* Launcher methods ***********/
#define LAUNCH_CACHED_QUANT_CALL(q_bits, quant_type) \
    cached_quantization<q_bits,                      \
                        quant_type,                  \
                        unroll_factor,               \
                        internal_unroll_l,           \
                        threads_per_group,           \
                        max_threads>                 \
        <<<grid, block, 0, stream>>>(output_data, params, input_data, groups, elems_per_group);

#define LAUNCH_CACHED_QUANT(                                                        \
    q_bits, quant_type, unroll_factor_in, internal_unroll_in, threads_per_group_in) \
    const int unroll_factor = unroll_factor_in;                                     \
    const int internal_unroll_l = internal_unroll_in;                               \
    const int threads_per_group = threads_per_group_in;                             \
    if (q_bits == 4) {                                                              \
        if (quant_type == quantize::Type::Asymmetric) {                             \
            LAUNCH_CACHED_QUANT_CALL(4, quantize::Type::Asymmetric)                 \
        } else {                                                                    \
            LAUNCH_CACHED_QUANT_CALL(4, quantize::Type::Symmetric)                  \
        }                                                                           \
    } else if (q_bits == 8) {                                                       \
        if (quant_type == quantize::Type::Asymmetric) {                             \
            LAUNCH_CACHED_QUANT_CALL(8, quantize::Type::Asymmetric)                 \
        } else {                                                                    \
            LAUNCH_CACHED_QUANT_CALL(8, quantize::Type::Symmetric)                  \
        }                                                                           \
    } else if (q_bits == 1) {                                                       \
        if (quant_type == quantize::Type::Asymmetric) {                             \
            LAUNCH_CACHED_QUANT_CALL(1, quantize::Type::Asymmetric)                 \
        } else {                                                                    \
            LAUNCH_CACHED_QUANT_CALL(1, quantize::Type::Symmetric)                  \
        }                                                                           \
    }                                                                               \

void launch_quant(int8_t* output_data,
                  float* params,
                  const __half* input_data,
                  const int groups,
                  const int elems_per_group,
                  const int num_bits,
                  const quantize::Type quant_type,
                  cudaStream_t stream)
{
    constexpr int max_threads = 256;

    constexpr int internal_unroll = 2;

    const bool is_subblock_schedule = (elems_per_group <= 128) ? true : false;
    const int h_per_step = is_subblock_schedule ? quantize::h_per_load
                                                : quantize::h_per_load * internal_unroll;

    // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
    // warp-sized blocks rather than stepping up to 64/96 threads
    const int one_step_threads = next_pow2((elems_per_group + h_per_step - 1) / h_per_step);
    const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads;

    const int groups_per_block =
        is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1;
    const int groups_launch = (groups_per_block + groups - 1) / groups_per_block;

    dim3 block(threads_per_group, groups_per_block);
    dim3 grid(groups_launch);

    const int elems_per_step = threads_per_group * h_per_step;
    const int external_unroll = (elems_per_group + elems_per_step - 1) / elems_per_step;

    if (is_subblock_schedule) {
        // <=128
        if (threads_per_group == 1) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 1);
        } else if (threads_per_group == 2) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 2);
        } else if (threads_per_group == 4) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 4);
        } else if (threads_per_group == 8) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 8);
        } else if (threads_per_group == 16) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 16);
        }
    } else if (external_unroll == 1) {
        // 129 - 4096 elems
        // (this can launch with 1-7 warps as well)
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, internal_unroll, max_threads);
    } else if (external_unroll == 2) {
        // 4097 - 8192 elems
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 2, internal_unroll, max_threads);
    } else if (external_unroll == 3) {
        // 8193 - 12288 elems
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 3, internal_unroll, max_threads);
    } else if (external_unroll == 4) {
        // 12289 - 16384 elems
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 4, internal_unroll, max_threads);
    }
}

void launch_quant(int8_t* output_data,
                  float* params,
                  const float* input_data,
                  const int groups,
                  const int elems_per_group,
                  const int num_bits,
                  const quantize::Type quant_type,
                  cudaStream_t stream) 
{
    constexpr int max_threads = 256;

    constexpr int internal_unroll = 2;

    const bool is_subblock_schedule = (elems_per_group <= 128) ? true : false;
    const int f_per_step = is_subblock_schedule ? quantize::f_per_load
                                                : quantize::f_per_load * internal_unroll;

    // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
    // warp-sized blocks rather than stepping up to 64/96 threads
    const int one_step_threads = next_pow2((elems_per_group + f_per_step - 1) / f_per_step);
    const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads;

    const int groups_per_block =
        is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1;
    const int groups_launch = (groups_per_block + groups - 1) / groups_per_block;

    dim3 block(threads_per_group, groups_per_block);
    dim3 grid(groups_launch);

    const int elems_per_step = threads_per_group * f_per_step;
    const int external_unroll = (elems_per_group + elems_per_step - 1) / elems_per_step;

    if (is_subblock_schedule) {
        // <=128
        if (threads_per_group == 1) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 1);
        } else if (threads_per_group == 2) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 2);
        } else if (threads_per_group == 4) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 4);
        } else if (threads_per_group == 8) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 8);
        } else if (threads_per_group == 16) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 16);
        } else if (threads_per_group == 32) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 32);
        }
    } else if (external_unroll == 1) {
        // 129 - 4096 elems
        // (this can launch with 1-7 warps as well)
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, internal_unroll, max_threads);
    } else if (external_unroll == 2) {
        // 4097 - 8192 elems
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 2, internal_unroll, max_threads);
    } else if (external_unroll == 3) {
        // 8193 - 12288 elems
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 3, internal_unroll, max_threads);
    } else if (external_unroll == 4) {
        // 12289 - 16384 elems
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 4, internal_unroll, max_threads);
    }
}

#ifdef BF16_AVAILABLE
void launch_quant(int8_t* output_data,
                  float* params,
                  const __nv_bfloat16* input_data,  // Changed from __half to __nv_bfloat16
                  const int groups,
                  const int elems_per_group,
                  const int num_bits,
                  const quantize::Type quant_type,
                  cudaStream_t stream)
{
    constexpr int max_threads = 256;
    constexpr int internal_unroll = 2;
    const bool is_subblock_schedule = (elems_per_group <= 128) ? true : false;
    const int bf_per_step = is_subblock_schedule ? quantize::bf_per_load
                                                : quantize::bf_per_load * internal_unroll;

    const int one_step_threads = next_pow2((elems_per_group + bf_per_step - 1) / bf_per_step);
    const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads;
    const int groups_per_block = is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1;
    const int groups_launch = (groups_per_block + groups - 1) / groups_per_block;

    dim3 block(threads_per_group, groups_per_block);
    dim3 grid(groups_launch);
    const int elems_per_step = threads_per_group * bf_per_step;
    const int external_unroll = (elems_per_group + elems_per_step - 1) / elems_per_step;

    if (is_subblock_schedule) {
        if (threads_per_group == 1) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 1);
        } else if (threads_per_group == 2) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 2);
        } else if (threads_per_group == 4) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 4);
        } else if (threads_per_group == 8) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 8);
        } else if (threads_per_group == 16) {
            LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 16);
        }
    } else if (external_unroll == 1) {
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, internal_unroll, max_threads);
    } else if (external_unroll == 2) {
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 2, internal_unroll, max_threads);
    } else if (external_unroll == 3) {
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 3, internal_unroll, max_threads);
    } else if (external_unroll == 4) {
        LAUNCH_CACHED_QUANT(num_bits, quant_type, 4, internal_unroll, max_threads);
    }
}
#endif