/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Defines iterators used by warp-level matrix multiply operations
   targeting Tensor Cores.
*/

#pragma once

#include "cutlass/cutlass.h"

#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/matrix_shape.h"

#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"

#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"

#include "cutlass/platform/platform.h"
#include "cutlass/fast_math.h"

////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace warp {

////////////////////////////////////////////////////////////////////////////////

template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Operand identity
        Operand Operand,
        /// Data type of A elements
        typename Element_,
        /// Layout of operand
        typename Layout_,
        /// Shape of one matrix production operation (concept: GemmShape)
        typename InstructionShape_,
        /// Delta between *MMA operations (in units of *MMA operations, concept:
        /// MatrixShape)
        int OpDelta_,
        /// Number of threads participating in one matrix operation
        int Threads,
        /// Number of partitions along K dimension
        int PartitionsK_ = 1>
class MmaTensorOpMultiplicandTileIterator;

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to
/// load from shared memory and therefore must be initialized with a TensorRef
/// to shared memory.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: PitchLinearShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: PitchLinearShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::TensorOpMultiplicandCongruous<
                sizeof_bits<Element_>::value, 64>,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::TensorOpMultiplicandCongruous<
            sizeof_bits<Element_>::value, 64>;

    /// Shape of one matrix product operation (concept: GemmShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// Number of partitions along K dimension
    static int const kPartitionsK = PartitionsK_;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Internal structure of iterator - made public to enable introspection
    struct Policy {
        static_assert(
                !(Shape::kContiguous % InstructionShape::kContiguous),
                "Shape of warp-level Mma must be divisible by operator shape.");

        // Determine number of elements along outer dimension per individual
        // LDSM op
        static int const kLdsmOpOuter = Layout::kElementsPerAccess;
        static int const kLdsmOpInner = 8;

        static_assert(!(Shape::kContiguous % kLdsmOpOuter),
                      "Shape of warp-level mma must be divisible by LDSM's "
                      "fundamental tile size.");

        static_assert(!(Shape::kStrided % kLdsmOpInner),
                      "Shape of warp-level mma must be divisible by LDSM's "
                      "fundamental tile size.");

        /// Shape of one individual LDSM instruction
        static int const LdsmShapeStrided =
                InstructionShape::kStrided / kLdsmOpInner;
        static int const LdsmShapeContiguous = 4 / LdsmShapeStrided;
        using LdsmShape =
                layout::PitchLinearShape<LdsmShapeContiguous, LdsmShapeStrided>;

        /// Number and arrangement of LDSM instructions
        using LdsmIterations =
                layout::PitchLinearShape<Shape::kContiguous /
                                                 Layout::kElementsPerAccess /
                                                 LdsmShapeContiguous,
                                         1>;

        /// Number of groups for each tile
        static int const kGroupsPerTile =
                Shape::kStrided / InstructionShape::kStrided;
    };

private:
    /// Not working on this feature at the moment.
    static_assert(kOpDelta == 1,
                  "Alternative arrangements not supported at present.");

    /// Number of internal pointers needed to reference shared memory
    static int const kPointerCount =
            Layout::TileShape::kContiguous / Policy::LdsmShape::kContiguous;

    /// Pointer type used for accesses
    using AccessType = Array<Element, Layout::kElementsPerAccess>;

    /// Internal counter used to jump to next K partition
    int k_group_idx_;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment =
            Array<Element,
                  Shape::kContiguous * InstructionShape::kStrided / kThreads>;

private:
    /// Layout object storing stride values
    Index stride_;

    /// Shared memory base pointers - not advanced
    AccessType const* pointer_[kPointerCount];

    /// Byte offset incremented as iterator advances
    Index byte_offset_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() : stride_(0), byte_offset_(0) {}

    /// Constructor from TensorRef
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : stride_(ref.stride(0) / Layout::kElementsPerAccess),
              byte_offset_(0),
              k_group_idx_(0) {
        int quad_pair = (lane_id >> 3);
        int quad_quad = (lane_id >> 4);
        int lane_in_quad = (lane_id & 3);
        int lane_in_quad_pair = (lane_id & 7);
        int lane_in_quad_quad = (lane_id & 15);

        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < kPointerCount; ++i) {
            int partition_contiguous_idx = -1;
            int access_contiguous_idx = -1;
            int access_strided_idx = -1;

            if (Policy::LdsmShape::kContiguous == 4) {
                // Matrix multiply 1688 A/B
                // Q0 Q1 Q2 Q3 (Q stands for 1 8x128bit block).
                // Four blocks are next to each other in the contiguous
                // dimension.
                partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ i);
                access_contiguous_idx = (quad_pair ^ lane_in_quad);
                access_strided_idx = lane_in_quad_pair;
            } else if (Policy::LdsmShape::kContiguous == 2 &&
                       kOperand == Operand::kA) {
                // Matrix multiply 16816 A
                // Q0 Q2
                // Q1 Q3
                partition_contiguous_idx =
                        ((lane_in_quad_pair >> 2) ^ (i >> 1));
                access_contiguous_idx =
                        (((quad_pair & 1) + ((i & 1) << 1)) ^ lane_in_quad);
                access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3);
            } else if (Policy::LdsmShape::kContiguous == 2 &&
                       kOperand == Operand::kB) {
                // Matrix multiply 16816 B
                // Q0 Q1
                // Q2 Q3
                partition_contiguous_idx =
                        ((lane_in_quad_pair >> 2) ^ (i >> 1));
                access_contiguous_idx =
                        ((quad_quad + ((i & 1) << 1)) ^ lane_in_quad);
                access_strided_idx = lane_in_quad_quad;
            } else if (Policy::LdsmShape::kContiguous == 1) {
                // Matrix multiply 16832.SP B
                // Q0
                // Q1
                // Q2
                // Q3
                partition_contiguous_idx =
                        ((lane_in_quad_pair >> 2) ^ (i >> 2));
                access_contiguous_idx = ((i & 3) ^ lane_in_quad);
                access_strided_idx = lane_id;
            }

            int access_contiguous =
                    partition_contiguous_idx *
                            Layout::PartitionShape::kContiguous +
                    access_contiguous_idx;

            int access_strided = access_strided_idx;

            pointer_[i] = reinterpret_cast<AccessType const*>(ref.data()) +
                          access_contiguous + access_strided * stride_;
        }
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        byte_offset_ += offset * sizeof(Element);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        int contiguous_offset = tile_offset.contiguous();
        if (Shape::kContiguous ==
            Layout::PartitionShape::kContiguous * Layout::kElementsPerAccess) {
            if (tile_offset.contiguous() % 2) {
                CUTLASS_PRAGMA_UNROLL
                for (int i = 0; i < kPointerCount / 2; ++i) {
                    AccessType const* tmp_pointer = pointer_[i];
                    pointer_[i] = pointer_[i + kPointerCount / 2];
                    pointer_[i + kPointerCount / 2] = tmp_pointer;
                }
            }
            contiguous_offset = (tile_offset.contiguous() >> 1) << 1;
        }

        int offset = (tile_offset.strided() * InstructionShape::kStrided) *
                             stride_ * Layout::kElementsPerAccess +
                     contiguous_offset * Shape::kContiguous;

        add_pointer_offset(offset);

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        add_tile_offset({0, 1});

        if (kPartitionsK > 1) {
            ++k_group_idx_;
            // Jump to next stage
            if (k_group_idx_ == Policy::kGroupsPerTile) {
                k_group_idx_ = 0;
                add_tile_offset(
                        {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)});
            }
        }

        return *this;
    }

    /// Advances the iterator along the opposite of the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) *
                        Layout::kElementsPerAccess;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_byte_offset(frag, 0); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset in units of bytes
            Index byte_offset) const {
        Array<unsigned, Policy::LdsmShape::kCount>* fetch_ptr =
                reinterpret_cast<Array<unsigned, Policy::LdsmShape::kCount>*>(
                        &frag);

        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) {
            CUTLASS_PRAGMA_UNROLL
            for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) {
                int access_idx = c + s * Policy::LdsmIterations::kContiguous;

                AccessType const* source_ptr =
                        pointer_[c % kPointerCount] +
                        Layout::TileShape::kContiguous * (c / kPointerCount) +
                        Policy::LdsmShape::kStrided * s * stride_;

                char const* source_byte_ptr =
                        reinterpret_cast<char const*>(source_ptr) +
                        byte_offset + byte_offset_;

                cutlass::arch::ldsm<layout::ColumnMajor,
                                    Policy::LdsmShape::kCount>(
                        fetch_ptr[access_idx], source_byte_ptr);
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        load_with_byte_offset(frag, tile_offset, 0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, tile_offset,
                              pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        Index pointer_offset =
                tile_offset.contiguous() * Shape::kContiguous /
                        Layout::kElementsPerAccess +
                tile_offset.strided() * InstructionShape::kStrided * stride_;

        byte_offset += sizeof(AccessType) * pointer_offset;

        load_with_byte_offset(frag, byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) {
        // no op
    }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread MMA.TF32 NT TensorOps. It
/// uses LDS.32 to load from shared memory and therefore must be initialized
/// with a TensorRef to shared memory.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: PitchLinearShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: PitchLinearShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::TensorOpMultiplicandCongruous<32, 32>,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(
            kOperand == Operand::kA || kOperand == Operand::kB,
            "MmaTensorOpMultiplicandIterator may only be instantiated for "
            "A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::TensorOpMultiplicandCongruous<32, 32>;

    /// Shape of one matrix product operation (concept: GemmShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// Number of partitions along K dimension
    static int const kPartitionsK = PartitionsK_;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Internal structure of iterator - made public to enable introspection
    struct Policy {
        static_assert(
                !(Shape::kContiguous % InstructionShape::kContiguous),
                "Shape of warp-level Mma must be divisible by operator shape.");

        // Determine number of elements along outer dimension per individual
        // 32bit shared memory load op.  Every one warp of 32bit shared memory
        // load loads 8x4 elements
        static int const kLdsOpInner = Layout::TileShape::kStrided;
        static int const kLdsOpOuter = kThreads / kLdsOpInner;

        static_assert(!(Shape::kContiguous % kLdsOpOuter),
                      "Shape of warp-level mma must be divisible by 32bit "
                      "fundamental tile size.");

        static_assert(!(Shape::kStrided % kLdsOpInner),
                      "Shape of warp-level mma must be divisible by 32bit "
                      "fundamental tile size.");

        /// Number of 32 bit shared memory load instructions needed by one MMA
        /// instruction 1688  A 2x2 1688  B 1x2 16816 B 1x4
        static int const LdsShapeContiguous =
                InstructionShape::kContiguous / kLdsOpOuter;
        static int const LdsShapeStrided =
                InstructionShape::kStrided / kLdsOpInner;
        using LdsShape =
                layout::PitchLinearShape<LdsShapeContiguous, LdsShapeStrided>;

        /// Number and arrangement of LDS instructions
        using LdsIterations = layout::PitchLinearShape<
                Shape::kContiguous / LdsShapeContiguous / kLdsOpOuter, 1>;

        /// Number of groups for each tile
        static int const kGroupsPerTile =
                Shape::kStrided / InstructionShape::kStrided;
    };

private:
    /// Not working on this feature at the moment.
    static_assert(kOpDelta == 1,
                  "Alternative arrangements not supported at present.");

    /// Number of internal pointers needed to reference shared memory
    static int const kPointerCount = Layout::TileShape::kContiguous *
                                     Layout::kElementsPerAccess /
                                     Policy::kLdsOpOuter;

    /// Vectorized access is not used
    static int const kElementsPerAccess = 1;

    /// Pointer type used for accesses
    using AccessType = Element;

    /// Internal counter used to jump to next K partition
    int k_group_idx_;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment =
            Array<Element,
                  Shape::kContiguous * InstructionShape::kStrided / kThreads>;

private:
    /// Layout object storing stride values
    Index stride_;

    /// Shared memory base pointers - not advanced
    AccessType const* pointer_[kPointerCount];

    /// Byte offset incremented as iterator advances
    Index byte_offset_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() : stride_(0), byte_offset_(0) {}

    /// Constructor from TensorRef
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : stride_(ref.stride(0)), byte_offset_(0), k_group_idx_(0) {
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < kPointerCount; ++i) {
            int access_strided = lane_id % Policy::kLdsOpInner;
            int access_contiguous = (lane_id / Policy::kLdsOpInner) +
                                    (access_strided ^ i) * Policy::kLdsOpOuter;

            pointer_[i] = reinterpret_cast<AccessType const*>(ref.data()) +
                          access_contiguous + access_strided * stride_;
        }
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        byte_offset_ += offset * sizeof(Element);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        int contiguous_offset = tile_offset.contiguous();
        if (Shape::kContiguous ==
            Layout::TileShape::kContiguous * Layout::kElementsPerAccess / 2) {
            if (tile_offset.contiguous() % 2) {
                // Matrix multiply 1688 pointer_[0] <=> pointer_[4] pointer_[1]
                // <=> pointer_[5]
                //           pointer_[2] <=> pointer_[6] pointer_[3] <=>
                //           pointer_[7]
                CUTLASS_PRAGMA_UNROLL
                for (int i = 0; i < kPointerCount / 2; ++i) {
                    AccessType const* tmp_pointer = pointer_[i];
                    pointer_[i] = pointer_[i + kPointerCount / 2];
                    pointer_[i + kPointerCount / 2] = tmp_pointer;
                }
            }
            contiguous_offset = (tile_offset.contiguous() >> 1) << 1;
        }

        int offset =
                (tile_offset.strided() * InstructionShape::kStrided) * stride_ +
                contiguous_offset * Shape::kContiguous;

        add_pointer_offset(offset);

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        add_tile_offset({0, 1});

        if (kPartitionsK > 1) {
            ++k_group_idx_;
            // Jump to next stage
            if (k_group_idx_ == Policy::kGroupsPerTile) {
                k_group_idx_ = 0;
                add_tile_offset(
                        {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)});
            }
        }

        return *this;
    }

    /// Advances the iterator along the opposite of the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) *
                        kElementsPerAccess;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_byte_offset(frag, 0); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset in units of bytes
            Index byte_offset) const {
        Element* fetch_ptr = reinterpret_cast<Element*>(&frag);

        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) {
            CUTLASS_PRAGMA_UNROLL
            for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) {
                CUTLASS_PRAGMA_UNROLL
                for (int ss = 0; ss < Policy::LdsShape::kStrided; ++ss) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int cc = 0; cc < Policy::LdsShape::kContiguous; ++cc) {
                        int access_idx =
                                cc +
                                (ss +
                                 (c + s * Policy::LdsIterations::kContiguous) *
                                         Policy::LdsShape::kStrided) *
                                        Policy::LdsShape::kContiguous;
                        int access_idx_contiguous =
                                cc + c * Policy::LdsShape::kContiguous;
                        int access_idx_strided =
                                (ss + s * Policy::LdsShape::kStrided) *
                                Policy::kLdsOpInner;

                        AccessType const* source_ptr =
                                pointer_[access_idx_contiguous %
                                         kPointerCount] +
                                Layout::TileShape::kContiguous *
                                        Layout::kElementsPerAccess *
                                        (access_idx_contiguous /
                                         kPointerCount) +
                                access_idx_strided * stride_;

                        char const* source_byte_ptr =
                                reinterpret_cast<char const*>(source_ptr) +
                                byte_offset + byte_offset_;

                        fetch_ptr[access_idx] =
                                *reinterpret_cast<Element const*>(
                                        source_byte_ptr);
                    }
                }
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        load_with_byte_offset(frag, tile_offset, 0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, tile_offset,
                              pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        Index pointer_offset =
                tile_offset.contiguous() * Shape::kContiguous /
                        Layout::kElementsPerAccess +
                tile_offset.strided() * InstructionShape::kStrided * stride_;

        byte_offset += sizeof(AccessType) * pointer_offset;

        load_with_byte_offset(frag, byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) {
        // no op
    }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to
/// load from shared memory and therefore must be initialized with a TensorRef
/// to shared memory.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous<
                sizeof_bits<Element_>::value, int(128 / sizeof(Element_))>,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(
            kOperand == Operand::kA,
            "MmaTensorOpMultiplicandIterator for ColumnMajor Congruous may "
            "only be instantiated for A operand to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous<
            sizeof_bits<Element_>::value, int(128 / sizeof(Element_))>;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIterator<
            layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, kOperand,
            Element,
            layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
                                                  int(128 / sizeof(Element_))>,
            layout::PitchLinearShape<InstructionShape::kRow,
                                     InstructionShape::kColumn>,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.contiguous(), tile_offset.strided()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to
/// load from shared memory and therefore must be initialized with a TensorRef
/// to shared memory.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::RowMajorTensorOpMultiplicandCongruous<
                sizeof_bits<Element_>::value, int(128 / sizeof(Element_))>,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator for RowMajor Congruous may "
                  "only be instantiated for B operand to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous<
            sizeof_bits<Element_>::value, int(128 / sizeof(Element_))>;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIterator<
            layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, kOperand,
            Element,
            layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
                                                  int(128 / sizeof(Element_))>,
            layout::PitchLinearShape<InstructionShape::kColumn,
                                     InstructionShape::kRow>,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.column(), tile_offset.row()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.column(), tile_offset.row()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.strided(), tile_offset.contiguous()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to
/// load from shared memory and therefore must be initialized with a TensorRef
/// to shared memory.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: PitchLinearShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: PitchLinearShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Element number when the layout crosses (in units of elements)
        int Crosswise,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::TensorOpMultiplicandCrosswise<
                sizeof_bits<Element_>::value, Crosswise>,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(
            kOperand == Operand::kA || kOperand == Operand::kB,
            "MmaTensorOpMultiplicandIterator may only be instantiated for "
            "A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Element number when the layout crosses
    static int const kCrosswise = Crosswise;

    /// Layout of source tile
    using Layout = cutlass::layout::TensorOpMultiplicandCrosswise<
            sizeof_bits<Element_>::value, kCrosswise>;

    /// Shape of one matrix product operation (concept: GemmShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// Number of partitions along K dimension
    static int const kPartitionsK = PartitionsK_;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Internal structure of iterator - made public to enable introspection
    struct Policy {
        static_assert(
                !(Shape::kContiguous % InstructionShape::kContiguous),
                "Shape of warp-level Mma must be divisible by operator shape.");

        // Determine number of elements along outer dimension per individual
        // LDSM op
        static int const kLdsmOpOuter = Layout::kElementsPerAccess;
        static int const kLdsmOpInner = 8;

        static_assert(!(Shape::kContiguous % kLdsmOpOuter),
                      "Shape of warp-level mma must be divisible by LDSM's "
                      "fundamental tile size.");

        static_assert(!(Shape::kStrided % kLdsmOpInner),
                      "Shape of warp-level mma must be divisible by LDSM's "
                      "fundamental tile size.");

        /// Shape of one individual LDSM instruction
        static int const LdsmShapeContiguous =
                InstructionShape::kContiguous / kLdsmOpOuter;
        static int const LdsmShapeStrided =
                ((4 / LdsmShapeContiguous * kLdsmOpInner) > Shape::kStrided)
                        ? (Shape::kStrided / kLdsmOpInner)
                        : (4 / LdsmShapeContiguous);
        using LdsmShape =
                layout::PitchLinearShape<LdsmShapeContiguous, LdsmShapeStrided>;

        /// Number and arrangement of LDSM instructions
        using LdsmIterations =
                layout::PitchLinearShape<1, Shape::kStrided / kLdsmOpInner /
                                                    LdsmShape::kStrided>;

        ///
        static int const kGroupsPerTile = Layout::TileShape::kContiguous /
                                          Layout::kFactor /
                                          LdsmShape::kContiguous;
    };

private:
    /// Not working on this feature at the moment.
    static_assert(kOpDelta == 1,
                  "Alternative arrangements not supported at present.");

    /// Pointer type used for accesses
    using AccessType = Array<Element, Layout::kElementsPerAccess>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment =
            Array<Element,
                  Shape::kStrided * InstructionShape::kContiguous / kThreads>;

private:
    /// Total number of sections.  The memory is divided into stages.  One stage
    /// can store one tile.  Stage is divided into sections.  Interleaved layout
    /// can have multiple sections in a stage.  The rest layout only has one
    /// section in a stage.
    int sections_;

    /// Layout object storing stride values
    Index stride_;

    /// Shared memory base pointers - not advanced
    AccessType const* pointer_;

    /// Byte offset incremented as iterator advances
    Index byte_offset_;

    /// Internal counter used to determine when to increment byte offset and
    /// when to XOR it
    int k_group_idx_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator()
            : pointer_(nullptr),
              sections_(0),
              stride_(0),
              byte_offset_(0),
              k_group_idx_(0) {}

    /// Constructor from TensorRef
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : pointer_(reinterpret_cast<AccessType const*>(ref.data())),
              sections_(ref.stride(0) / kCrosswise),
              // stride_ = kCrosswise x sections_ x kFactor
              stride_(ref.stride(0) * Layout::kFactor /
                      Layout::kElementsPerAccess),
              byte_offset_(0),
              k_group_idx_(0) {
        // Warp level iterator at most use double buffer to hide latency.  If
        // there are more than 2 sections, every stage should have more than 1
        // section.

        // Turing silicon requires all 32 threads in a warp provide valid
        // addresses even for LDSM.1 and LDSM.2
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 750))
        lane_id = lane_id % (Policy::LdsmShape::kCount * Policy::kLdsmOpInner);
#endif

        int quad_quad = (lane_id >> 4);
        int quad_pair = (lane_id >> 3);
        int lane_in_pair = (lane_id & 1);
        int lane_in_quad = (lane_id & 3);
        int lane_in_quad_pair = (lane_id & 7);
        int lane_in_quad_quad = (lane_id & 15);

        int partition_contiguous_idx = -1;
        int access_contiguous_idx = -1;
        int access_strided_idx = -1;

        if (Layout::kFactor == 4) {
            // Super Integer matrix multiply Interleaved-32

            int factor_in_partition =
                    (Layout::PartitionShape::kContiguous * Layout::kFactor /
                     Layout::TileShape::kContiguous);

            if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) {
                // Integer matrix multiply 8816  A/B
                partition_contiguous_idx = lane_in_quad / factor_in_partition;
                access_contiguous_idx = ((lane_in_pair * factor_in_partition) ^
                                         (lane_in_quad_quad / Layout::kFactor));
                access_strided_idx = lane_id / Layout::kFactor;
            } else if (Policy::LdsmShape::kStrided ==
                               (Policy::LdsmShape::kCount / 2) &&
                       kOperand == Operand::kA) {
                // Integer matrix multiply 16832 A
                partition_contiguous_idx = lane_in_quad / factor_in_partition;
                access_strided_idx = lane_in_quad_quad / Layout::kFactor;
                access_contiguous_idx =
                        ((lane_in_pair * factor_in_partition + quad_quad) ^
                         access_strided_idx);
            } else if (Policy::LdsmShape::kStrided ==
                               (Policy::LdsmShape::kCount / 2) &&
                       kOperand == Operand::kB) {
                // Integer matrix multiply 16832 B
                partition_contiguous_idx = lane_in_quad / factor_in_partition;
                access_strided_idx =
                        lane_in_quad_pair / Layout::kFactor + quad_quad * 2;
                access_contiguous_idx = ((lane_in_pair * factor_in_partition +
                                          ((lane_id & 8) >> 3)) ^
                                         access_strided_idx);
            }
        } else if (Layout::kFactor == 2) {
            // Super Matrix multiply kBlock = 32
            if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) {
                // Matrix multiply 1688 A/B
                // (Q stands for 1 8x128bit block).
                // Q0
                // Q1
                // Q2
                // Q3
                // Four blocks are next to each other in the strided dimension.
                partition_contiguous_idx = (lane_id % Layout::kFactor);
                access_contiguous_idx = (lane_in_quad_pair / Layout::kFactor);
                access_strided_idx = lane_id / Layout::kFactor;
            } else if (Policy::LdsmShape::kStrided ==
                               (Policy::LdsmShape::kCount / 2) &&
                       kOperand == Operand::kA) {
                // Matrix multiply 16816|1688.TF32 A
                // Q0 Q2
                // Q1 Q3
                partition_contiguous_idx = (lane_id % Layout::kFactor);
                access_contiguous_idx =
                        (quad_quad ^ (lane_in_quad_pair / Layout::kFactor));
                access_strided_idx = (lane_in_quad_quad / Layout::kFactor);
            } else if (Policy::LdsmShape::kStrided ==
                               (Policy::LdsmShape::kCount / 2) &&
                       kOperand == Operand::kB) {
                // Matrix multiply 16816|1688.TF32 B
                // Q0 Q1
                // Q2 Q3
                partition_contiguous_idx = (lane_id % Layout::kFactor);
                access_contiguous_idx = ((quad_pair & 1) ^
                                         (lane_in_quad_pair / Layout::kFactor));
                access_strided_idx = (lane_in_quad_pair + (lane_id >> 4 << 3)) /
                                     Layout::kFactor;
            } else if (Policy::LdsmShape::kContiguous ==
                       Policy::LdsmShape::kCount) {
                // Matrix multiply 16832.SP B
                // Q0 Q1 Q2 Q3
                partition_contiguous_idx = (lane_id % Layout::kFactor);
                access_contiguous_idx =
                        (quad_pair ^ (lane_in_quad_pair / Layout::kFactor));
                access_strided_idx = lane_in_quad_pair / Layout::kFactor;
            }
        } else if (Layout::kFactor == 1) {
            // Super Matrix multiply kBlock = 64
            if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) {
                // Q0
                // Q1
                // Q2
                // Q3
                partition_contiguous_idx = (lane_in_quad_pair >> 2);
                access_contiguous_idx = lane_in_quad;
                access_strided_idx = lane_id;
            } else if (Policy::LdsmShape::kStrided ==
                               (Policy::LdsmShape::kCount / 2) &&
                       kOperand == Operand::kA) {
                // Matrix multiply 16816|1688.TF32 A
                // Q0 Q2
                // Q1 Q3
                partition_contiguous_idx = (lane_in_quad_pair >> 2);
                access_contiguous_idx = (quad_quad ^ lane_in_quad);
                access_strided_idx = lane_in_quad_quad;
            } else if (Policy::LdsmShape::kStrided ==
                               (Policy::LdsmShape::kCount / 2) &&
                       kOperand == Operand::kB) {
                // Matrix multiply 16816|1688.TF32 B
                // Q0 Q1
                // Q2 Q3
                partition_contiguous_idx = (lane_in_quad_pair >> 2);
                access_contiguous_idx = ((quad_pair & 1) ^ lane_in_quad);
                access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3);
            } else if (Policy::LdsmShape::kContiguous ==
                       Policy::LdsmShape::kCount) {
                // Matrix multiply 16832.SP B
                // Q0 Q1 Q2 Q3
                partition_contiguous_idx = (lane_in_quad_pair >> 2);
                access_contiguous_idx = (quad_pair ^ lane_in_quad);
                access_strided_idx = lane_in_quad_pair;
            }
        }

        int access_contiguous =
                partition_contiguous_idx * Layout::PartitionShape::kContiguous +
                access_contiguous_idx;

        int access_strided = access_strided_idx;

        byte_offset_ = (access_contiguous + access_strided * stride_) *
                       sizeof_bits<Element>::value *
                       Layout::kElementsPerAccess / 8;
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        byte_offset_ += offset * sizeof_bits<Element>::value / 8;

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile;
        int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile;

        byte_offset_ ^= k_groups_delta * sizeof_bits<Element>::value *
                        Layout::kElementsPerAccess *
                        Policy::LdsmShape::kContiguous / 8;
        pointer_ += tile_offset.strided() * stride_ * Shape::kStrided /
                            Layout::kFactor +
                    whole_tiles * stride_ / sections_;
        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset_negative(
            TensorCoord const& tile_offset) {
        int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile;
        int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile;
        if (k_groups_delta < 0) {
            whole_tiles -= 1;
            k_groups_delta += Policy::kGroupsPerTile;
        }

        if ((Policy::kGroupsPerTile / kPartitionsK) >= 2) {
            byte_offset_ ^= (k_groups_delta & 1) *
                            Policy::LdsmShape::kContiguous *
                            sizeof_bits<Element>::value *
                            Layout::kElementsPerAccess / 8;
        }
        if ((Policy::kGroupsPerTile / kPartitionsK) >= 4) {
            byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 1)) & 2) *
                            Policy::LdsmShape::kContiguous *
                            sizeof_bits<Element>::value *
                            Layout::kElementsPerAccess / 8;
        }
        if ((Policy::kGroupsPerTile / kPartitionsK) == 8) {
            byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 3)) & 4) *
                            Policy::LdsmShape::kContiguous *
                            sizeof_bits<Element>::value *
                            Layout::kElementsPerAccess / 8;
        }

        k_group_idx_ += k_groups_delta;
        whole_tiles += k_group_idx_ / (Policy::kGroupsPerTile / kPartitionsK);
        k_group_idx_ = k_group_idx_ % (Policy::kGroupsPerTile / kPartitionsK);

        pointer_ += tile_offset.strided() * stride_ * Shape::kStrided /
                            Layout::kFactor +
                    whole_tiles * stride_ / sections_;
        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        // Integer matrix multiply 16832 Interleaved-32
        //   NONE
        // Integer matrix multiply 16816 Interleaved-32 || Integer matrix
        // multiply 16816 kblock=32

        // Integer matrix multiply 8816  Interleaved-32
        //   ^1 ^1
        // Matrix multiply 1684.TF32 kblock=16 || Integer matrix multiply 16816
        // kblock=64 Matrix multiply 1688 kblock=32 || Integer matrix multiply
        // 8816 kblock=64
        //   ^1 ^3 ^1 ^3
        // Matrix multiply 1688 kblock=64
        //   ^1 ^3 ^1 ^7 ^1 ^3 ^1 ^7

        // Matrix multiply 16816 kblock=32 | 1688.TF32 kblock=16 || Integer
        // matrix multiply 16832 kblock=64
        //   ^2 ^2
        // Matrix multiply 16816 kblock=64 | 1688.TF32 kblock=32 || Integer
        // matrix multiply 16832 kblock=128
        //   ^2 ^6 ^2 ^6

        if ((Policy::kGroupsPerTile / kPartitionsK) > 1) {
            int mask = ((Policy::kGroupsPerTile / kPartitionsK) == 8)
                               ? 3
                               : (((Policy::kGroupsPerTile / kPartitionsK) == 4)
                                          ? 1
                                          : 0);

            if (((k_group_idx_ & mask) % 2) == 0)
                byte_offset_ ^= 1 * Policy::LdsmShape::kContiguous *
                                sizeof_bits<Element>::value *
                                Layout::kElementsPerAccess / 8;
            else if ((k_group_idx_ & mask) == 1)
                byte_offset_ ^= 3 * Policy::LdsmShape::kContiguous *
                                sizeof_bits<Element>::value *
                                Layout::kElementsPerAccess / 8;
            else if ((k_group_idx_ & mask) == 3)
                byte_offset_ ^= 7 * Policy::LdsmShape::kContiguous *
                                sizeof_bits<Element>::value *
                                Layout::kElementsPerAccess / 8;
        }

        k_group_idx_++;

        if (k_group_idx_ == (Policy::kGroupsPerTile / kPartitionsK)) {
            k_group_idx_ = 0;
            add_tile_offset({Policy::kGroupsPerTile, 0});
        }

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() { assert(0); }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_byte_offset(frag, 0); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset in units of bytes
            Index byte_offset) const {
        Array<unsigned, Policy::LdsmShape::kCount>* fetch_ptr =
                reinterpret_cast<Array<unsigned, Policy::LdsmShape::kCount>*>(
                        &frag);

        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) {
            CUTLASS_PRAGMA_UNROLL
            for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) {
                int access_idx = c + s * Policy::LdsmIterations::kContiguous;

                AccessType const* source_ptr =
                        pointer_ + Policy::LdsmShape::kContiguous * c +
                        Policy::kLdsmOpInner / Layout::kFactor *
                                Policy::LdsmShape::kStrided * s * stride_;

                char const* source_byte_ptr =
                        reinterpret_cast<char const*>(source_ptr) +
                        byte_offset + byte_offset_;

                cutlass::arch::ldsm<layout::RowMajor,
                                    Policy::LdsmShape::kCount>(
                        fetch_ptr[access_idx], source_byte_ptr);
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        load_with_byte_offset(frag, tile_offset, 0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, tile_offset,
                              pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        Index pointer_offset =
                tile_offset.contiguous() * InstructionShape::kContiguous /
                        Layout::kElementsPerAccess +
                tile_offset.strided() * Shape::kStrided * stride_;

        byte_offset += sizeof_bits<AccessType>::value * pointer_offset / 8;

        load_with_byte_offset(frag, byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) {
        k_group_idx_ = k_group % (Policy::kGroupsPerTile / kPartitionsK);
    }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to
/// load from shared memory and therefore must be initialized with a TensorRef
/// to shared memory.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Element number when the layout crosses (in units of elements)
        int Crosswise,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
                sizeof_bits<Element_>::value, Crosswise>,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(
            kOperand == Operand::kB,
            "MmaTensorOpMultiplicandIterator for ColumnMajor Crosswise may "
            "only be instantiated for B operand to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// KBlock size
    static int const kCrosswise = Crosswise;

    /// Layout of source tile
    using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
            sizeof_bits<Element_>::value, kCrosswise>;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIterator<
            layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, kOperand,
            Element,
            layout::TensorOpMultiplicandCrosswise<sizeof_bits<Element_>::value,
                                                  kCrosswise>,
            layout::PitchLinearShape<InstructionShape::kRow,
                                     InstructionShape::kColumn>,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset_negative(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset_negative(
                {tile_offset.row(), tile_offset.column()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
        assert(0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
        assert(0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.contiguous(), tile_offset.strided()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to
/// load from shared memory and therefore must be initialized with a TensorRef
/// to shared memory.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Element number when the layout crosses (in units of elements)
        int Crosswise,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
                sizeof_bits<Element_>::value, Crosswise>,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA,
                  "MmaTensorOpMultiplicandIterator for RowMajor Crosswise may "
                  "only be instantiated for A operand to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Element number when the layout crosses
    static int const kCrosswise = Crosswise;

    /// Layout of source tile
    using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
            sizeof_bits<Element_>::value, kCrosswise>;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIterator<
            layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, kOperand,
            Element,
            layout::TensorOpMultiplicandCrosswise<sizeof_bits<Element_>::value,
                                                  kCrosswise>,
            layout::PitchLinearShape<InstructionShape::kColumn,
                                     InstructionShape::kRow>,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset_negative(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset_negative(
                {tile_offset.column(), tile_offset.row()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.column(), tile_offset.row()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.column(), tile_offset.row()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
        assert(0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
        assert(0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.strided(), tile_offset.contiguous()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

////////////////////////////////////////////////////////////////////////////////

template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Element type
        typename Element_,
        /// Layout of operand in memory
        typename Layout_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions, concept: MatrixShape)
        typename OpDelta_>
class MmaTensorOpAccumulatorTileIterator;

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It is used to
/// load or store accumulators from memory and is agnostic to layout. It could
/// be faster if it assumed row-major accumulator layout.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept |
///   WriteableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Element type
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions, concept: MatrixShape)
        typename OpDelta_>
class MmaTensorOpAccumulatorTileIterator<Shape_, Element_,
                                         cutlass::layout::RowMajor,
                                         InstructionShape_, OpDelta_> {
public:
    /// Shape of tile to load (concept: MatrixShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand::kC;

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::RowMajor;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    using OpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Internal structure of iterator - made public to enable introspection
    struct Policy {
        static bool const kDivisible = !(Shape::kRow % InstructionShape::kM) &&
                                       !(Shape::kColumn % InstructionShape::kN);

        static_assert(platform::is_same<TensorCoord, MatrixCoord>::value,
                      "Layouts must be defined for logical MatrixCoord "
                      "coordinate space.");

        /// Number of mma operations performed
        using MmaIterations =
                MatrixShape<(Shape::kRow + InstructionShape::kM - 1) /
                                    InstructionShape::kM,
                            (Shape::kColumn + InstructionShape::kN - 1) /
                                    InstructionShape::kN>;
    };

private:
    // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over
    // the entire shape, with each quad mapped to one row and each thread mapped
    // to 1/4 of the elements of that row. The accumulators within one row are
    // assumed to be consecutive.
    static int const kElementsPerAccess = InstructionShape::kN / 4;
    static int const kRowsPerTile = 8;
    static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = Array<Element, Policy::MmaIterations::kCount *
                                            InstructionShape::kMN / kThreads>;

private:
    /// Reference to output tensor
    TensorRef ref_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator(TensorRef const& ref, int lane_id)
            : ref_(ref) {
        int quad = (lane_id >> 2);
        int lane_in_quad = (lane_id & 3);

        MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess);

        ref_.add_coord_offset(lane_offset);
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& add_pointer_offset(LongIndex offset) {
        ref_.add_pointer_offset(offset);
        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        ref_.add_coord_offset(tile_offset *
                              make_Coord(Shape::kRow, Shape::kColumn));

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator++() {
        // deliberate no-op
        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator--() {
        // deliberate no-op
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            Fragment& frag,  ///< fragment to load from the tensor
            Index pointer_offset)
            const {  ///< loads a tile with a linear offset

        TensorRef offset_ref(ref_);
        offset_ref.add_pointer_offset(pointer_offset);

        CUTLASS_PRAGMA_UNROLL
        for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
            CUTLASS_PRAGMA_UNROLL
            for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
                int mma_accum_start =
                        kAccumulatorRows * kElementsPerAccess *
                        (mma_n * Policy::MmaIterations::kRow + mma_m);

                CUTLASS_PRAGMA_UNROLL
                for (int row = 0; row < kAccumulatorRows; ++row) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int col = 0; col < kElementsPerAccess; ++col) {
                        int accum_m =
                                mma_m * InstructionShape::kM * OpDelta::kRow +
                                row * kRowsPerTile;
                        int accum_n = mma_n * InstructionShape::kN *
                                              OpDelta::kColumn +
                                      col;

                        frag[mma_accum_start + row * kElementsPerAccess + col] =
                                offset_ref.at({accum_m, accum_n});
                    }
                }
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            Fragment& frag,             ///< fragment to load from the tensor
            Index byte_offset) const {  ///< loads a tile with a linear offset

        load_with_pointer_offset(byte_offset / sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(Fragment& frag,  ///< fragment to load from the tensor
              TensorCoord const& tile_offset)
            const {  ///< loads a tile with a logical offset in units of whole
                     ///< tiles

        load(frag, tile_offset, 0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            Fragment& frag,  ///< fragment to load from the tensor
            TensorCoord const& tile_offset,  ///< loads a tile with a logical
                                             ///< offset in units of whole tiles
            Index pointer_offset) const {    ///< loads a tile with a logical
                                             ///< offset AND a pointer offset

        load_with_pointer_offset(frag,
                                 ref_.offset(tile_offset) + pointer_offset);
    }

    /// Stores a fragment to memory
    CUTLASS_HOST_DEVICE
    void store(Fragment const& frag) const {
        store_with_pointer_offset(frag, 0);
    }

    /// Stores a fragment to memory with additional pointer offset
    CUTLASS_DEVICE
    void store_with_pointer_offset(
            Fragment const& frag,  ///< fragment to store from the tensor
            Index pointer_offset)
            const {  ///< store a tile with a linear offset

        TensorRef offset_ref(ref_);
        offset_ref.add_pointer_offset(pointer_offset);

        CUTLASS_PRAGMA_UNROLL
        for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
            CUTLASS_PRAGMA_UNROLL
            for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
                int mma_accum_start =
                        kAccumulatorRows * kElementsPerAccess *
                        (mma_n * Policy::MmaIterations::kRow + mma_m);

                CUTLASS_PRAGMA_UNROLL
                for (int row = 0; row < kAccumulatorRows; ++row) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int col = 0; col < kElementsPerAccess; ++col) {
                        int accum_m =
                                mma_m * InstructionShape::kM * OpDelta::kRow +
                                row * kRowsPerTile;
                        int accum_n = mma_n * InstructionShape::kN *
                                              OpDelta::kColumn +
                                      col;
                        int idx = mma_accum_start + row * kElementsPerAccess +
                                  col;

                        offset_ref.at({accum_m, accum_n}) = frag[idx];
                    }
                }
            }
        }
    }

    /// Stores a fragment to memory with additional pointer offset
    CUTLASS_DEVICE
    void store_with_byte_offset(
            Fragment const& frag,       ///< fragment to store from the tensor
            Index byte_offset) const {  ///< store a tile with a linear offset

        store_with_pointer_offset(byte_offset / sizeof(Element));
    }

    /// Stores a fragment to memory with logical offset in units of whole tiles.
    CUTLASS_DEVICE
    void store(Fragment& frag,  ///< fragment to store to the tensor
               TensorCoord const& tile_offset)
            const {  ///< stores a tile with a logical offset in units of whole
                     ///< tiles

        store(frag, tile_offset, 0);
    }

    /// Stores a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void store(
            /// fragment to store to the tensor
            Fragment const& frag,
            /// stores a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// stores a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        store_with_pointer_offset(frag,
                                  ref_.offset(tile_offset) + pointer_offset);
    }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It is used to
/// load or store accumulators from memory and is agnostic to layout. It could
/// be faster if it assumed row-major accumulator layout.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept |
///   WriteableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Element type
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions, concept: MatrixShape)
        typename OpDelta_>
class MmaTensorOpAccumulatorTileIterator<Shape_, Element_,
                                         cutlass::layout::ColumnMajor,
                                         InstructionShape_, OpDelta_> {
public:
    /// Shape of tile to load (concept: MatrixShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand::kC;

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::ColumnMajor;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    using OpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Internal structure of iterator - made public to enable introspection
    struct Policy {
        static bool const kDivisible = !(Shape::kRow % InstructionShape::kM) &&
                                       !(Shape::kColumn % InstructionShape::kN);

        static_assert(platform::is_same<TensorCoord, MatrixCoord>::value,
                      "Layouts must be defined for logical MatrixCoord "
                      "coordinate space.");

        /// Number of mma operations performed
        using MmaIterations =
                MatrixShape<(Shape::kRow + InstructionShape::kM - 1) /
                                    InstructionShape::kM,
                            (Shape::kColumn + InstructionShape::kN - 1) /
                                    InstructionShape::kN>;
    };

private:
    // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over
    // the entire shape, with each quad mapped to one row and each thread mapped
    // to 1/4 of the elements of that row. The accumulators within one row are
    // assumed to be consecutive.
    static int const kElementsPerAccess = InstructionShape::kN / 4;
    static int const kRowsPerTile = 8;
    static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = Array<Element, Policy::MmaIterations::kCount *
                                            InstructionShape::kMN / kThreads>;

private:
    /// Reference to output tensor
    TensorRef ref_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator(TensorRef const& ref, int lane_id)
            : ref_(ref) {
        int quad = (lane_id >> 2);
        int lane_in_quad = (lane_id & 3);

        MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess);

        ref_.add_coord_offset(lane_offset);
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& add_pointer_offset(LongIndex offset) {
        ref_.add_pointer_offset(offset);
        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        ref_.add_coord_offset(tile_offset *
                              make_Coord(Shape::kRow, Shape::kColumn));

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator++() {
        // deliberate no-op
        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator--() {
        // deliberate no-op
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            Fragment& frag,  ///< fragment to load from the tensor
            Index pointer_offset)
            const {  ///< loads a tile with a linear offset

        TensorRef offset_ref(ref_);
        offset_ref.add_pointer_offset(pointer_offset);

        CUTLASS_PRAGMA_UNROLL
        for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
            CUTLASS_PRAGMA_UNROLL
            for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
                int mma_accum_start =
                        kAccumulatorRows * kElementsPerAccess *
                        (mma_n * Policy::MmaIterations::kRow + mma_m);

                CUTLASS_PRAGMA_UNROLL
                for (int row = 0; row < kAccumulatorRows; ++row) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int col = 0; col < kElementsPerAccess; ++col) {
                        int accum_m =
                                mma_m * InstructionShape::kM * OpDelta::kRow +
                                row * kRowsPerTile;
                        int accum_n = mma_n * InstructionShape::kN *
                                              OpDelta::kColumn +
                                      col;
                        int idx = mma_accum_start + row * kElementsPerAccess +
                                  col;

                        frag[idx] = offset_ref.at({accum_m, accum_n});
                    }
                }
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            Fragment& frag,             ///< fragment to load from the tensor
            Index byte_offset) const {  ///< loads a tile with a linear offset

        load_with_pointer_offset(byte_offset / sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(Fragment& frag,  ///< fragment to load from the tensor
              TensorCoord const& tile_offset)
            const {  ///< loads a tile with a logical offset in units of whole
                     ///< tiles

        load(frag, tile_offset, 0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            Fragment& frag,  ///< fragment to load from the tensor
            TensorCoord const& tile_offset,  ///< loads a tile with a logical
                                             ///< offset in units of whole tiles
            Index pointer_offset) const {    ///< loads a tile with a logical
                                             ///< offset AND a pointer offset

        load_with_pointer_offset(frag,
                                 ref_.offset(tile_offset) + pointer_offset);
    }

    /// Stores a fragment to memory
    CUTLASS_HOST_DEVICE
    void store(Fragment const& frag) const {
        store_with_pointer_offset(frag, 0);
    }

    /// Stores a fragment to memory with additional pointer offset
    CUTLASS_DEVICE
    void store_with_pointer_offset(
            Fragment const& frag,  ///< fragment to store from the tensor
            Index pointer_offset)
            const {  ///< store a tile with a linear offset

        TensorRef offset_ref(ref_);
        offset_ref.add_pointer_offset(pointer_offset);

        CUTLASS_PRAGMA_UNROLL
        for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
            CUTLASS_PRAGMA_UNROLL
            for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
                int mma_accum_start =
                        kAccumulatorRows * kElementsPerAccess *
                        (mma_n * Policy::MmaIterations::kRow + mma_m);

                CUTLASS_PRAGMA_UNROLL
                for (int row = 0; row < kAccumulatorRows; ++row) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int col = 0; col < kElementsPerAccess; ++col) {
                        int accum_m =
                                mma_m * InstructionShape::kM * OpDelta::kRow +
                                row * kRowsPerTile;
                        int accum_n = mma_n * InstructionShape::kN *
                                              OpDelta::kColumn +
                                      col;
                        int idx = mma_accum_start + row * kElementsPerAccess +
                                  col;

                        offset_ref.at({accum_m, accum_n}) = frag[idx];
                    }
                }
            }
        }
    }

    /// Stores a fragment to memory with additional pointer offset
    CUTLASS_DEVICE
    void store_with_byte_offset(
            Fragment const& frag,       ///< fragment to store from the tensor
            Index byte_offset) const {  ///< store a tile with a linear offset

        store_with_pointer_offset(byte_offset / sizeof(Element));
    }

    /// Stores a fragment to memory with logical offset in units of whole tiles.
    CUTLASS_DEVICE
    void store(Fragment& frag,  ///< fragment to store to the tensor
               TensorCoord const& tile_offset)
            const {  ///< stores a tile with a logical offset in units of whole
                     ///< tiles

        store(frag, tile_offset, 0);
    }

    /// Stores a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void store(
            /// fragment to store to the tensor
            Fragment const& frag,
            /// stores a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// stores a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        store_with_pointer_offset(frag,
                                  ref_.offset(tile_offset) + pointer_offset);
    }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It is used to
/// load or store accumulators from memory and is agnostic to layout. It could
/// be faster if it assumed row-major accumulator layout.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept |
///   WriteableRandomAccessContiguousTileIteratorConcept
///

template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Element typ
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions, concept: MatrixShape)
        typename OpDelta_,
        /// Interleaved N
        int InterleavedN>
class MmaTensorOpAccumulatorTileIterator<
        Shape_, Element_, cutlass::layout::ColumnMajorInterleaved<InterleavedN>,
        InstructionShape_, OpDelta_> {
public:
    /// Shape of tile to load (concept: MatrixShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand::kC;

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::ColumnMajorInterleaved<InterleavedN>;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    using OpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Internal structure of iterator - made public to enable introspection
    struct Policy {
        static_assert(
                !(Shape::kRow % InstructionShape::kM) &&
                        !(Shape::kColumn % InstructionShape::kN),
                "Shape of warp-level Mma must be divisible by operator shape.");

        static_assert(platform::is_same<TensorCoord, MatrixCoord>::value,
                      "Layouts must be defined for logical MatrixCoord "
                      "coordinate space.");

        /// Number of mma operations performed
        using MmaIterations =
                MatrixShape<Shape::kRow / InstructionShape::kM,
                            Shape::kColumn / InstructionShape::kN>;
    };

private:
    static int const kElementsPerAccess = 2;

public:
    //
    // Derived quantities
    //

    using AccessType = Array<Element, kElementsPerAccess>;

    /// Fragment object holding a thread's part of a tile
    using Fragment = Array<Element, Shape::kCount / kThreads>;

private:
    /// Reference to output tensor
    TensorRef ref_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator(TensorRef const& ref, int lane_id)
            : ref_(ref) {
        int quad = (lane_id >> 2);
        int lane_in_quad = (lane_id & 3);

        MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess);

        ref_.add_coord_offset(lane_offset);
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& add_pointer_offset(LongIndex offset) {
        ref_.add_pointer_offset(offset);
        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        ref_.add_coord_offset(tile_offset *
                              make_Coord(Shape::kRow, Shape::kColumn));

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator++() {
        // deliberate no-op
        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator--() {
        // deliberate no-op
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            Fragment& frag,  ///< fragment to load from the tensor
            Index pointer_offset)
            const {  ///< loads a tile with a linear offset

        TensorRef offset_ref(ref_);
        offset_ref.add_pointer_offset(pointer_offset);

        AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);

        CUTLASS_PRAGMA_UNROLL
        for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
            CUTLASS_PRAGMA_UNROLL
            for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
                int accum_m = mma_m * InstructionShape::kM;
                int accum_n = mma_n * InstructionShape::kN;

                int idx = mma_m + mma_n * Policy::MmaIterations::kRow;

                AccessType* access_ptr = reinterpret_cast<AccessType*>(
                        offset_ref.data() +
                        offset_ref.offset(TensorCoord(accum_m, accum_n)));

                frag_ptr[idx] = access_ptr[0];
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            Fragment& frag,             ///< fragment to load from the tensor
            Index byte_offset) const {  ///< loads a tile with a linear offset

        load_with_pointer_offset(byte_offset / sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(Fragment& frag,  ///< fragment to load from the tensor
              TensorCoord const& tile_offset)
            const {  ///< loads a tile with a logical offset in units of whole
                     ///< tiles

        load(frag, tile_offset, 0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            Fragment& frag,  ///< fragment to load from the tensor
            TensorCoord const& tile_offset,  ///< loads a tile with a logical
                                             ///< offset in units of whole tiles
            Index pointer_offset) const {    ///< loads a tile with a logical
                                             ///< offset AND a pointer offset

        load_with_pointer_offset(frag,
                                 ref_.offset(tile_offset) + pointer_offset);
    }

    /// Stores a fragment to memory
    CUTLASS_HOST_DEVICE
    void store(Fragment const& frag) const {
        store_with_pointer_offset(frag, 0);
    }

    /// Stores a fragment to memory with additional pointer offset
    CUTLASS_DEVICE
    void store_with_pointer_offset(
            Fragment const& frag,  ///< fragment to store from the tensor
            Index pointer_offset)
            const {  ///< store a tile with a linear offset

        TensorRef offset_ref(ref_);
        offset_ref.add_pointer_offset(pointer_offset);

        AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);

        CUTLASS_PRAGMA_UNROLL
        for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
            CUTLASS_PRAGMA_UNROLL
            for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
                int accum_m = mma_m * InstructionShape::kM;
                int accum_n = mma_n * InstructionShape::kN;

                int idx = mma_m + mma_n * Policy::MmaIterations::kRow;

                AccessType* access_ptr = reinterpret_cast<AccessType*>(
                        offset_ref.data() +
                        offset_ref.offset(TensorCoord(accum_m, accum_n)));

                access_ptr[0] = frag_ptr[idx];
            }
        }
    }

    /// Stores a fragment to memory with additional pointer offset
    CUTLASS_DEVICE
    void store_with_byte_offset(
            Fragment const& frag,       ///< fragment to store from the tensor
            Index byte_offset) const {  ///< store a tile with a linear offset

        store_with_pointer_offset(byte_offset / sizeof(Element));
    }

    /// Stores a fragment to memory with logical offset in units of whole tiles.
    CUTLASS_DEVICE
    void store(Fragment& frag,  ///< fragment to store to the tensor
               TensorCoord const& tile_offset)
            const {  ///< stores a tile with a logical offset in units of whole
                     ///< tiles

        store(frag, tile_offset, 0);
    }

    /// Stores a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void store(
            /// fragment to store to the tensor
            Fragment const& frag,
            /// stores a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// stores a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        store_with_pointer_offset(frag,
                                  ref_.offset(tile_offset) + pointer_offset);
    }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It is used to
/// load or store accumulators from memory and is agnostic to layout. It could
/// be faster if it assumed row-major accumulator layout.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept |
///   WriteableRandomAccessContiguousTileIteratorConcept
///

template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Element typ
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions, concept: MatrixShape)
        typename OpDelta_,
        /// Interleaved N
        int InterleavedN>
class MmaTensorOpAccumulatorTileIterator<
        Shape_, Element_, cutlass::layout::TensorNCxHWx<InterleavedN>,
        InstructionShape_, OpDelta_> {
public:
    /// Shape of tile to load (concept: MatrixShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand::kC;

    /// Element type
    using Element = int8_t;

    /// Layout of source tile
    using Layout = cutlass::layout::TensorNCxHWx<InterleavedN>;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    using OpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Internal structure of iterator - made public to enable introspection
    struct Policy {
        static_assert(
                !(Shape::kRow % InstructionShape::kM) &&
                        !(Shape::kColumn % InstructionShape::kN),
                "Shape of warp-level Mma must be divisible by operator shape.");

        /// Number of elements in strided dimension that each STG writes
        static int const kStridedPerSTG = 8;

        /// Factor to calculate reorder index to pack accumulator.
        static int const kPackedFactor = Shape::kColumn / 32;

        /// Number of mma operations performed
        using MmaIterations = MatrixShape<Shape::kRow / kStridedPerSTG,
                                          Shape::kColumn / InterleavedN>;
    };

private:
    static int const kElementsPerAccess = InterleavedN / 4;

public:
    //
    // Derived quantities
    //

    struct alignas((kElementsPerAccess * sizeof_bits<Element>::value /
                    8)) AccessType {
        Array<Element, kElementsPerAccess> storage;
    };

    /// Fragment object holding a thread's part of a tile
    using Fragment = Array<int32_t, Shape::kCount / kThreads>;

private:
    /// Reference to output tensor
    TensorRef ref_;

    /// Row offset index globally
    LongIndex global_offset_row_;

    /// Column offset index globally
    LongIndex global_offset_col_;

    /// Output tensor size
    TensorCoord extent_;

    /// Alpha
    float alpha_;

    /// Beta
    float beta_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator(TensorRef const& ref, int const lane_id,
                                       TensorCoord extent, float alpha = 1.0f,
                                       float beta = 0.0f)
            : ref_(ref), extent_(extent), alpha_(alpha), beta_(beta) {
        int quad = (lane_id >> 2);
        int lane_in_quad = (lane_id & 3);

        global_offset_row_ = quad;

        global_offset_col_ = lane_in_quad * kElementsPerAccess;
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& add_pointer_offset(LongIndex offset) {
        ref_.add_pointer_offset(offset);
        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& add_tile_offset(
            MatrixCoord const& tile_offset) {
        global_offset_row_ += tile_offset.row() * Shape::kRow;

        global_offset_col_ += tile_offset.column() * Shape::kColumn;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator++() {
        // deliberate no-op
        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator--() {
        // deliberate no-op
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpAccumulatorTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_pointer_offset(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            Fragment& frag,  ///< fragment to load from the tensor
            Index pointer_offset)
            const {  ///< loads a tile with a linear offset

        TensorRef offset_ref(ref_);
        offset_ref.add_pointer_offset(pointer_offset);

        AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);

        CUTLASS_PRAGMA_UNROLL
        for (int mma_n = 0; mma_n < Policy::MmaIterations::kN; ++mma_n) {
            CUTLASS_PRAGMA_UNROLL
            for (int mma_m = 0; mma_m < Policy::MmaIterations::kM; ++mma_m) {
                int accum_m = mma_m * InstructionShape::kM;
                int accum_n = mma_n * InstructionShape::kN;

                int idx = mma_m + mma_n * Policy::MmaIterations::kM;

                AccessType* access_ptr = reinterpret_cast<AccessType*>(
                        offset_ref.data() + accum_m * offset_ref.stride(0) +
                        accum_n);

                frag_ptr[idx] = access_ptr[0];
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            Fragment& frag,             ///< fragment to load from the tensor
            Index byte_offset) const {  ///< loads a tile with a linear offset

        load_with_pointer_offset(byte_offset / sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(Fragment& frag,  ///< fragment to load from the tensor
              TensorCoord const& tile_offset)
            const {  ///< loads a tile with a logical offset in units of whole
                     ///< tiles

        load(frag, tile_offset, 0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            Fragment& frag,  ///< fragment to load from the tensor
            TensorCoord const& tile_offset,  ///< loads a tile with a logical
                                             ///< offset in units of whole tiles
            Index pointer_offset) const {    ///< loads a tile with a logical
                                             ///< offset AND a pointer offset

        load_with_pointer_offset(frag,
                                 ref_.offset(tile_offset) + pointer_offset);
    }

    /// Stores a fragment to memory
    CUTLASS_HOST_DEVICE
    void store(Fragment const& frag) const {
        store_with_pointer_offset(frag, 0);
    }

    /// Stores a fragment to memory with additional pointer offset
    CUTLASS_DEVICE
    void store_with_pointer_offset(
            Fragment const& frag,  ///< fragment to store from the tensor
            Index pointer_offset)
            const {  ///< store a tile with a linear offset

        TensorRef offset_ref(ref_);
        offset_ref.add_pointer_offset(pointer_offset);

        Array<float, Shape::kCount / kThreads> output_frag_f;
        Array<Element, Shape::kCount / kThreads> output_frag;

        LongIndex pq = extent_.h() * extent_.w();

        LongIndex extent_row = extent_.n() * pq;
        LongIndex extent_col = extent_.c();

        LongIndex k_major = (global_offset_col_ / InterleavedN) * pq;
        Index k_minor = global_offset_col_ % InterleavedN;
        LongIndex k_offset = k_major * InterleavedN + k_minor;
        LongIndex k_offset_delta = pq * InterleavedN;

        LongIndex stride_n = pq * extent_.c();

        Index n;
        LongIndex pq_rem;

        unsigned int pq_mul, pq_shr;
        find_divisor(pq_mul, pq_shr, pq);

        if (beta_ == 0.0f) {
            CUTLASS_PRAGMA_UNROLL
            for (int i = 0; i < frag.size(); ++i) {
                output_frag_f[i] = frag[i];
            }

            if (InstructionShape::kM == Policy::kStridedPerSTG) {
                CUTLASS_PRAGMA_UNROLL
                for (int i = 0; i < frag.size(); ++i) {
                    output_frag[i] = (Element)(output_frag_f[i] * alpha_);
                }
            } else {
                CUTLASS_PRAGMA_UNROLL
                for (int i = 0; i < frag.size(); ++i) {
                    int map_i = (i / (16 * Policy::kPackedFactor)) *
                                        (16 * Policy::kPackedFactor) +
                                (i % (8 * Policy::kPackedFactor)) / 2 * 4 +
                                (i % (8 * Policy::kPackedFactor)) % 2 +
                                (i / (8 * Policy::kPackedFactor)) % 2 * 2;
                    output_frag[i] = (Element)(output_frag_f[map_i] * alpha_);
                }
            }

            AccessType const* frag_ptr =
                    reinterpret_cast<AccessType const*>(&output_frag);

            CUTLASS_PRAGMA_UNROLL
            for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
                int accum_m = mma_m * Policy::kStridedPerSTG;

                fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul,
                            pq_shr);
                LongIndex offset_m =
                        n * stride_n + k_offset + pq_rem * InterleavedN;

                CUTLASS_PRAGMA_UNROLL
                for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn;
                     ++mma_n) {
                    int accum_n = mma_n * InterleavedN;

                    int idx = mma_n + mma_m * Policy::MmaIterations::kColumn;

                    if ((global_offset_row_ + accum_m < extent_row) &&
                        (global_offset_col_ + accum_n < extent_col)) {
                        AccessType* access_ptr = reinterpret_cast<AccessType*>(
                                offset_ref.data() + offset_m +
                                mma_n * k_offset_delta);

                        access_ptr[0] = frag_ptr[idx];
                    }
                }
            }
        } else {
            if (InstructionShape::kM == Policy::kStridedPerSTG) {
                CUTLASS_PRAGMA_UNROLL
                for (int i = 0; i < frag.size(); ++i) {
                    output_frag_f[i] = frag[i];
                }
            } else {
                CUTLASS_PRAGMA_UNROLL
                for (int i = 0; i < frag.size(); ++i) {
                    int map_i = (i / (16 * Policy::kPackedFactor)) *
                                        (16 * Policy::kPackedFactor) +
                                (i % (8 * Policy::kPackedFactor)) / 2 * 4 +
                                (i % (8 * Policy::kPackedFactor)) % 2 +
                                (i / (8 * Policy::kPackedFactor)) % 2 * 2;
                    output_frag_f[i] = frag[map_i];
                }
            }

            AccessType const* frag_ptr =
                    reinterpret_cast<AccessType const*>(&output_frag);

            Array<Element, kElementsPerAccess> ref_frag;
            AccessType* ref_frag_ptr = reinterpret_cast<AccessType*>(&ref_frag);

            CUTLASS_PRAGMA_UNROLL
            for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
                int accum_m = mma_m * Policy::kStridedPerSTG;

                fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul,
                            pq_shr);
                LongIndex offset_m =
                        n * stride_n + k_offset + pq_rem * InterleavedN;

                CUTLASS_PRAGMA_UNROLL
                for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn;
                     ++mma_n) {
                    int accum_n = mma_n * InterleavedN;

                    int idx = mma_n + mma_m * Policy::MmaIterations::kColumn;

                    if ((global_offset_row_ + accum_m < extent_row) &&
                        (global_offset_col_ + accum_n < extent_col)) {
                        AccessType* access_ptr = reinterpret_cast<AccessType*>(
                                offset_ref.data() + offset_m +
                                mma_n * k_offset_delta);

                        ref_frag_ptr[0] = access_ptr[0];

                        CUTLASS_PRAGMA_UNROLL
                        for (int i = 0; i < kElementsPerAccess; ++i) {
                            output_frag[idx * kElementsPerAccess + i] = Element(
                                    alpha_ * output_frag_f
                                                     [idx * kElementsPerAccess +
                                                      i] +
                                    beta_ * ref_frag[i]);
                        }

                        access_ptr[0] = frag_ptr[idx];
                    }
                }
            }
        }
    }

    /// Stores a fragment to memory with additional pointer offset
    CUTLASS_DEVICE
    void store_with_byte_offset(
            Fragment const& frag,       ///< fragment to store from the tensor
            Index byte_offset) const {  ///< store a tile with a linear offset

        store_with_pointer_offset(byte_offset / sizeof(Element));
    }

    /// Stores a fragment to memory with logical offset in units of whole tiles.
    CUTLASS_DEVICE
    void store(Fragment& frag,  ///< fragment to store to the tensor
               TensorCoord const& tile_offset)
            const {  ///< stores a tile with a logical offset in units of whole
                     ///< tiles

        store(frag, tile_offset, 0);
    }

    /// Stores a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void store(
            /// fragment to store to the tensor
            Fragment const& frag,
            /// stores a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// stores a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        store_with_pointer_offset(frag,
                                  ref_.offset(tile_offset) + pointer_offset);
    }
};

////////////////////////////////////////////////////////////////////////////////

}  // namespace warp
}  // namespace gemm
}  // namespace cutlass

////////////////////////////////////////////////////////////////////////////////
