# Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
#     * Redistributions of source code must retain the above copyright notice, this list of
#       conditions and the following disclaimer.
#     * Redistributions in binary form must reproduce the above copyright notice, this list of
#       conditions and the following disclaimer in the documentation and/or other materials
#       provided with the distribution.
#     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
#       to endorse or promote products derived from this software without specific prior written
#       permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# this file creates the test/unit/gemm/device simt tests


outputDir = ""

################################################################################
# parameters
# Edge - for tiles, the edges represent the length of one side
# Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles
# MaxEdge - maximum length of each edge
# Min/Max - minimum/maximum of the product of edge lengths
################################################################################

warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
warpsPerThreadblockRatio = 2
warpsPerThreadblockMax = 16
# NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases

warpShapeEdges = [8, 16, 32, 64, 128, 256]
warpShapeRatio = 4
warpShapeMax = 64*64
warpShapeMin = 8*8

threadblockEdgeMax = 256

#      char,      type             bits/elem, max tile,    L0 threadblock tiles
precisions = [
       ["c", "cutlass::complex<float>",   64,  64*128, [ [ 64, 128], [ 64,  32]             ] ],
       ["d", "double",                    64,   64*64, [ [ 64,  64], [ 32,  32]             ] ],
       ["h", "cutlass::half_t",           16, 128*256, [ [256, 128], [ 64, 128], [ 64,  32] ] ],
       ["i", "int",                       32, 128*128, [ [128,  64], [ 16, 32]              ] ],
       ["s", "float",                     32, 128*128, [ [128, 256], [128, 128], [ 64,  64] ] ],
       ["z", "cutlass::complex<double>", 128,   64*64, [ [ 32,  64], [ 16,  32]             ] ],
       ]
# L1 will have a single kernel for every unique shape
# L2 will have everything else

transposes = [
       [False, False],
       [False, True],
       [True, False],
       [True, True]
       ]

################################################################################
# warps per threadblock
################################################################################
warpsPerThreadblocks = []
for warpsPerThreadblock0 in warpsPerThreadblockEdge:
    for warpsPerThreadblock1 in warpsPerThreadblockEdge:
        if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax:
            warpsPerThreadblocks.append([warpsPerThreadblock0,
                warpsPerThreadblock1])
print("WarpsPerThreadblocks",warpsPerThreadblocks)

################################################################################
# warp shapes
################################################################################
warpNumThreads = 32
warpShapes = []
for warp0 in warpShapeEdges:
    for warp1 in warpShapeEdges:
        if warp0 / warp1 <= warpShapeRatio and warp1 / warp0 <= warpShapeRatio and warp0*warp1 <= warpShapeMax and warp0*warp1 > warpShapeMin:
            warpShapes.append([warp0, warp1])
print("WarpShapes", warpShapes)

numL0 = 0
numL1 = 0
numL2 = 0

################################################################################
# create kernels
# create a file for each precision/transpose
# each file contains many tile sizes
################################################################################

# precisions
for precision in precisions:

    # get precision char
    precisionChar = precision[0]
    precisionType = precision[1]
    precisionBits = precision[2]
    threadblockMaxElements = precision[3]
    threadblockTilesL0 = precision[4]

    # transposes
    for transpose in transposes:

        # get transpose char
        columnMajorA = transpose[0]
        columnMajorB = transpose[1]
        transCharA = "n" if columnMajorA else "t"
        transCharB = "n" if columnMajorB else "t"

        # open file
        fileName="simt_%sgemm_%s%s_sm50.cu" % (precisionChar, transCharA, transCharB)
        print("\n", fileName)
        filePath = "%s%s" % (outputDir, fileName)
        out = open(filePath, "w+")

        # write file header
        out.write("/***************************************************************************************************\n"
" * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.\n"
" *\n"
" * Redistribution and use in source and binary forms, with or without modification, are permitted\n"
" * provided that the following conditions are met:\n"
" *     * Redistributions of source code must retain the above copyright notice, this list of\n"
" *       conditions and the following disclaimer.\n"
" *     * Redistributions in binary form must reproduce the above copyright notice, this list of\n"
" *       conditions and the following disclaimer in the documentation and/or other materials\n"
" *       provided with the distribution.\n"
" *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used\n"
" *       to endorse or promote products derived from this software without specific prior written\n"
" *       permission.\n"
" *\n"
" * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR\n"
" * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n"
" * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE\n"
" * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,\n"
" * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;\n"
" * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,\n"
" * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n"
" * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
" *\n"
" **************************************************************************************************/\n"
"/*! \\file\n"
"    \\brief Tests for device-wide GEMM interface\n"
"*/\n"
"\n"
"#include <iostream>\n"
"\n"
"#include \"cutlass/cutlass.h\"\n"
"#include \"cutlass/gemm/device/gemm.h\"\n"
"#include \"cutlass/numeric_types.h\"\n"
"\n"
"#include \"../../common/cutlass_unit_test.h\"\n"
"\n"
"#include \"cutlass/util/host_tensor.h\"\n"
"#include \"cutlass/util/tensor_view_io.h\"\n"
"#include \"cutlass/util/reference/host/tensor_fill.h\"\n"
"#include \"cutlass/util/reference/host/tensor_copy.h\"\n"
"#include \"cutlass/util/reference/host/tensor_compare.h\"\n"
"#include \"cutlass/util/reference/host/gemm.h\"\n"
"\n"
"#include \"testbed.h\"\n"
"\n")
        foundThreadblockTilesL0 = {}
        foundThreadblockTilesL1 = {}

        ########################################################################
        # for each combination of tile sizes
        ########################################################################
        for warpsPerThreadblock in warpsPerThreadblocks:
            for warpShape in warpShapes:
                warpThreadsM = 0
                if warpShape[0] > warpShape[1]:
                    warpThreadsM = 8
                else:
                    warpThreadsM = 4
                warpThreadsN = warpNumThreads / warpThreadsM

                # skip shapes with conflicting rectangularity
                # they are unlikely to be fastest
                blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
                blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
                warpG = warpShape[0] > warpShape[1]
                warpL = warpShape[0] < warpShape[1]

                blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2
                blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1]
                warpG2 = warpShape[0] > warpShape[1]*2
                warpL2 = warpShape[0]*2 < warpShape[1]

                if blockG2 and warpL: continue
                if blockL2 and warpG: continue
                if warpG2 and blockL: continue
                if warpL2 and blockG: continue

                # check threadblock ratios and max
                threadblockTile = [warpShape[0]*warpsPerThreadblock[0],
                        warpShape[1]*warpsPerThreadblock[1]]
                if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue
                if threadblockTile[0] > threadblockEdgeMax: continue
                if threadblockTile[1] > threadblockEdgeMax: continue
                totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1]

                # calculate unroll
                # ensure that every iteration at least a full load of A,B are done
                unrollMin = 8
                unrollMin0 = totalThreads / threadblockTile[0]
                unrollMin1 = totalThreads / threadblockTile[1]
                unroll = max(unrollMin, unrollMin0, unrollMin1)

                threadTileM = warpShape[0] / warpThreadsM
                threadTileN = warpShape[1] / warpThreadsN
                if threadTileM < 2 or threadTileN < 2: continue
                if threadTileM*threadTileN*precisionBits > 8*8*32: continue

                # epilogue currently only supports N < WarpNumThreads
                if threadblockTile[1] < warpNumThreads: continue

                # limit smem
                smemBitsA = threadblockTile[0]*unroll*2*precisionBits
                smemBitsB = threadblockTile[1]*unroll*2*precisionBits
                smemKBytes = (smemBitsA+smemBitsB)/8/1024
                if (smemKBytes > 48): continue

                # test level 0
                testLevel = -1
                for tileId in range(0, len(threadblockTilesL0)):
                    tbTile = threadblockTilesL0[tileId]
                    if tbTile[0] == threadblockTile[0] and tbTile[1] == threadblockTile[1]:
                        if tuple(tbTile) not in foundThreadblockTilesL0:
                            testLevel = 0
                            numL0 += 1
                            foundThreadblockTilesL0[tuple(tbTile)] = True

                # test level 1
                if testLevel < 0:
                    threadblockTileAlreadyUsed = False
                    if tuple(threadblockTile) not in foundThreadblockTilesL1:
                        testLevel = 1
                        numL1 += 1
                        foundThreadblockTilesL1[tuple(threadblockTile)] = True

                # test level 2
                if testLevel < 0:
                    testLevel = 2
                    numL2 += 1

                ################################################################
                # write this tile to file
                ################################################################

                print("%ix%ix%i__%ix%i_%ix%i_%ix%i L%i" % (
                        threadblockTile[0], threadblockTile[1], unroll,
                        threadTileM, threadTileN,
                        warpThreadsM, warpThreadsN,
                        warpsPerThreadblock[0], warpsPerThreadblock[1], testLevel))

                out.write("////////////////////////////////////////////////////////////////////////////////\n"
                        "// Elements / Thread: %3i x %3i\n"
                        "//    Threads / Warp: %3i x %3i\n"
                        "//     Warps / Block: %3i x %3i\n"
                        "//       Threadblock: %3i x %3i x %2i\n"
                        % ( threadTileM, threadTileN,
                            warpThreadsM, warpThreadsN,
                            warpsPerThreadblock[0], warpsPerThreadblock[1],
                            threadblockTile[0], threadblockTile[1], unroll
                            )
                        )

                out.write("CUTLASS_TEST_L%i(SM50_device_%sgemm_%s%s, %ix%ix%i_%ix%ix1_%ix%i_%ix%i_%ix%i, {\n" % (
                    testLevel,
                    precisionChar,
                    transCharA,
                    transCharB,
                    threadblockTile[0],
                    threadblockTile[1],
                    unroll,
                    warpShape[0],
                    warpShape[1],
                    threadTileM,
                    threadTileN,
                    warpThreadsM,
                    warpThreadsN,
                    warpsPerThreadblock[0],
                    warpsPerThreadblock[1]
                    ))
                out.write("    using precision = %s;\n" % precisionType)
                out.write("    using ThreadblockShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n" % (
                    threadblockTile[0],
                    threadblockTile[1],
                    unroll))
                out.write("    using WarpShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n\n" % (
                    warpShape[0],
                    warpShape[1],
                    unroll))
                out.write("    static int const kEpilogueElementsPerAccess = 1;\n"
                    "    using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;\n"
                    "    using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<\n"
                    "        precision, kEpilogueElementsPerAccess, precision, precision>;\n\n")

                out.write("    using Gemm = cutlass::gemm::device::Gemm<\n"
                    "        precision, cutlass::layout::%sMajor,\n"
                    "        precision, cutlass::layout::%sMajor,\n"
                    "        precision, cutlass::layout::RowMajor,\n"
                    "        precision,\n"
                    "        cutlass::arch::OpClassSimt,\n"
                    "        cutlass::arch::Sm50,\n"
                    "        ThreadblockShape, WarpShape, InstructionShape,\n"
                    "        EpilogueOutputOp,\n"
                    "        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,\n"
                    "        2 // Stages\n"
                    "    >;\n" % (
                        "Column" if columnMajorA else "Row",
                        "Column" if columnMajorB else "Row",
                        ))
                out.write("    EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());\n"
                    "} )\n\n")


        out.close()
print("NumKernels:", numL0, numL1, numL2)

