/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"

using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>;

template<bool Is_training>
__global__ 
void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
                                           const int total_heads) {

    fmha::device_1xN<Kernel_traits, Is_training>(params, total_heads);
}

template<bool Is_training>
__global__ 
void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params,
                                           const int num_full_heads,
                                           const int num_main_groups,
                                           const int main_group_size,
                                           const int main_steps,
                                           const int rest_steps) {

    fmha::device_1xN<Kernel_traits, Is_training>(
        params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}

void run_fmha_fp16_512_64_sm80_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {

    auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel<true> : &fmha_fprop_fp16_512_64_sm80_kernel<false>;

    constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();

    if( smem_size >= 48 * 1024 ) {
        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
    }

    const int sm_count = launch_params.props->multiProcessorCount;
    int ctas_per_sm;
    FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
    int total_ctas = sm_count * ctas_per_sm;

    const int heads_total = launch_params.params.b * launch_params.params.h;
    if(configure) {

        using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
        constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
        constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
        constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;

        size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas);
        size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8;
        launch_params.elts_per_thread = heads_per_cta * elts_per_head;
        return;
    }

    dim3 grid(total_ctas);
    kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
        launch_params.params,
        heads_total);

    FMHA_CHECK_CUDA(cudaPeekAtLastError());

}

void run_fmha_fp16_512_64_sm80_nl_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {

    auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel_nl<true> : &fmha_fprop_fp16_512_64_sm80_kernel_nl<false>;

    constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();

    if( smem_size >= 48 * 1024 ) {
        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
    }

    const int sm_count = launch_params.props->multiProcessorCount;
    int ctas_per_sm;
    FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
    int total_ctas = sm_count * ctas_per_sm;

    if(configure) {
        const int heads_total = launch_params.params.b * launch_params.params.h;
        std::tie(launch_params.num_full_heads,
                 launch_params.num_main_groups, 
                 launch_params.heads_last_wave, 
                 launch_params.main_steps, 
                 launch_params.rest_steps, 
                 launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
        return;
    }

    dim3 grid(total_ctas);
    kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
        launch_params.params,
        launch_params.num_full_heads, 
        launch_params.num_main_groups, 
        launch_params.heads_last_wave, 
        launch_params.main_steps, 
        launch_params.rest_steps);

    FMHA_CHECK_CUDA(cudaPeekAtLastError());

}

void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
    if( launch_params.is_nl ) {
        run_fmha_fp16_512_64_sm80_nl_(launch_params, configure);
    } else {
        run_fmha_fp16_512_64_sm80_(launch_params, configure);
    }
}
