/******************************************************************************
 * Copyright (c) 2023, Tri Dao.
 ******************************************************************************/

#pragma once

#include <assert.h>
#include <stdint.h>
#include <stdlib.h>

#include <cuda_fp16.h>

#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace cute {

}  // namespace cute

namespace flash {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<bool A_in_regs = false,
         bool B_in_regs = false,
         typename Tensor0,
         typename Tensor1,
         typename Tensor2,
         typename Tensor3,
         typename Tensor4,
         typename TiledMma,
         typename TiledCopyA,
         typename TiledCopyB,
         typename ThrCopyA,
         typename ThrCopyB>
inline __device__ void gemm(Tensor0&       acc,
                            Tensor1&       tCrA,
                            Tensor2&       tCrB,
                            Tensor3 const& tCsA,
                            Tensor4 const& tCsB,
                            TiledMma       tiled_mma,
                            TiledCopyA     smem_tiled_copy_A,
                            TiledCopyB     smem_tiled_copy_B,
                            ThrCopyA       smem_thr_copy_A,
                            ThrCopyB       smem_thr_copy_B)
{
    CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));   // MMA_M
    CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));   // MMA_N
    CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));  // MMA_K
    Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
    CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));  // M
    Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
    CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));  // N
    if (!A_in_regs) {
        cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
    }
    if (!B_in_regs) {
        cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
    }
#pragma unroll
    for (int i = 0; i < size<2>(tCrA); ++i) {
        if (i < size<2>(tCrA) - 1) {
            if (!A_in_regs) {
                cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
            }
            if (!B_in_regs) {
                cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
            }
        }
        cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Tensor0,
         typename Tensor1,
         typename Tensor2,
         typename Tensor3,
         typename TiledMma,
         typename TiledCopy,
         typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0&       acc,
                                      Tensor1&       tCrA,
                                      Tensor2&       tCrB,
                                      Tensor3 const& tCsB,
                                      TiledMma       tiled_mma,
                                      TiledCopy      smem_tiled_copy_B,
                                      ThrCopy        smem_thr_copy_B)
{
    CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));   // MMA_M
    CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));   // MMA_N
    CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));  // MMA_K
    Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
    CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));  // N
    cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
    for (int i = 0; i < size<2>(tCrA); ++i) {
        if (i < size<2>(tCrA) - 1) {
            cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
        }
        cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template<int N>
CUTE_HOST_DEVICE void cp_async_wait()
{
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
    asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#endif
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template<bool Is_even_MN   = true,
         bool Is_even_K    = true,
         bool Clear_OOB_MN = false,
         bool Clear_OOB_K  = true,
         typename TiledCopy,
         typename Engine0,
         typename Layout0,
         typename Engine1,
         typename Layout1,
         typename Engine2,
         typename Layout2,
         typename Engine3,
         typename Layout3>
inline __device__ void copy(TiledCopy                       thr_copy,
                            Tensor<Engine0, Layout0> const& S,
                            Tensor<Engine1, Layout1>&       D,
                            Tensor<Engine2, Layout2> const& identity_MN,
                            Tensor<Engine3, Layout3> const& predicate_K,
                            int                             max_MN = 0)
{
    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));  // MMA
    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));  // MMA_M
    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));  // MMA_K
    // There's no case where !Clear_OOB_K && Clear_OOB_MN
    static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
    for (int m = 0; m < size<1>(S); ++m) {
        if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
            for (int k = 0; k < size<2>(S); ++k) {
                if (Is_even_K || predicate_K(k)) {
                    copy(thr_copy, S(_, m, k), D(_, m, k));
                }
                else if (Clear_OOB_K) {
                    clear(D(_, m, k));
                }
            }
        }
        else if (Clear_OOB_MN) {
            clear(D(_, m, _));
        }
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T>
struct MaxOp {
    __device__ inline T operator()(T const& x, T const& y)
    {
        return x > y ? x : y;
    }
};

template<>
struct MaxOp<float> {
    // This is slightly faster
    __device__ inline float operator()(float const& x, float const& y)
    {
        return max(x, y);
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T>
struct SumOp {
    __device__ inline T operator()(T const& x, T const& y)
    {
        return x + y;
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<int THREADS>
struct Allreduce {
    static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
    template<typename T, typename Operator>
    static __device__ inline T run(T x, Operator& op)
    {
        constexpr int OFFSET = THREADS / 2;
        x                    = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
        return Allreduce<OFFSET>::run(x, op);
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<>
struct Allreduce<2> {
    template<typename T, typename Operator>
    static __device__ inline T run(T x, Operator& op)
    {
        x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
        return x;
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout)
{
    static_assert(decltype(size<0>(acc_layout))::value == 4);
    static_assert(decltype(rank(acc_layout))::value == 3);
    auto l = logical_divide(acc_layout, Shape<_2>{});  // ((2, 2), MMA_M, MMA_N)
    return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout)
{
    using X = Underscore;
    static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
    static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
    constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
    static_assert(mma_shape_K == 8 || mma_shape_K == 16);
    constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
    auto          l =
        logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{});  // ((2, MMA_M), (2, (2, MMA_N / 2)))
    return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), get<0, 1>(l), get<1, 1, 1>(l));
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const& tensor)
{
    using From_type                                                 = typename Engine::value_type;
    constexpr int                                             numel = decltype(size(tensor))::value;
    cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
    // HACK: this requires tensor to be "contiguous"
    auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
    return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace flash
