/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <multihead_attn/philox.cuh>

#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>

namespace fmha {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<int THREADS_PER_CTA>
struct BlockInfoPadded {

    template<typename Params>
    __device__ BlockInfoPadded(const Params &params,
                               const int bidb,
                               const int bidh,
                               const int tidx)
        : bidb(bidb), bidh(bidh), h(params.h) {

        // The block index.
        sum_s = params.cu_seqlens[bidb];
        actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;
        bidx = sum_s * params.h + bidh;

        tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
    }

    __device__ bool stop_early() const {
        return actual_seqlen == 0;
    }

    int actual_seqlen;
    int bidx;
    int sum_s;
    int bidh;
    int bidb;
    int tidx_global;
    int h;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<int CHUNKS, typename Cta_tile> 
struct Noloop_traits{
    // Interpretation of Cta_tile dims, i.e. Cta_tile_p:
    enum{ STEP = Cta_tile::M };
    enum{ SEQLEN = Cta_tile::N };

    template<typename Block_info>
    inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) 
        : bidc_(bidc) {
        const int seqlen = binfo.actual_seqlen;
        const int steps = (seqlen  + STEP - 1) / STEP;
        const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS;

        const int step_begin = bidc_ * steps_per_chunk; 
        const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk);
        const int actual_steps = max(0, step_end - step_begin);
        loop_offset_ = step_begin;
        num_steps_ = actual_steps;

    }

    template<typename ... Tiles> 
    inline __device__ void move_all(Tiles & ... tiles) const {
        using expand_type = int[];
        for( int s = 0; s < loop_offset_; s++ ) {
            expand_type{ (tiles.move(), 0)... };
        }
    }

    inline __device__ int get_idx_dk() const {
        //return bidc_;
        return bidc_ * 2 + 0;
    }

    inline __device__ int get_idx_dv() const {
        //return CHUNKS + bidc_;
        return bidc_ * 2 + 1;
    }

    inline __device__ int offset_loop_count(const int l) {
        // convert loop counter to position in the outer sequence
        return (loop_offset_ + l) * STEP;
    }

    const uint32_t bidc_;
    int loop_offset_;
    int num_steps_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits>
std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const int heads_total) {

    constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;

    const int num_full_heads = heads_total / total_ctas;
    const int heads_last_wave = heads_total % total_ctas; 

    int num_main_groups = 0;
    int main_steps = 0;
    int rest_steps = 0;
    if( heads_last_wave > 0 ) {
        // Number of CTA groups that process within heads.
        num_main_groups = total_ctas / heads_last_wave;
        // Remaining CTAs that process between heads.
        const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups);
        if(rest_ctas == 0) {
            // We have exactly "num_main_groups" CTAs to process each of the remaining heads.
            main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups;
            num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0
            rest_steps = STEPS_PER_HEAD % main_steps;

        } else {
            // Ideal number of steps if we could load-balance as evenly as possible.
            const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas;
            // Iterations that a "rest" CTA has to do at most.
            const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas;
            // Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
            main_steps = steps_ideal;
            rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
            for( ; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++ ) {
                rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
                const int max_rest_total_steps = rest_steps * max_rest_iters;
                if( max_rest_total_steps < main_steps )
                    break;
            }
            rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
        }
    }

    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;

    const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps);
    const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8;
    const int elts_per_thread = max_steps * elts_per_thread_per_step;

    return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread};
}

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace fmha
