#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

## Test case generator for SM80

import pycutlass
from pycutlass import *
from pycutlass.test import *
import unittest

#
# Create GEMM operation
#

def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False,
    epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
    """
    Test GEMM Operation based on configuration
    """

    if "data_type" in kwargs.keys():
        data_type = kwargs["data_type"]
    else:
        if mixed or math_inst.element_a == cutlass.bfloat16:
            data_type = [
                math_inst.element_a,
                math_inst.element_b,
                math_inst.element_accumulator,
                math_inst.element_accumulator
            ]
        else:
            data_type = [
                math_inst.element_a,
                math_inst.element_b,
                math_inst.element_a,
                math_inst.element_accumulator
            ]
    
    tile_description = TileDescription(
        tiling[0], tiling[1], tiling[2],
        math_inst
    )

    A = TensorDescription(
        data_type[0], layout[0], alignment[0]
    )

    B = TensorDescription(
        data_type[1], layout[1], alignment[1]
    )

    C = TensorDescription(
        data_type[2], layout[2], alignment[2]
    )

    element_epilogue = data_type[3]
    if epilogue_functor is None:
        epilogue_functor = LinearCombination(
            C.element, C.alignment, 
            math_inst.element_accumulator, element_epilogue)

    if gemm_kind == GemmKind.Universal:
        operation = GemmOperationUniversal(
            arch=arch, tile_description=tile_description,
            A=A, B=B, C=C,
            epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
        )
        if A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]:
            return test_all_gemm(operation, "interleaved")
        else:
            return test_all_gemm(operation, "universal")
        
    elif gemm_kind == GemmKind.Grouped:
        operation = GemmOperationGrouped(
            arch, tile_description, A, B, C,
            epilogue_functor, swizzling_functor,
            precompute_mode=kwargs["precompute_mode"]
        )
        testbed = TestbedGrouped(operation=operation)
        return testbed.run(24)
    else:
        raise NotImplementedError("the gemm kind is not implemented")


def TestConv2dOperator(math_inst, alignment, tiling, arch, 
    stride_supports=[StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided],
    epilogue_functor=None, 
    swizzling_functor=cutlass.IdentitySwizzle1, interleaved=False, **kwargs):
    """
    Test Conv2d Operation based on configurations
    """

    mixeds = [False, True, False]
    conv_kinds = [cutlass.conv.Operator.fprop, cutlass.conv.Operator.dgrad, cutlass.conv.Operator.wgrad]

    results = []

    default_swizzling_functor = swizzling_functor

    if "layout" in kwargs.keys():
        layout = kwargs["layout"]
    else:
        layout = (cutlass.TensorNHWC, cutlass.TensorNHWC, cutlass.TensorNHWC)

    for mixed, conv_kind, stride_support in zip(mixeds, conv_kinds, stride_supports):

        if "data_type" in kwargs.keys():
            data_type = kwargs["data_type"]
        else:
            if mixed or math_inst.element_a == cutlass.bfloat16:
                data_type = [
                    math_inst.element_a,
                    math_inst.element_b,
                    math_inst.element_accumulator,
                    math_inst.element_accumulator
                ]
            else:
                data_type = [
                    math_inst.element_a,
                    math_inst.element_b,
                    math_inst.element_a,
                    math_inst.element_accumulator
                ]
        # skip Int8 Conv Backward
        if data_type[0] == cutlass.int8 and conv_kind in [cutlass.conv.Operator.dgrad, cutlass.conv.Operator.wgrad]:
            continue

        A = TensorDescription(
            element=data_type[0],
            layout=layout[0],
            alignment=alignment[0])
        B = TensorDescription(
            element=data_type[1],
            layout=layout[1], 
            alignment=alignment[1])
        C = TensorDescription(
            element=data_type[2],
            layout=layout[2], 
            alignment=alignment[2])
        
        tile_description = TileDescription(
            threadblock_shape=tiling[0], stages=tiling[1], 
            warp_count=tiling[2],
            math_instruction=math_inst
        )

        if conv_kind == cutlass.conv.Operator.dgrad and stride_support == StrideSupport.Strided:
            swizzling_functor = cutlass.StridedDgradIdentitySwizzle1
        else:
            swizzling_functor = default_swizzling_functor
        
        if epilogue_functor is None:
            epilogue_functor_ = LinearCombination(
            C.element, C.alignment, 
            math_inst.element_accumulator, data_type[3])

        operation = Conv2dOperation(
            conv_kind=conv_kind, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
            arch=arch, tile_description=tile_description, A=A, B=B, C=C, 
            stride_support=stride_support,
            epilogue_functor=epilogue_functor_,
            swizzling_functor=swizzling_functor
        )
        
        results.append(test_all_conv2d(operation, interleaved=interleaved))
    
    return results



class Test_SM80(unittest.TestCase):
    def test_SM80_TensorOp_16816(self):
        math_instructions = [
            MathInstruction(
                [16, 8, 16], cutlass.float16, cutlass.float16, cutlass.float32,
                cutlass.OpClass.TensorOp, MathOperation.multiply_add
            ),
            MathInstruction(
                [16, 8, 16], cutlass.float16, cutlass.float16, cutlass.float16,
                cutlass.OpClass.TensorOp, MathOperation.multiply_add
            ),
            MathInstruction(
                [16, 8, 16], cutlass.bfloat16, cutlass.bfloat16, cutlass.float32,
                cutlass.OpClass.TensorOp, MathOperation.multiply_add
            )
        ]

        layouts = [
            (cutlass.RowMajor, cutlass.RowMajor, cutlass.RowMajor),
            (cutlass.ColumnMajor, cutlass.RowMajor, cutlass.RowMajor),
            (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajor)
        ]

        alignments = [
            (8, 8, 8), (4, 8, 8), (8, 4, 8)
        ]

        tilings = [
            ([256, 128, 32], 3, [4, 2, 1]),
            ([64, 256, 32], 4, [1, 4, 1]),
            ([128, 64, 64], 3, [2, 2, 1])
        ]

        for math_inst, layout, alignment, tiling in zip(math_instructions, layouts, alignments, tilings):
            self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False))
            self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Host))
            stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
            results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports)
            for res in results:
                self.assertTrue(res)

    def test_SM80_TensorOp_1688(self):
        # tf32 is not supported by most of python environment. Skip the test
        self.assertTrue(True)
    
    def test_SM80_TensorOp_1688_fast_math(self):
        math_instructions = [
            MathInstruction(
                [16, 8, 8], cutlass.tfloat32, cutlass.tfloat32, cutlass.float32,
                cutlass.OpClass.TensorOp, MathOperation.multiply_add
            ),
            MathInstruction(
                [16, 8, 8], cutlass.float16, cutlass.float16, cutlass.float32,
                cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_f16
            ),
            MathInstruction(
                [16, 8, 8], cutlass.bfloat16, cutlass.bfloat16, cutlass.float32,
                cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_bf16
            ),
            MathInstruction(
                [16, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32,
                cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_f32
            )
        ]

        layouts = [
            (cutlass.RowMajor, cutlass.RowMajor, cutlass.ColumnMajor),
            (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.ColumnMajor),
            (cutlass.ColumnMajor, cutlass.RowMajor, cutlass.ColumnMajor),
            (cutlass.ColumnMajor, cutlass.ColumnMajor, cutlass.RowMajor)
        ]
        alignments = [
            (4, 4, 4), (4, 2, 4), (2, 4, 4), (2, 2, 4)
        ]
        tilings = [
            ([128, 256, 16], 3, [4, 2, 1]),
            ([64, 256, 16], 4, [1, 4, 1]),
            ([128, 64, 32], 3, [2, 2, 1]),
            ([256, 64, 32], 3, [4, 2, 1])
        ]
        data_type = [
            cutlass.float32, cutlass.float32, cutlass.float32, cutlass.float32
        ]
        for math_inst, layout, alignment, tiling in zip(math_instructions, layouts, alignments, tilings):
            self.assertTrue(
                TestGemmOperator(
                    GemmKind.Universal, math_inst, layout, 
                    alignment, tiling, 80, False, data_type=data_type))
            self.assertTrue(
                TestGemmOperator(
                    GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, 
                    True, precompute_mode=SchedulerMode.Device, data_type=data_type))
            stride_supports = [StrideSupport.Unity, StrideSupport.Strided, StrideSupport.Unity]
            results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
            for res in results:
                self.assertTrue(res)

    def test_SM80_TensorOp_884(self):
        math_inst = MathInstruction(
            [8, 8, 4], cutlass.float64, cutlass.float64, cutlass.float64,
            cutlass.OpClass.TensorOp, MathOperation.multiply_add
        )
        layout = (cutlass.ColumnMajor, cutlass.ColumnMajor, cutlass.ColumnMajor)
        alignment = (1, 1, 1)

        tiling = ([64, 256, 16], 3, [2, 4, 1])
        data_type = [cutlass.float64, cutlass.float64, cutlass.float64, cutlass.float64]
        self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type))
        self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type))
        stride_supports = [StrideSupport.Unity, StrideSupport.Strided, StrideSupport.Unity]
        results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
        for res in results:
            self.assertTrue(res)
    
    def test_SM80_TensorOp_16832_TN(self):
        math_inst = MathInstruction(
            [16, 8, 32], cutlass.int8, cutlass.int8, cutlass.int32,
            cutlass.OpClass.TensorOp, MathOperation.multiply_add_saturate
        )
        layout = (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.ColumnMajor)
        alignment = (16, 16, 4)
        alignment_mixed = (16, 16, 16)
        tiling = ([128, 256, 64], 3, [2, 4, 1])

        data_type = [cutlass.int8, cutlass.int8, cutlass.int32, cutlass.int32]
        data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32]

        self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type))
        self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment_mixed, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type_mixed))
        stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
        results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
        for res in results:
            self.assertTrue(res)
    
    def test_SM80_Simt_f32(self):
        math_inst = MathInstruction(
            [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32,
            cutlass.OpClass.Simt, MathOperation.multiply_add
        )
        layout = (cutlass.RowMajor, cutlass.RowMajor, cutlass.RowMajor)
        alignment = (1, 1, 1)

        tiling = ([128, 256, 8], 4, [2, 4, 1])
        data_type = [cutlass.float32, cutlass.float32, cutlass.float32, cutlass.float32]
        self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type))
        self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Host, data_type=data_type))
        stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
        results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
        for res in results:
            self.assertTrue(res)

    def test_SM80_Simt_f64(self):
        math_inst = MathInstruction(
            [1, 1, 1], cutlass.float64, cutlass.float64, cutlass.float64,
            cutlass.OpClass.Simt, MathOperation.multiply_add
        )
        layout = (cutlass.RowMajor, cutlass.RowMajor, cutlass.ColumnMajor)
        alignment = (1, 1, 1)

        tiling = ([64, 128, 8], 5, [2, 2, 1])
        data_type = [cutlass.float64, cutlass.float64, cutlass.float64, cutlass.float64]
        self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type))
        self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type))
        stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
        results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type)
        for res in results:
            self.assertTrue(res)

    def test_SM80_TensorOp_16832_Interleaved(self):
        math_inst = MathInstruction(
            [16, 8, 32], cutlass.int8, cutlass.int8, cutlass.int32,
            cutlass.OpClass.TensorOp, MathOperation.multiply_add_saturate
        )

        layout = (cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32)
        alignment_mixed = (16, 16, 8)
        tiling = ([256, 64, 64], 4, [4, 1, 1])
        data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32]

        epilogue_functor = FastLinearCombinationClamp(
            data_type_mixed[2], alignment_mixed[2]
        )

        self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=epilogue_functor))
        stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided]
        layout = [cutlass.TensorNC32HW32, cutlass.TensorC32RSK32, cutlass.TensorNC32HW32]
        results = TestConv2dOperator(math_inst, alignment_mixed, tiling, 80, stride_supports=stride_supports, data_type=data_type_mixed, layout=layout, interleaved=True)
        for res in results:
            self.assertTrue(res)

    def SM80_SparseTensorOp_16832(self):
        pass
    def SM80_PlanarComplexTensorOp_16816(self):
        pass
    def SM80_SparseTensorOp_16816_fast_math(self):
        pass
    def SM80_TensorOp_1688_complex(self):
        pass
    def SM80_TensorOp_1688_fast_fp32_math_complex(self):
        pass
    def SM80_TensorOp_1688_rank_k(self):
        pass
    def SM80_TensorOp_1688_rank_k_complex(self):
        pass
    def SM80_TensorOp_1688_trmm(self):
        pass
    def SM80_TensorOp_1688_trmm_complex(self):
        pass
    def SM80_TensorOp_1688_symm(self):
        pass
    def SM80_TensorOp_1688_symm_complex(self):
        pass
    def SM80_TensorOp_884_complex(self):
        pass
    def SM80_TensorOp_884_complex_gaussian(self):
        pass
    def SM80_TensorOp_884_rank_k(self):
        pass
    def SM80_TensorOp_884_rank_k_complex(self):
        pass
    def SM80_TensorOp_884_rank_k_complex_gaussian(self):
        pass
    def SM80_TensorOp_884_trmm(self):
        pass
    def SM80_TensorOp_884_trmm_complex(self):
        pass
    def SM80_TensorOp_884_trmm_complex_gaussian(self):
        pass
    def SM80_TensorOp_884_symm(self):
        pass
    def SM80_TensorOp_884_symm_complex(self):
        pass
    def SM80_TensorOp_884_symm_complex_gaussian(self):
        pass
    def SM80_SparseTensorOp_16864_TN(self):
        pass
    def SM80_TensorOp_16864_TN(self):
        pass
    def SM80_SparseTensorOp_168128_TN(self):
        pass
    def SM80_TensorOp_16864_Interleaved(self):
        pass
    def SM80_TensorOp_168256(self):
        pass
    def SM80_Simt_complex(self):
        pass


if __name__ == '__main__':
    pycutlass.get_memory_pool(2**20, 2**34)
    pycutlass.compiler.nvcc()
    unittest.main()
