
#pragma once

#include "cute/tensor.hpp"

#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>

using namespace cute;

template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
struct Flash_kernel_traits {

#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
    using Element = elem_type;
    static constexpr bool Has_cp_async = true;
#else
    using Element = cutlass::half_t;
    static constexpr bool Has_cp_async = false;
#endif

    using ElementAccum = float;
    using index_t = int64_t;

#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
    using MMA_Atom_Arch = std::conditional_t<
        std::is_same_v<elem_type, cutlass::half_t>,
        MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
        MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
    >;
#else
    using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
#endif

#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 750
    using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
    using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
    using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
    using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
#endif
};

// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template<
    int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
    bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, bool Has_mask_=false, bool Has_bias_=false, typename elem_type=cutlass::half_t,
    typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type>
>
struct Flash_fwd_kernel_traits : public Base {
    using Element = typename Base::Element;
    using ElementAccum = typename Base::ElementAccum;
    using index_t = typename Base::index_t;
    static constexpr bool Has_cp_async = Base::Has_cp_async;
    using SmemCopyAtom = typename Base::SmemCopyAtom;
    using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;

    static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
    static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
    static constexpr bool Has_mask = Has_mask_;
    static constexpr bool Has_bias = Has_bias_;

    // The number of threads.
    static constexpr int kNWarps = kNWarps_;
    static constexpr int kNThreads = kNWarps * 32;

    static constexpr int kBlockM = kBlockM_;
    static constexpr int kBlockN = kBlockN_;
    static constexpr int kHeadDim = kHeadDim_;
    static_assert(kHeadDim % 32 == 0);
    static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
    static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
    static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
    static constexpr int kSwizzlePS = 3;

    using TiledMma = TiledMMA<
        typename Base::MMA_Atom_Arch,
        Layout<Shape<Int<kNWarps>, _1, _1>>,  // 4x1x1 or 8x1x1 thread group
        Tile<Int<16 * kNWarps>, _16, _16>
    >;

    // Shared memory layout for Q matrix and Mask matrix
    using SmemLayoutAtomQ = decltype(
        composition(
            Swizzle<kSwizzle, 3, 3>{},
            // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
            Layout<Shape<_8, Int<kBlockKSmem>>,
            Stride<Int<kBlockKSmem>, _1>>{}
        )
    );
    using SmemLayoutAtomPS = decltype(
        composition(
            Swizzle<kSwizzlePS, 3, 3>{},
            Layout<Shape<Int<kBlockM>, Int<kBlockN>>,
            Stride<Int<kBlockN>, _1>>{}
        )
    );

    using SmemLayoutQ = decltype(
        tile_to_shape(
            SmemLayoutAtomQ{},
            Shape<Int<kBlockM>, Int<kHeadDim>>{}
        )
    );

    using SmemLayoutKV = decltype(
        tile_to_shape(
            SmemLayoutAtomQ{},
            Shape<Int<kBlockN>, Int<kHeadDim>>{}
        )
    );

    // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
    using SmemLayoutVtransposed = decltype(
        composition(
            SmemLayoutKV{},
            make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})
        )
    );
    using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));

    using SmemLayoutPS = decltype(
        tile_to_shape(
            SmemLayoutAtomPS{},
            Shape<Int<kBlockM>, Int<kBlockN>>{}
        )
    );
    using SmemCopyAtomPS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;

    // Shared memory layout for output
    using SmemLayoutAtomO = decltype(
        composition(
            Swizzle<kSwizzle, 3, 3>{},
            Layout<Shape<Int<8>, Int<kBlockKSmem>>,
            Stride<Int<kBlockKSmem>, _1>>{}
        )
    );
    using SmemLayoutO = decltype(
        tile_to_shape(
            SmemLayoutAtomO{},
            Shape<Int<kBlockM>, Int<kHeadDim>>{}
        )
    );
    using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
    using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;

    // Shared memory size calculations
    static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
    static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
    static constexpr int kSmemMaskSize = size(SmemLayoutPS{}) * sizeof(Element);
    static constexpr int kSmemBiasSize = size(SmemLayoutPS{}) * sizeof(Element);

    // Shared memory size with QKV matrices and mask/bias matrices
    static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + (Has_mask ? kSmemMaskSize : 0) + (Has_bias ? kSmemBiasSize : 0);

    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
    static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
    // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
    // For example, for d=128, smem is split into 2 "pages", each page takes care of columns
    // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
    // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
    // to the same banks.
    static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
    static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
    using GmemLayoutAtom = Layout<
        Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
        Stride<Int<kGmemThreadsPerRow>, _1>
    >;

    // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
    // from the same address by the same threadblock. This is slightly faster.
    using Gmem_copy_struct = std::conditional_t<
        Has_cp_async,
        SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
        AutoVectorizingCopyWithAssumedAlignment<128>
    >;
    using GmemTiledCopyQKV = decltype(
        make_tiled_copy(
            Copy_Atom<Gmem_copy_struct, Element>{},
            GmemLayoutAtom{},
            Layout<Shape<_1, _8>>{}
        )
    );      // Val layout, 8 vals per read
    using GmemTiledCopyMask = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
            GmemLayoutAtom{},
            Layout<Shape<_1, _8>>{}
        )
    );      // Val layout, 8 vals per read
    using GmemTiledCopyBias = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
            GmemLayoutAtom{},
            Layout<Shape<_1, _8>>{}
        )
    );      // Val layout, 8 vals per read
    using GmemTiledCopyO = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
            GmemLayoutAtom{},
            Layout<Shape<_1, _8>>{}
        )
    );      // Val layout, 8 vals per store

    // Accumulator layout for output
    using GmemLayoutAtomOaccum = std::conditional_t<
        kBlockKSmem == 32,
        Layout<Shape <_16, _8>, Stride< _8, _1>>,   // Thread layout, 8 threads per row
        Layout<Shape <_8, _16>, Stride< _16, _1>>   // Thread layout, 16 threads per row       
    >;

    using GmemTiledCopyOaccum = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
            GmemLayoutAtomOaccum{},
            Layout<Shape < _1, _4>>{}
        )
    );  // Val layout, 4 vals per store
};

// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
// No_double_buffer is another option to reduce smem usage, but will slow things down.
template<
    int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
    int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
    bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
    typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type>
>
struct Flash_bwd_kernel_traits : public Base {
    using Element = typename Base::Element;
    using ElementAccum = typename Base::ElementAccum;
    using index_t = typename Base::index_t;
    static constexpr bool Has_cp_async = Base::Has_cp_async;
    using SmemCopyAtom = typename Base::SmemCopyAtom;
    using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;

    static constexpr bool Is_V_in_regs = Is_V_in_regs_;
    static constexpr bool No_double_buffer = No_double_buffer_;

    // The number of threads.
    static constexpr int kNWarps = kNWarps_;
    static constexpr int kNThreads = kNWarps * 32;

    static constexpr int kBlockM = kBlockM_;
    static constexpr int kBlockN = kBlockN_;
    static constexpr int kHeadDim = kHeadDim_;
    static_assert(kHeadDim % 32 == 0);
    static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
    static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
    static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;

    static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
    static_assert(kNWarps % AtomLayoutMSdP == 0);
    static_assert(kNWarps % AtomLayoutNdKV == 0);
    static_assert(kNWarps % AtomLayoutMdQ == 0);

    using TiledMmaSdP = TiledMMA<
        typename Base::MMA_Atom_Arch,
        Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
        Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>
    >;

    using TiledMmadKV = TiledMMA<
        typename Base::MMA_Atom_Arch,
        Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
        Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>
    >;

    using TiledMmadQ = TiledMMA<
        typename Base::MMA_Atom_Arch,
        Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>,    // 2x4x1 or 4x2x1 thread group
        Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>
    >;

    using SmemLayoutAtomQdO = decltype(
        composition(
            Swizzle<kSwizzle, 3, 3>{},
            Layout<Shape<_8, Int<kBlockKSmem>>,
            Stride<Int<kBlockKSmem>, _1>>{}
        )
    );
    using SmemLayoutQdO = decltype(
        tile_to_shape(
            SmemLayoutAtomQdO{},
            make_shape(Int<kBlockM>{}, Int<kHeadDim>{})
        )
    );

    using SmemLayoutAtomKV = decltype(
        composition(
            Swizzle<kSwizzle, 3, 3>{},
            Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
            Stride<Int<kBlockKSmem>, _1>>{}
        )
    );
    using SmemLayoutKV = decltype(
        tile_to_shape(
            // SmemLayoutAtomQdO{},
            SmemLayoutAtomKV{},
            make_shape(Int<kBlockN>{}, Int<kHeadDim>{})
        )
    );

    using SmemLayoutKtransposed = decltype(
        composition(
            SmemLayoutKV{},
            make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})
        )
    );
    using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));

    // TODO: generalize to other values of kBlockN
    // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
    // static constexpr int kPBlockN = kBlockN;
    // Temporarily disabling this for hdim 256 on sm86 and sm89
    // static_assert(kBlockN >= 64);
    static_assert(kBlockN >= 32);
    // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
    static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
    static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
    // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
    static constexpr int kSwizzlePdS = 3;
    using SmemLayoutAtomMaskBiasPdS = decltype(
        composition(
            Swizzle<kSwizzlePdS, 3, 3>{},
            Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
            Stride<Int<kPBlockN>, _1>>{}
        )
    );
    using SmemLayoutMaskBiasPdS = decltype(
        tile_to_shape(
            SmemLayoutAtomMaskBiasPdS{},
            make_shape(Int<kBlockM>{}, Int<kBlockN>{})
        )
    );
    using SmemLayoutPdStransposed = decltype(
        composition(
            SmemLayoutMaskBiasPdS{},
            make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})
        )
    );
    using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));

    using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;

    using SmemLayoutQdOtransposed = decltype(
        composition(
            SmemLayoutQdO{},
            make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})
        )
    );
    using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));

    using SmemLayoutAtomdKV = decltype(
        composition(
            Swizzle<kSwizzle, 3, 3>{},
            Layout<Shape<_8, Int<kBlockKSmem>>,
            Stride<Int<kBlockKSmem>, _1>>{}
        )
    );
    using SmemLayoutdKV = decltype(
        tile_to_shape(
            SmemLayoutAtomdKV{},
            make_shape(Int<kBlockN>{}, Int<kHeadDim>{})
        )
    );
    using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;

    using SmemLayoutAtomdQ = decltype(
        composition(
            Swizzle<kSwizzle, 3, 3>{},
            Layout<Shape<_8, Int<kBlockKSmem>>,
            Stride<Int<kBlockKSmem>, _1>>{}
        )
    );
    using SmemLayoutdQ = decltype(
        tile_to_shape(
            SmemLayoutAtomdQ{},
            make_shape(Int<kBlockM>{}, Int<kHeadDim>{})
        )
    );
    using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;

    // Double buffer for sQ
    static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
    static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
    static constexpr int kSmemBiasdSSize = size(SmemLayoutMaskBiasPdS{}) * sizeof(Element);
    static constexpr int kSmemPMaskSize = size(SmemLayoutMaskBiasPdS{}) * sizeof(Element);
    static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
    static constexpr int kSmemSize = kSmemQdOSize
        + (
            !Is_V_in_regs
            ? kSmemKVSize + kSmemBiasdSSize + std::max(kSmemPMaskSize, kSmemdQSize)
            : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemBiasdSSize + std::max(kSmemPMaskSize, kSmemdQSize))
        );
    static constexpr int kSmemSize1colblock = kSmemQdOSize
        + (
            !Is_V_in_regs
            ? kSmemKVSize + kSmemBiasdSSize + kSmemPMaskSize
            : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemBiasdSSize + kSmemPMaskSize)
        );

    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
    static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
    // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
    // to affect speed in practice.
    static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
    static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
    using GmemLayoutAtom = Layout<
        Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
        Stride<Int<kGmemThreadsPerRow>, _1>
    >;

    // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
    // from the same address by the same threadblock. This is slightly faster.
    using Gmem_copy_struct = std::conditional_t<
        Has_cp_async,
        SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
        AutoVectorizingCopyWithAssumedAlignment<128>
    >;
    using GmemTiledCopyQKV = decltype(
        make_tiled_copy(
            Copy_Atom<Gmem_copy_struct, elem_type>{},
            GmemLayoutAtom{},
            Layout<Shape<_1, _8>>{}
        )
    );      // Val layout, 8 vals per read
    using GmemTiledCopydO = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
            GmemLayoutAtom{},
            Layout<Shape < _1, _8>>{}
        )
    );      // Val layout, 8 vals per store
    using GmemTiledCopydKV = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
            GmemLayoutAtom{},
            Layout<Shape < _1, _8>>{}
        )
    );      // Val layout, 8 vals per store
    using GmemTiledCopydQ = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
            GmemLayoutAtom{},
            Layout<Shape < _1, _8>>{}
        )
    );      // Val layout, 8 vals per store
    using GmemTiledCopyMask = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
            GmemLayoutAtom{},
            Layout<Shape<_1, _8>>{}
        )
    );      // Val layout, 8 vals per read
    using GmemTiledCopyBias = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
            GmemLayoutAtom{},
            Layout<Shape<_1, _8>>{}
        )
    );      // Val layout, 8 vals per read
    using GmemLayoutAtomdQaccum = std::conditional_t<
        kBlockKSmem == 32,
        Layout<Shape <_32, _8>, Stride< _8, _1>>,       // Thread layout, 8 threads per row
        Layout<Shape <_16, _16>, Stride< _16, _1>>      // Thread layout, 16 threads per row
               
    >;
    using GmemTiledCopydQaccum = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
            GmemLayoutAtomdQaccum{},
            Layout<Shape < _1, _4>>{}
        )
    );      // Val layout, 4 vals per store

    using GmemTiledCopydQaccumAtomicAdd = decltype(
        make_tiled_copy(
            Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
            Layout<Shape <_8, _32>,                     // Thread layout, 8 threads per row
            Stride<_32, _1>>{},
            Layout<Shape < _1, _1>>{}
        )
    );      // Val layout, 1 val per store
};

////////////////////////////////////////////////////////////////////////////////////////////////////