/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Template for a pipelined GEMM kernel. Does not compute batching or
   support split-K.
*/

#pragma once

#include "cutlass/cutlass.h"

#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace kernel {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Mma_,  ///! Threadblock-scoped matrix multiply-accumulate
          typename Epilogue_,           ///! Epilogue
          typename ThreadblockSwizzle_  ///! Threadblock swizzling function
          >
struct GemmBatched {
    using Mma = Mma_;
    using Epilogue = Epilogue_;
    using OutputOp = typename Epilogue::OutputOp;
    using ThreadblockSwizzle = ThreadblockSwizzle_;

    /// Warp count (concept: GemmShape)
    using WarpCount = typename Mma::WarpCount;
    static int const kThreadCount = 32 * WarpCount::kCount;

    /// Parameters structure
    struct Params {
        cutlass::gemm::GemmCoord problem_size;
        cutlass::gemm::GemmCoord grid_tiled_shape;
        typename Mma::IteratorA::Params params_A;
        typename Mma::IteratorA::TensorRef ref_A;
        int64_t stride_A;
        typename Mma::IteratorB::Params params_B;
        typename Mma::IteratorB::TensorRef ref_B;
        int64_t stride_B;
        typename Epilogue::OutputTileIterator::Params params_C;
        typename Epilogue::OutputTileIterator::TensorRef ref_C;
        int64_t stride_C;
        typename Epilogue::OutputTileIterator::Params params_D;
        typename Epilogue::OutputTileIterator::TensorRef ref_D;
        int64_t stride_D;
        typename OutputOp::Params epilogue;
        int batch_count;
        int gemm_k_iterations;

        //
        // Methods
        //

        CUTLASS_HOST_DEVICE
        Params() {}

        CUTLASS_HOST_DEVICE
        Params(cutlass::gemm::GemmCoord const& problem_size_,
               cutlass::gemm::GemmCoord const& grid_tiled_shape_,
               typename Mma::IteratorA::TensorRef ref_A_, int64_t stride_A_,
               typename Mma::IteratorB::TensorRef ref_B_, int64_t stride_B_,
               typename Epilogue::OutputTileIterator::TensorRef ref_C_,
               int64_t stride_C_,
               typename Epilogue::OutputTileIterator::TensorRef ref_D_,
               int64_t stride_D_, typename OutputOp::Params epilogue_,
               int batch_count_)
                : problem_size(problem_size_),
                  grid_tiled_shape(grid_tiled_shape_),
                  params_A(ref_A_.layout()),
                  ref_A(ref_A_),
                  stride_A(stride_A_),
                  params_B(ref_B_.layout()),
                  ref_B(ref_B_),
                  stride_B(stride_B_),
                  params_C(ref_C_.layout()),
                  ref_C(ref_C_),
                  stride_C(stride_C_),
                  params_D(ref_D_.layout()),
                  ref_D(ref_D_),
                  stride_D(stride_D_),
                  epilogue(epilogue_),
                  batch_count(batch_count_),
                  gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) /
                                    Mma::Shape::kK) {}
    };

    /// Shared memory storage structure
    union SharedStorage {
        typename Mma::SharedStorage main_loop;
        typename Epilogue::SharedStorage epilogue;
    };

    //
    // Methods
    //

    CUTLASS_HOST_DEVICE
    GemmBatched() {}

    /// Executes one GEMM
    CUTLASS_DEVICE
    void operator()(Params const& params, SharedStorage& shared_storage) {
        // Compute threadblock location
        ThreadblockSwizzle threadblock_swizzle;

        cutlass::gemm::GemmCoord threadblock_tile_offset =
                threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);

        // Early exit if CTA is out of range
        if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
            params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
            return;
        }

        // Each CTA handles multiple batch indices to accommodate limited range
        // of CUDA grid's Z dimension
        for (int batch_idx = threadblock_swizzle.get_batch_idx();
             batch_idx < params.batch_count; batch_idx += gridDim.z) {
            // Compute initial location in logical coordinates
            cutlass::MatrixCoord tb_offset_A{
                    threadblock_tile_offset.m() * Mma::Shape::kM, 0};

            cutlass::MatrixCoord tb_offset_B{
                    0, threadblock_tile_offset.n() * Mma::Shape::kN};

            // Compute position within threadblock
            int thread_idx = threadIdx.x;

            // Construct iterators to A and B operands
            typename Mma::IteratorA iterator_A(
                    params.params_A, params.ref_A.data(),
                    params.problem_size.mk(), thread_idx, tb_offset_A);

            iterator_A.add_pointer_offset(params.stride_A * batch_idx);

            typename Mma::IteratorB iterator_B(
                    params.params_B, params.ref_B.data(),
                    params.problem_size.kn(), thread_idx, tb_offset_B);

            iterator_B.add_pointer_offset(params.stride_B * batch_idx);

            //
            // Main loop
            //

            // Broadcast the warp_id computed by lane 0 to ensure dependent code
            // is compiled as warp-uniform.
            int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);

            int lane_idx = threadIdx.x % 32;

            Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

            typename Mma::FragmentC accumulators;

            accumulators.clear();

            // Compute threadblock-scoped matrix multiply-add
            mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B,
                accumulators);

            //
            // Epilogue
            //

            OutputOp output_op(params.epilogue);

            //
            // Masked tile iterators constructed from members
            //

            threadblock_tile_offset = threadblock_swizzle.get_tile_offset(
                    params.grid_tiled_shape);

            // assume identity swizzle
            MatrixCoord threadblock_offset(
                    threadblock_tile_offset.m() * Mma::Shape::kM,
                    threadblock_tile_offset.n() * Mma::Shape::kN);

            // Tile iterator writing to output tile
            typename Epilogue::OutputTileIterator iterator_C(
                    params.params_C, params.ref_C.data(),
                    params.problem_size.mn(), thread_idx, threadblock_offset);

            iterator_C.add_pointer_offset(params.stride_C * batch_idx);

            // Tile iterator writing to output tile
            typename Epilogue::OutputTileIterator iterator_D(
                    params.params_D, params.ref_D.data(),
                    params.problem_size.mn(), thread_idx, threadblock_offset);

            iterator_D.add_pointer_offset(params.stride_D * batch_idx);

            Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx,
                              lane_idx);

            // run efficient epilogue
            epilogue(output_op, iterator_D, accumulators, iterator_C);
        }
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace kernel
}  // namespace gemm
}  // namespace cutlass
