/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/

#pragma once

#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace threadblock {

////////////////////////////////////////////////////////////////////////////////

////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
        /// Size of the Gemm problem - concept: gemm::GemmShape<>
        typename Shape0_,
        /// Size of the Gemm problem - concept: gemm::GemmShape<>
        typename Shape1_,
        /// Policy describing tuning details (concept: MmaPolicy)
        typename Policy0_,
        /// Policy describing tuning details (concept: MmaPolicy)
        typename Policy1_,
        /// Number of stages,
        int Stages,
        /// Used for partial specialization
        typename Enable = bool>
class B2bMmaBase {
public:
    ///< Size of the Gemm problem - concept: gemm::GemmShape<>
    using Shape0 = Shape0_;
    using Shape1 = Shape1_;

    ///< Policy describing tuning details
    using Policy0 = Policy0_;
    using Policy1 = Policy1_;

    //
    // Dependent types
    //

    /// Warp-level Mma
    using Operator0 = typename Policy0::Operator;
    using Operator1 = typename Policy1::Operator;

    /// Shape describing the overall GEMM computed from shared memory
    /// by each warp.
    using WarpGemm0 = typename Policy0::Operator::Shape;
    using WarpGemm1 = typename Policy1::Operator::Shape;

    /// Shape describing the number of warps filling the CTA
    using WarpCount0 =
            GemmShape<Shape0::kM / WarpGemm0::kM, Shape0::kN / WarpGemm0::kN,
                      Shape0::kK / WarpGemm0::kK>;
    using WarpCount1 =
            GemmShape<Shape1::kM / WarpGemm1::kM, Shape1::kN / WarpGemm1::kN,
                      Shape1::kK / WarpGemm1::kK>;

    /// Number of warp-level GEMM oeprations
    static int const kWarpGemmIterations0 =
            (WarpGemm0::kK / Operator0::Policy::MmaShape::kK);
    static int const kWarpGemmIterations1 =
            (WarpGemm1::kK / Operator1::Policy::MmaShape::kK);

    /// Number of stages
    static int const kStages = Stages;

    //
    // Nested structs
    //

    /// Shared storage object needed by threadblock-scoped GEMM
    template <typename Shape_, typename Policy_>
    class SharedStorage {
    public:
        //
        // Type definitions
        //
        using Shape = Shape_;
        using Policy = Policy_;
        using Operator = typename Policy::Operator;

        /// Tensor reference to the A operand
        using TensorRefA = TensorRef<typename Operator::ElementA,
                                     typename Operator::LayoutA>;

        /// Tensor reference to the B operand
        using TensorRefB = TensorRef<typename Operator::ElementB,
                                     typename Operator::LayoutB>;

        /// Shape of the A matrix operand in shared memory
        using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
                                   Shape::kK * kStages +
                                           Policy::SmemPaddingA::kColumn>;

        /// Shape of the B matrix operand in shared memory
        using ShapeB =
                MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
                            Shape::kN + Policy::SmemPaddingB::kColumn>;

    public:
        //
        // Data members
        //

        /// Buffer for A operand
        AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;

        /// Buffer for B operand
        AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;

    public:
        //
        // Methods
        //

        /// Returns a layout object for the A matrix
        CUTLASS_DEVICE
        static typename Operator::LayoutA LayoutA() {
            return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
        }

        /// Returns a layout object for the B matrix
        CUTLASS_HOST_DEVICE
        static typename Operator::LayoutB LayoutB() {
            return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
        }

        /// Returns a TensorRef to the A operand
        CUTLASS_HOST_DEVICE
        TensorRefA operand_A_ref() {
            return TensorRefA{operand_A.data(), LayoutA()};
        }

        /// Returns a TensorRef to the B operand
        CUTLASS_HOST_DEVICE
        TensorRefB operand_B_ref() {
            return TensorRefB{operand_B.data(), LayoutB()};
        }
    };

    using SharedStorage0 = SharedStorage<Shape0, Policy0>;
    using SharedStorage1 = SharedStorage<Shape1, Policy1>;
    union B2bMmaSharedStorage {
        SharedStorage0 sharedStorage0;
        SharedStorage1 sharedStorage1;
    };

protected:
    //
    // Data members
    //

    /// Iterator to load a warp-scoped tile of A0 operand from shared memory
    typename Operator0::IteratorA warp_tile_iterator_A0_;

    /// Iterator to load a warp-scoped tile of B0 operand from shared memory
    typename Operator0::IteratorB warp_tile_iterator_B0_;

    /// Iterator to load a warp-scoped tile of B0 operand from shared memory
    typename Operator1::IteratorB warp_tile_iterator_B1_;

public:
    /// Construct from tensor references
    CUTLASS_DEVICE
    B2bMmaBase(
            ///< Shared storage needed for internal use by threadblock-scoped
            ///< GEMM
            B2bMmaSharedStorage& shared_storage,
            ///< ID within the threadblock
            int thread_idx,
            ///< ID of warp
            int warp_idx,
            ///< ID of each thread within a warp
            int lane_idx)
            : warp_tile_iterator_A0_(
                      shared_storage.sharedStorage0.operand_A_ref(), lane_idx),
              warp_tile_iterator_B0_(
                      shared_storage.sharedStorage0.operand_B_ref(), lane_idx),
              warp_tile_iterator_B1_(
                      shared_storage.sharedStorage1.operand_B_ref(), lane_idx) {

    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace threadblock
}  // namespace gemm
}  // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
