/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Defines iterators used by warp-level matrix multiply operations
   targeting Tensor Cores.
*/

#pragma once

#include "cutlass/cutlass.h"

#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/matrix_shape.h"

#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"

#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor_op_multiplicand_sm80.h"

#include "cutlass/platform/platform.h"
#include "cutlass/fast_math.h"

#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"

////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace warp {

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for loading 128b vectors of 64b elements.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: PitchLinearShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: PitchLinearShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::TensorOpMultiplicandCongruous64b, InstructionShape_,
        OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    static_assert(!(Shape::kContiguous % 16) && !(Shape::kStrided % 4),
                  "Divisibility.");

    static_assert(sizeof_bits<Element_>::value == 64,
                  "This is specialized for 64b accesses.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::TensorOpMultiplicandCongruous64b;

    /// Shape of one matrix product operation (concept: GemmShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// Number of partitions along K dimension
    static int const kPartitionsK = PartitionsK_;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Load two elements per access
    static int const kElementsPerAccess = 2;

    /// Policy defining internal details of tile iterator
    struct Policy {
        /// Shape of one access
        using Delta = layout::PitchLinearShape<8, 4>;

        /// Number of iterations to load
        using Iterations = layout::PitchLinearShape<
                Shape::kContiguous / kElementsPerAccess / Delta::kContiguous,
                InstructionShape::kStrided / Delta::kStrided>;
    };

private:
    /// Not working on this feature at the moment.
    static_assert(kOpDelta == 1,
                  "Alternative arrangements not supported at present.");

    /// Pointer type used for accesses
    using AccessType = AlignedArray<Element, kElementsPerAccess, 16>;

    /// Internal counter used to jump to next K partition
    int k_group_idx_;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment =
            Array<Element,
                  Shape::kContiguous * InstructionShape::kStrided / kThreads>;

private:
    /// Layout object storing stride values
    Index stride_;

    /// Shared memory base pointers - not advanced
    AccessType const* pointer_;

    /// Byte offset incremented as iterator advances
    Index byte_offset_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() : stride_(0), byte_offset_(0) {}

    /// Constructor from TensorRef
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : stride_(ref.stride(0) / kElementsPerAccess),
              byte_offset_(0),
              k_group_idx_(0) {
        int access_strided = lane_id / Policy::Delta::kContiguous;
        int access_contiguous =
                (lane_id % Policy::Delta::kContiguous) ^ access_strided;

        pointer_ = reinterpret_cast<AccessType const*>(ref.data()) +
                   access_contiguous + access_strided * stride_;
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        byte_offset_ += offset * sizeof(Element);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        int offset = (tile_offset.strided() * InstructionShape::kStrided) *
                             stride_ * kElementsPerAccess +
                     tile_offset.contiguous() * Shape::kContiguous;

        add_pointer_offset(offset);

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        add_tile_offset({0, 1});

        return *this;
    }

    /// Advances the iterator along the opposite of the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        add_tile_offset({0, -1});

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_byte_offset(frag, 0); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset in units of bytes
            Index byte_offset) const {
        AccessType* fetch_ptr = reinterpret_cast<AccessType*>(&frag);

        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < Policy::Iterations::kStrided; ++s) {
            CUTLASS_PRAGMA_UNROLL
            for (int c = 0; c < Policy::Iterations::kContiguous; ++c) {
                int access_idx = c + s * Policy::Iterations::kContiguous;

                AccessType const* source_ptr =
                        pointer_ + Policy::Delta::kContiguous * c +
                        Policy::Delta::kStrided * s * stride_;

                char const* source_byte_ptr =
                        reinterpret_cast<char const*>(source_ptr) +
                        byte_offset + byte_offset_;

                AccessType const* source =
                        reinterpret_cast<AccessType const*>(source_byte_ptr);

                fetch_ptr[access_idx] = *source;
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        load_with_byte_offset(frag, tile_offset, 0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, tile_offset,
                              pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        Index pointer_offset =
                tile_offset.contiguous() * Shape::kContiguous /
                        Layout::kElementsPerAccess +
                tile_offset.strided() * InstructionShape::kStrided * stride_;

        byte_offset += sizeof(AccessType) * pointer_offset;

        load_with_byte_offset(frag, byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) {}
};

////////////////////////////////////////////////////////////////////////////////

///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIterator<
            layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, kOperand,
            Element, layout::TensorOpMultiplicandCongruous64b,
            layout::PitchLinearShape<InstructionShape::kColumn,
                                     InstructionShape::kRow>,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.column(), tile_offset.row()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.column(), tile_offset.row()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.strided(), tile_offset.contiguous()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to
/// load from shared memory and therefore must be initialized with a TensorRef
/// to shared memory.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIterator<
            layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, kOperand,
            Element, layout::TensorOpMultiplicandCongruous64b,
            layout::PitchLinearShape<InstructionShape::kRow,
                                     InstructionShape::kColumn>,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.contiguous(), tile_offset.strided()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////

/// This tile iterator is specialized for loading 128b vectors of 64b elements.
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: PitchLinearShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: PitchLinearShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::TensorOpMultiplicand64bCrosswise, InstructionShape_,
        OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 16),
                  "Divisibility.");

    static_assert(sizeof_bits<Element_>::value == 64,
                  "This is specialized for 64b accesses.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::TensorOpMultiplicand64bCrosswise;

    /// Shape of one matrix product operation (concept: GemmShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// Number of partitions along K dimension
    static int const kPartitionsK = PartitionsK_;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Load two elements per access
    static int const kElementsPerAccess = 2;

    /// Policy defining internal details of tile iterator
    struct Policy {
        /// Shape of one access
        using Delta = layout::PitchLinearShape<4, 16>;

        /// Number of iterations to load
        using Iterations =
                layout::PitchLinearShape<InstructionShape::kContiguous /
                                                 Delta::kContiguous,
                                         Shape::kStrided / Delta::kStrided>;
    };

private:
    /// Not working on this feature at the moment.
    static_assert(kOpDelta == 1,
                  "Alternative arrangements not supported at present.");

    /// Pointer type used for accesses
    using AccessType = AlignedArray<Element, kElementsPerAccess, 16>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment =
            Array<Element,
                  Shape::kStrided * InstructionShape::kContiguous / kThreads>;

private:
    /// Layout object storing stride values
    Index stride_;

    /// Shared memory base pointers - not advanced
    AccessType const* pointer_;

    /// Byte offset incremented as iterator advances
    Index byte_offset_;

    /// Internal counter for tracking K-group
    Index k_group_idx_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() : stride_(0), byte_offset_(0) {}

    /// Constructor from TensorRef
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : stride_(ref.stride(0) / kElementsPerAccess),
              byte_offset_(0),
              k_group_idx_(0) {
        int access_strided = lane_id / 8;
        int access_contiguous = (lane_id % 8);

        byte_offset_ = (access_contiguous + access_strided * stride_) *
                       sizeof(AccessType);

        pointer_ = reinterpret_cast<AccessType const*>(ref.data());
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        pointer_ += offset / kElementsPerAccess;

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        int offset =
                (tile_offset.contiguous() * InstructionShape::kContiguous) *
                        stride_ * kElementsPerAccess +
                tile_offset.strided() * Shape::kStrided;

        add_pointer_offset(offset);

        int old_k_group_idx = k_group_idx_;

        k_group_idx_ += tile_offset.contiguous();

        if ((k_group_idx_ & 2) ^ (old_k_group_idx & 2)) {
            byte_offset_ ^= 0x40;
        }

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        pointer_ += stride_ * InstructionShape::kContiguous;

        if (k_group_idx_ & 0x1) {
            // xor ptr
            byte_offset_ ^= 0x40;
        }

        ++k_group_idx_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_byte_offset(frag, 0); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset in units of bytes
            Index byte_offset) const {
        AccessType* fetch_ptr = reinterpret_cast<AccessType*>(&frag);

        CUTLASS_PRAGMA_UNROLL
        for (int c = 0; c < Policy::Iterations::kContiguous; ++c) {
            CUTLASS_PRAGMA_UNROLL
            for (int s = 0; s < Policy::Iterations::kStrided; ++s) {
                int access_idx = c + s * Policy::Iterations::kContiguous;

                AccessType const* source_ptr =
                        pointer_ + Policy::Delta::kContiguous * c * stride_ +
                        Policy::Delta::kStrided * s / kElementsPerAccess;

                char const* source_byte_ptr =
                        reinterpret_cast<char const*>(source_ptr) +
                        byte_offset + byte_offset_;

                AccessType const* source =
                        reinterpret_cast<AccessType const*>(source_byte_ptr);

                fetch_ptr[access_idx] = *source;
            }
        }

        Element* exchange_ptr = reinterpret_cast<Element*>(&frag);

        if (k_group_idx_ & 1) {
            // exchange on 64b granularity
            CUTLASS_PRAGMA_UNROLL
            for (int i = 0; i < Fragment::kElements; i += 2) {
                Element tmp = exchange_ptr[i];
                exchange_ptr[i] = exchange_ptr[i + 1];
                exchange_ptr[i + 1] = tmp;
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        load_with_byte_offset(frag, tile_offset, 0);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        load_with_byte_offset(frag, tile_offset,
                              pointer_offset * sizeof(Element));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        Index pointer_offset =
                tile_offset.contiguous() * InstructionShape::kContiguous /
                        Layout::kElementsPerAccess +
                tile_offset.strided() * Shape::kStrided * stride_;

        byte_offset += sizeof(AccessType) * pointer_offset;

        load_with_byte_offset(frag, byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { k_group_idx_ = k_group; }
};

////////////////////////////////////////////////////////////////////////////////
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIterator<
            layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, kOperand,
            Element, layout::TensorOpMultiplicand64bCrosswise,
            layout::PitchLinearShape<InstructionShape::kColumn,
                                     InstructionShape::kRow>,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.column(), tile_offset.row()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.column(), tile_offset.row()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.strided(), tile_offset.contiguous()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

////////////////////////////////////////////////////////////////////////////////
///
/// Satisfies:
///   ReadableRandomAccessContiguousTileIteratorConcept
///
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_,
        cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIterator<
            layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, kOperand,
            Element, layout::TensorOpMultiplicand64bCrosswise,
            layout::PitchLinearShape<InstructionShape::kRow,
                                     InstructionShape::kColumn>,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.contiguous(), tile_offset.strided()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

////////////////////////////////////////////////////////////////////////////////

/// Tile iterator specialized for canonical matrix layouts
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Operand identity
        Operand Operand_,
        /// Data type of A elements
        typename Element_,
        /// Layout of operand
        typename Layout_,
        /// Shape of one matrix production operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Delta between *MMA operations (in units of *MMA operations, concept:
        /// MatrixShape)
        int OpDelta_,
        /// Number of threads participating in one matrix operation
        int Threads = 32,
        /// Number of partitions along K dimension
        int PartitionsK_ = 1>
class MmaTensorOpMultiplicandTileIteratorCanonical {
public:
    /// Shape of tile to load (concept: MatrixShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    /// Basic check
    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = Layout_;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Number of elements accessed per Shared Memory load
    static int const kElementsPerAccess =
            (sizeof_bits<Element>::value >= 32
                     ? 1
                     : 32 / sizeof_bits<Element>::value);

private:
    static int const kWarpShapeOuter =
            (kOperand == Operand::kA ? Shape::kRow : Shape::kColumn);

    static int const kWarpShapeInner =
            (kOperand == Operand::kA ? Shape::kColumn : Shape::kRow);

    /// Rounded up instruction counts
    using InstructionCount =
            MatrixShape<Shape::kRow / InstructionShape::kRow,
                        Shape::kColumn / InstructionShape::kColumn>;

    /// Rounded up tile dimensions
    using WarpShapeDivisible =
            MatrixShape<InstructionCount::kRow * InstructionShape::kRow,
                        InstructionCount::kColumn * InstructionShape::kColumn>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment =
            Array<Element, WarpShapeDivisible::kRow *
                                   WarpShapeDivisible::kColumn / kThreads>;

    /// Memory access type
    using AccessType = AlignedArray<Element, kElementsPerAccess>;

private:
    /// Underlying tensor reference
    TensorRef ref_;

    /// Extent of tensor
    MatrixCoord extent_;

    /// Origin
    MatrixCoord origin_;

    /// Used to conditionally enable extents checking
    bool divisible_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIteratorCanonical() : divisible_(true) {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIteratorCanonical(TensorRef const& ref,
                                                 int lane_id)
            : ref_(ref),
              extent_(Shape::kRow, Shape::kColumn),
              divisible_(true) {
        if (kOperand == Operand::kA) {
            origin_ = MatrixCoord(lane_id / 4,
                                  (lane_id % 4) * kElementsPerAccess);
        } else {
            origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess,
                                  lane_id / 4);
        }

        ref_.add_coord_offset(origin_);
    }

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIteratorCanonical(TensorRef const& ref,
                                                 TensorCoord extent,
                                                 int lane_id)
            : ref_(ref), extent_(extent), divisible_(false) {
        if (kOperand == Operand::kA) {
            origin_ = MatrixCoord(lane_id / 4,
                                  (lane_id % 4) * kElementsPerAccess);
        } else {
            origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess,
                                  lane_id / 4);
        }

        ref_.add_coord_offset(origin_);
    }

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIteratorCanonical& add_pointer_offset(
            LongIndex offset) {
        ref_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIteratorCanonical& add_tile_offset(
            TensorCoord const& tile_offset) {
        TensorCoord coord_offset(tile_offset.row() * Shape::kRow,
                                 tile_offset.column() * Shape::kColumn);
        origin_ += coord_offset;

        ref_.add_coord_offset(coord_offset);

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIteratorCanonical& operator++() {
        if (kOperand == Operand::kA) {
            add_tile_offset({0, 1});
        } else {
            add_tile_offset({1, 0});
        }

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIteratorCanonical& operator--() {
        if (kOperand == Operand::kA) {
            add_tile_offset({0, -1});
        } else {
            add_tile_offset({-1, 0});
        }

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIteratorCanonical& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIteratorCanonical& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        int const kWarpShapeDivisibleInner =
                (kOperand == Operand::kA ? WarpShapeDivisible::kColumn
                                         : WarpShapeDivisible::kRow);

        // Take advantage of Tensor Op's 8 x 4T access pattern
        int const kAccessesInner =
                (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;

        AccessType* access_ptr = reinterpret_cast<AccessType*>(&frag);

        if (kOperand == Operand::kA) {
            int const kTilesPerInstruction = InstructionShape::kRow / 8;

            CUTLASS_PRAGMA_UNROLL
            for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow;
                 ++inst_m_idx) {
                CUTLASS_PRAGMA_UNROLL
                for (int inner_idx = 0; inner_idx < kAccessesInner;
                     ++inner_idx) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int access_m_idx = 0;
                         access_m_idx < kTilesPerInstruction; ++access_m_idx) {
                        int access_idx = access_m_idx +
                                         kTilesPerInstruction *
                                                 (inner_idx +
                                                  kAccessesInner * inst_m_idx);

                        MatrixCoord offset(
                                access_m_idx * 8 +
                                        inst_m_idx * InstructionShape::kRow,
                                inner_idx * 4 * kElementsPerAccess);

                        MatrixCoord access_coord = origin_ + offset;

                        if (divisible_ ||
                            (access_coord.row() < extent_.row() &&
                             access_coord.column() < extent_.column())) {
                            access_ptr[access_idx] =
                                    *reinterpret_cast<AccessType const*>(
                                            ref_.data() + ref_.offset(offset));
                        } else {
                            AccessType zero;
                            zero.clear();
                            access_ptr[access_idx] = zero;
                        }
                    }
                }
            }
        } else {
            CUTLASS_PRAGMA_UNROLL
            for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn;
                 ++inst_n_idx) {
                CUTLASS_PRAGMA_UNROLL
                for (int inner_idx = 0; inner_idx < kAccessesInner;
                     ++inner_idx) {
                    int access_idx = inner_idx + kAccessesInner * inst_n_idx;

                    MatrixCoord offset(inner_idx * 4 * kElementsPerAccess,
                                       inst_n_idx * 8);

                    MatrixCoord access_coord = origin_ + offset;

                    if (divisible_ ||
                        (access_coord.row() < extent_.row() &&
                         access_coord.column() < extent_.column())) {
                        access_ptr[access_idx] =
                                *reinterpret_cast<AccessType const*>(
                                        ref_.data() + ref_.offset(offset));
                    } else {
                        AccessType zero;
                        zero.clear();
                        access_ptr[access_idx] = zero;
                    }
                }
            }
        }
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        load_with_pointer_offset(frag,
                                 byte_offset * 8 / sizeof_bits<Element>::value);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        TensorCoord coord_offset(tile_offset.row() * Shape::kRow,
                                 tile_offset.column() * Shape::kColumn);

        load_with_pointer_offset(frag, ref_.offset(coord_offset));
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        TensorCoord coord_offset(tile_offset.row() * Shape::kRow,
                                 tile_offset.column() * Shape::kColumn);

        load_with_pointer_offset(frag,
                                 ref_.offset(coord_offset) + pointer_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        TensorCoord coord_offset(tile_offset.row() * Shape::kRow,
                                 tile_offset.column() * Shape::kColumn);

        load_with_pointer_offset(
                frag, ref_.offset(coord_offset) +
                              byte_offset * 8 / sizeof_bits<Element>::value);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) {
        // no operation
    }
};

/// Wrapper for ColumnMajor
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_, cutlass::layout::ColumnMajor,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::ColumnMajor;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIteratorCanonical<
            Shape, kOperand, Element, layout::ColumnMajor, InstructionShape,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref,
                                        TensorCoord const& extent, int lane_id)
            : iterator_({ref.data(), ref.stride()}, extent, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.contiguous(), tile_offset.strided()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

/// Wrapper for RowMajor
template <
        /// Size of the matrix to load (concept: MatrixShape)
        typename Shape_,
        /// Identifies A or B multiplicand
        Operand Operand_,
        /// Data type of elements
        typename Element_,
        /// Shape of one matrix product operation (concept: MatrixShape)
        typename InstructionShape_,
        /// Interval between adjacent *MMA instructions (in units of MMA
        /// instructions)
        int OpDelta_,
        /// Number of partitions along K dimension
        int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
        Shape_, Operand_, Element_, cutlass::layout::RowMajor,
        InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
    /// Shape of tile to load (concept: PitchLinearShape)
    using Shape = Shape_;

    /// Operand tag
    static Operand const kOperand = Operand_;

    static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
                  "MmaTensorOpMultiplicandIterator may only be instantiated "
                  "for A or B operands to warp-level Mma.");

    /// Element type
    using Element = Element_;

    /// Layout of source tile
    using Layout = cutlass::layout::RowMajor;

    /// Shape of one matrix product operation (concept: MatrixShape)
    using InstructionShape = InstructionShape_;

    /// Delta between *MMA operations (in units of *MMA operations, concept:
    /// MatrixShape)
    static int const kOpDelta = OpDelta_;

    /// Number of participating threads
    static int const kThreads = 32;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<Element, Layout>;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Underlying tile iterator implementation
    using Base = MmaTensorOpMultiplicandTileIteratorCanonical<
            Shape, kOperand, Element, layout::RowMajor, InstructionShape,
            kOpDelta, kThreads, PartitionsK_>;

public:
    //
    // Derived quantities
    //

    /// Fragment object holding a thread's part of a tile
    using Fragment = typename Base::Fragment;

private:
    /// Underlying tile iterator
    Base iterator_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator() {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref, int lane_id)
            : iterator_({ref.data(), ref.stride()}, lane_id) {}

    /// Constructor from TensorRef
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator(TensorRef const& ref,
                                        TensorCoord const& extent, int lane_id)
            : iterator_({ref.data(), ref.stride()}, extent, lane_id) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_pointer_offset(LongIndex offset) {
        iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& add_tile_offset(
            TensorCoord const& tile_offset) {
        iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator++() {
        ++iterator_;

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_HOST_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator--() {
        --iterator_;

        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator+=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    MmaTensorOpMultiplicandTileIterator& operator-=(
            TensorCoord const& tile_offset) {
        add_tile_offset(
                -PitchLinearCoord(tile_offset.row(), tile_offset.column()));
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const { iterator_.load(frag); }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        iterator_.load_with_pointer_offset(frag, pointer_offset);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(frag, byte_offset);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        // TODO
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        iterator_.load_with_byte_offset(
                frag, {tile_offset.contiguous(), tile_offset.strided()},
                byte_offset);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) { iterator_.set_kgroup_index(k_group); }
};

////////////////////////////////////////////////////////////////////////////////

}  // namespace warp
}  // namespace gemm
}  // namespace cutlass

////////////////////////////////////////////////////////////////////////////////
