/***************************************************************************************************
 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Unit tests for threadblock level GEMV
*/

#include "../../common/cutlass_unit_test.h"

#include "cutlass/aligned_buffer.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/tensor_ref.h"

#include "cutlass/core_io.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"

#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"

#include "cutlass/gemm/threadblock/gemv.h"
#include "cutlass/gemm/threadblock/default_gemv_core.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace test {
namespace gemm {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Gemv, typename LongIndex, typename RefA, typename RefB, typename RefC>
__global__ void batched_gemv_threadblock_test_kernel(
  cutlass::gemm::GemmCoord problem_size,
  LongIndex stride_a,
  LongIndex stride_b,
  LongIndex stride_c,
  RefA ref_A,
  RefB ref_B,
  RefC ref_C
  ) {

  typename Gemv::IteratorA::TensorCoord threadblock_offset_A(0, 0);
  typename Gemv::IteratorB::TensorCoord threadblock_offset_B(0, 0);
  typename Gemv::IteratorB::TensorCoord threadblock_offset_C(0, 0);

  // Move to the right batches for these threads
  ref_A.add_pointer_offset(threadIdx.y * stride_a);
  ref_B.add_pointer_offset(threadIdx.y * stride_b);
  ref_C.add_pointer_offset(threadIdx.y * stride_c);

  // Construct iterators to A and B operands
  typename Gemv::IteratorA::Params params_A(ref_A.layout());
  typename Gemv::IteratorA iterator_A(params_A, ref_A.data(), { problem_size.m(), problem_size.k() }, 0, threadblock_offset_A);
  typename Gemv::IteratorB::Params params_B(ref_B.layout());
  typename Gemv::IteratorB iterator_B(params_B, ref_B.data(), { problem_size.k(), problem_size.n() }, threadIdx.x, threadblock_offset_B);

  Gemv gemv;

  typename Gemv::FragmentC accum;
  accum.clear();

  // Compute threadblock-scoped matrix multiply-add
  gemv(problem_size, accum, iterator_A, iterator_B, accum);

  // IteratorC is PitchLinear<> assumes n() contiguous
  typename Gemv::IteratorC::Params params_C(ref_C.layout());
  typename Gemv::IteratorC iterator_C(params_C, ref_C.data(), { problem_size.m(), problem_size.n() }, threadIdx.x, threadblock_offset_C);
  iterator_C.store(accum);
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Shape_,
         typename ElementAB_,
         typename ElementC_,
         typename LayoutA_,
         typename LayoutB_,
         typename LayoutC_,
         int LDG_N,
         int LDG_K,
         int MAX_THREADS_PER_BLOCK=512,
         bool DEBUG=false>
void batched_gemv_threadblock_test(cutlass::gemm::GemmCoord problem_size, int num_batch)
{
  using Shape = Shape_;
  using ElementA = ElementAB_;
  using LayoutA = LayoutA_;
  using ElementB = ElementAB_;
  using LayoutB = LayoutB_;
  using ElementC = ElementC_;
  using LayoutC = LayoutC_;
  using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N, LDG_K>;

  using Core = typename cutlass::gemm::threadblock::DefaultGemvCore<
    Shape,
    ThreadShape,
    ElementA,
    LayoutA,
    ElementB,
    LayoutB,
    ElementC,
    LayoutC
  >;

  if (DEBUG)
  { 
      num_batch = 1;
  }

  using Mma = cutlass::gemm::threadblock::Gemv<Core>;

  // Create host tensors that will be the backing store for the batches
  // Note that no device memory is initially allocated
  cutlass::HostTensor<ElementA, LayoutA> matrix_A({problem_size.m(), problem_size.k()}, false); 
  cutlass::HostTensor<ElementB, LayoutB> matrix_B({problem_size.k(), problem_size.n()}, false); 
  cutlass::HostTensor<ElementC, LayoutC> matrix_C_computed({problem_size.m(), problem_size.n()}, false); 
  cutlass::HostTensor<ElementC, LayoutC> matrix_C_reference({problem_size.m(), problem_size.n()}, false);

  // Reserve memory for the batch of tensors
  matrix_A.reserve(problem_size.m()*problem_size.k()*num_batch);
  matrix_B.reserve(problem_size.n()*problem_size.k()*num_batch);
  matrix_C_computed.reserve(problem_size.m()*problem_size.n()*num_batch);
  matrix_C_reference.reserve(problem_size.m()*problem_size.n()*num_batch, false);

  // Fill eatch tensor batch
  const int seed = 6834;
  for (int b = 0; b < num_batch; b++)
  {
    if(DEBUG)
    {
      cutlass::reference::host::BlockFillSequential(
        matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity());
      cutlass::reference::host::BlockFillSequential(
        matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity());
    }
    else
    {
      cutlass::reference::host::TensorFillRandomUniform(
        matrix_A.host_view(b*matrix_A.capacity()),
        seed + 1660,
        8,
        -8,
        0
      );

      cutlass::reference::host::TensorFillRandomUniform(
        matrix_B.host_view(b*matrix_B.capacity()),
        seed + 1880,
        8,
        -8,
        0
      );
    }

    cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity()));
    cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity()));
  }

  matrix_A.sync_device();
  matrix_B.sync_device();
  matrix_C_computed.sync_device();

  dim3 grid(1, 1);      // only 1 CTA is used
  dim3 block(Shape::kN / LDG_N, num_batch, 1);

  #if 0
  printf("block dim = %d x %d\n", block.x, block.y);
  #endif

  // Some sanity checks
  EXPECT_TRUE( problem_size.n() % LDG_N == 0 );
  EXPECT_TRUE( block.x*block.y <= MAX_THREADS_PER_BLOCK );

  test::gemm::threadblock::batched_gemv_threadblock_test_kernel<Mma><<< grid, block >>>(
    problem_size,
    matrix_A.capacity(),
    matrix_B.capacity(),
    matrix_C_computed.capacity(),
    matrix_A.device_ref(),
    matrix_B.device_ref(),
    matrix_C_computed.device_ref()
  );

  cudaError_t result = cudaDeviceSynchronize();
  EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result);

  matrix_C_computed.sync_host();

  // Compute the batched gemms
  for (int b = 0; b < num_batch; b++)
  {

    cutlass::reference::host::Gemm<ElementA, LayoutA, ElementB, LayoutB,
                                   ElementC, LayoutC, ElementC, ElementC> reference_gemm;

    reference_gemm(
      problem_size.mnk(),
      ElementC(1),
      matrix_A.host_ref(b*matrix_A.capacity()),
      matrix_B.host_ref(b*matrix_B.capacity()),
      ElementC(0),
      matrix_C_reference.host_ref(b*matrix_C_computed.capacity())
    );

    bool passed = cutlass::reference::host::TensorEquals(
                    matrix_C_computed.host_view(b*matrix_C_computed.capacity()), 
                    matrix_C_reference.host_view(b*matrix_C_reference.capacity()));

    EXPECT_TRUE(passed)
    //<< "A:\n" << matrix_A.host_view() << "\n"
    //<< "B:\n" << matrix_B.host_view() << "\n"
      << "Batch: " << b << "\n"
      << "Reference:\n" << matrix_C_reference.host_view(b*matrix_C_reference.capacity()) << "\n"
      << "Computed:\n" << matrix_C_computed.host_view(b*matrix_C_computed.capacity()) << "\n";
  }
}

} // namespace threadblock
} // namespace gemm
} // namespace test

/////////////////////////////////////////////////////////////////////////////////////////////////

// A: ColumnMajor
// B: RowMajor
// C: ColumnMajor

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp32_fp32_2N_2K) {
  
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 2;
 
  using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
  batched_gemv_threadblock_test<Shape, float, float, 
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 5x1x128x128_crc_fp32_fp32_4N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 128, 128);
  const int num_batch = 5;
  const int LDG_N = 4;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
  batched_gemv_threadblock_test<Shape, float, float, 
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp32_fp32_1N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 17, 64);
  const int num_batch = 16;
  const int LDG_N = 1;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                float, float, 
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_2K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 2;
  
  using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                cutlass::half_t, float, 
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_8K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 8;
 
  using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                cutlass::half_t, float, 
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp16_fp32_1N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 17, 64);
  const int num_batch = 16;
  const int LDG_N = 1;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                cutlass::half_t, float, 
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_i8_i32_2N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                int8_t, int32_t, 
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_i8_i32_1N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 17, 64);
  const int num_batch = 16;
  const int LDG_N = 1;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                int8_t, int32_t, 
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

// A: RowMajor
// B: ColumnMajor
// C: RowMajor

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp32_fp32_2N_2K) {
  
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 2;
 
  using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
  batched_gemv_threadblock_test<Shape, float, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcr_fp32_fp32_4N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 128, 128);
  const int num_batch = 5;
  const int LDG_N = 4;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
  batched_gemv_threadblock_test<Shape, float, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp32_fp32_1N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 17, 64);
  const int num_batch = 16;
  const int LDG_N = 1;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                float, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_2K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 2;
  
  using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                cutlass::half_t, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_8K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 8;
 
  using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                cutlass::half_t, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp16_fp32_1N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 17, 64);
  const int num_batch = 16;
  const int LDG_N = 1;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                cutlass::half_t, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_i8_i32_2N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                int8_t, int32_t, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_i8_i32_1N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 17, 64);
  const int num_batch = 16;
  const int LDG_N = 1;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                int8_t, int32_t, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

// A: RowMajor
// B: ColumnMajor
// C: ColumnMajor

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp32_fp32_2N_2K) {
  
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 2;
 
  using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
  batched_gemv_threadblock_test<Shape, float, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcc_fp32_fp32_4N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 128, 128);
  const int num_batch = 5;
  const int LDG_N = 4;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
  batched_gemv_threadblock_test<Shape, float, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp32_fp32_1N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 17, 64);
  const int num_batch = 16;
  const int LDG_N = 1;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                float, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_2K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 2;
  
  using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                cutlass::half_t, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_8K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 8;
 
  using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                cutlass::half_t, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp16_fp32_1N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 17, 64);
  const int num_batch = 16;
  const int LDG_N = 1;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                cutlass::half_t, float, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_i8_i32_2N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 64, 64);
  const int num_batch = 4;
  const int LDG_N = 2;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                int8_t, int32_t, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_i8_i32_1N_4K) {
  using namespace test::gemm::threadblock;
  cutlass::gemm::GemmCoord problem_size(1, 17, 64);
  const int num_batch = 16;
  const int LDG_N = 1;
  const int LDG_K = 4;

  using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
  batched_gemv_threadblock_test<Shape,
                                int8_t, int32_t, 
                                cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor,
                                cutlass::layout::ColumnMajor,
                                LDG_N, LDG_K>(problem_size, num_batch);
}
