/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/aligned_buffer.h"

#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"

#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
        /// Size of the Gemm problem - concept: gemm::GemmShape<>
        typename Shape_,
        /// Iterates over tiles of A operand in global memory
        //  (concept: ReadableTileIterator | ForwardTileIterator |
        //  MaskedTileIterator)
        typename IteratorA_,
        /// Iterates over tiles of A operand in shared memory
        /// (concept: WriteableTileIterator | RandomAccessTileIterator)
        typename SmemIteratorA_,
        /// Iterates over tiles of B operand in global memory
        //  (concept: ReadableTileIterator | ForwardTileIterator |
        //  MaskedTileIterator)
        typename IteratorB_,
        /// Iterates over tiles of B operand in shared memory
        /// (concept: WriteableTileIterator | RandomAccessTileIterator)
        typename SmemIteratorB_,
        /// Data type of accumulator matrix
        typename ElementC_,
        /// Data type of accumulator matrix
        typename LayoutC_,
        /// Policy describing tuning details (concept: MmaPolicy)
        typename Policy_,
        /// Used for partial specialization
        typename Enable = bool>
class MmaSingleStage : public MmaBase<Shape_, Policy_, 1> {
public:
    ///< Base class
    using Base = MmaBase<Shape_, Policy_, 1>;

    using Shape =
            Shape_;  ///< Size of the Gemm problem - concept: gemm::GemmShape<>
    using IteratorA =
            IteratorA_;  ///< Iterates over tiles of A operand in global memory
    using IteratorB =
            IteratorB_;  ///< Iterates over tiles of B operand in global memory
    using ElementC = ElementC_;  ///< Data type of accumulator matrix
    using LayoutC = LayoutC_;    ///< Layout of accumulator matrix
    using Policy = Policy_;      ///< Policy describing tuning details

    using SmemIteratorA = SmemIteratorA_;
    using SmemIteratorB = SmemIteratorB_;

    //
    // Dependent types
    //

    /// Fragment of operand A loaded from global memory
    using FragmentA = typename IteratorA::Fragment;

    /// Fragment of operand B loaded from global memory
    using FragmentB = typename IteratorB::Fragment;

    /// Fragment of accumulator tile
    using FragmentC = typename Policy::Operator::FragmentC;

    /// Warp-level Mma
    using Operator = typename Policy::Operator;

    using ArchTag = arch::Sm70;

    /// Complex transform on A operand
    static ComplexTransform const kTransformA = Operator::kTransformA;

    /// Complex transform on B operand
    static ComplexTransform const kTransformB = Operator::kTransformB;

    // staticaly assert kStages for MmaSingleStage is 1 (single stage mma
    // pipeline)
    static_assert((Base::kStages == 1),
                  "MmaSingleStage requires kStages set to value 1");

private:
    using WarpFragmentA = typename Operator::FragmentA;
    using WarpFragmentB = typename Operator::FragmentB;

protected:
    /// Iterator to write threadblock-scoped tile of A operand to shared memory
    SmemIteratorA smem_iterator_A_;

    /// Iterator to write threadblock-scoped tile of B operand to shared memory
    SmemIteratorB smem_iterator_B_;

public:
    /// Construct from tensor references
    CUTLASS_DEVICE
    MmaSingleStage(
            typename Base::SharedStorage&
                    shared_storage,  ///< Shared storage needed for internal use
                                     ///< by threadblock-scoped GEMM
            int thread_idx,          ///< ID within the threadblock
            int warp_idx,            ///< ID of warp
            int lane_idx             ///< ID of each thread within a warp
            )
            : Base(shared_storage, thread_idx, warp_idx, lane_idx),
              smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
              smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
        // Compute warp location within threadblock tile by mapping the warp_id
        // to three coordinates:
        //   _m: the warp's position within the threadblock along the M
        //   dimension _n: the warp's position within the threadblock along the
        //   N dimension _k: the warp's position within the threadblock along
        //   the K dimension

        int warp_idx_mn =
                warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
        int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);

        int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
        int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;

        // Add per-warp offsets in units of warp-level tiles
        this->warp_tile_iterator_A_.add_tile_offset(
                {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
        this->warp_tile_iterator_B_.add_tile_offset(
                {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
    }

    /// Perform a threadblock-scoped matrix multiply-accumulate
    CUTLASS_DEVICE
    void operator()(
            int gemm_k_iterations,  ///< number of iterations of the mainloop
            FragmentC& accum,       ///< destination accumulator tile
            IteratorA iterator_A,  ///< iterator over A operand in global memory
            IteratorB iterator_B,  ///< iterator over B operand in global memory
            FragmentC const& src_accum) {  ///< source accumualtor tile

        //
        // Prologue
        //

        // Perform accumulation in the 'd' output operand
        accum = src_accum;

        FragmentA tb_frag_A;
        FragmentB tb_frag_B;

        tb_frag_A.clear();
        tb_frag_B.clear();

        // The last kblock is loaded in the prolog
        iterator_A.load(tb_frag_A);
        iterator_B.load(tb_frag_B);

        ++iterator_A;
        ++iterator_B;

        // Pair of fragments used to overlap shared memory loads and math
        // instructions
        WarpFragmentA warp_frag_A;
        WarpFragmentB warp_frag_B;

        Operator warp_mma;

        // Avoid reading out of bounds
        if (gemm_k_iterations <= 1) {
            iterator_A.clear_mask();
            iterator_B.clear_mask();
        }

        //
        // Mainloop
        //

        CUTLASS_GEMM_LOOP
        for (; gemm_k_iterations > 0; --gemm_k_iterations) {
            this->smem_iterator_A_.store(tb_frag_A);
            this->smem_iterator_B_.store(tb_frag_B);

            __syncthreads();

            //
            // Loop over GEMM K dimension
            //

            CUTLASS_PRAGMA_UNROLL
            for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
                 ++warp_mma_k) {
                // Load warp-level tiles from shared memory, wrapping to k
                // offset if this is the last group as the case may be.

                this->warp_tile_iterator_A_.set_kgroup_index(
                        warp_mma_k % Base::kWarpGemmIterations);
                this->warp_tile_iterator_B_.set_kgroup_index(
                        warp_mma_k % Base::kWarpGemmIterations);

                this->warp_tile_iterator_A_.load(warp_frag_A);
                this->warp_tile_iterator_B_.load(warp_frag_B);

                ++this->warp_tile_iterator_A_;
                ++this->warp_tile_iterator_B_;

                warp_mma(accum, warp_frag_A, warp_frag_B, accum);
            }

            // Add negative offsets to return smem load iterators to the 'start'
            // of the shared memory
            this->warp_tile_iterator_A_.add_tile_offset(
                    {0, -Policy::kPartitionsK * Base::kWarpGemmIterations});
            this->warp_tile_iterator_B_.add_tile_offset(
                    {-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});

            __syncthreads();

            iterator_A.load(tb_frag_A);
            iterator_B.load(tb_frag_B);

            ++iterator_A;
            ++iterator_B;

            // Avoid reading out of bounds if this was the last loop iteration
            if (gemm_k_iterations <= 2) {
                iterator_A.clear_mask();
                iterator_B.clear_mask();
            }
        }
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace threadblock
}  // namespace gemm
}  // namespace cutlass
