/***************************************************************************************************
 * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
 *reserved. SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 *POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/

//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/visitor_load.hpp from
// https://github.com/NVIDIA/cutlass It's beem modified to support either
// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x.
// Important because this saves us a factor 4x on the number of kernels
// compiled.
//
#pragma once

// clang-format off

#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cute/tensor.hpp"

// clang-format on

namespace cutlass::epilogue::threadblock {

using namespace cute;
using namespace detail;

template<
  class ThreadMap,
  class Element,
  class StrideMNL
>
struct VisitorRowOrScalarBroadcast {

  struct Arguments {
    Element const* ptr_row = nullptr;
    Element null_default = Element(0);
    StrideMNL dRow = {};
  };

  using Params = Arguments;

  template <class ProblemShape>
  static constexpr Params
  to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
    return args;
  }

  template <class ProblemShape>
  static size_t
  get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
    return 0;
  }

  struct SharedStorage {};

  // Global load type
  static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
  using VecType = uint_bit_t<cute::min(128, vec_bits)>;
  static int constexpr VecLength = sizeof(VecType) / sizeof(Element);

  CUTLASS_HOST_DEVICE
  VisitorRowOrScalarBroadcast() { }

  CUTLASS_HOST_DEVICE
  VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
    : params_ptr(&params) { }

  Params const* params_ptr;

  template <class GTensor, class RTensor, class CTensor, class ProblemShape>
  struct Callbacks : EmptyCallbacks {
    CUTLASS_DEVICE
    Callbacks(
      GTensor&& tC_gRow,
      RTensor&& tC_rRow,
      CTensor&& tC_cRow,
      ProblemShape problem_shape,
      Params const* params_ptr
    ):
      tC_gRow(cute::forward<GTensor>(tC_gRow)),
      tC_rRow(cute::forward<RTensor>(tC_rRow)),
      tC_cRow(cute::forward<CTensor>(tC_cRow)),
      n(get<1>(problem_shape)),
      params_ptr(params_ptr) { }

    GTensor tC_gRow;
    RTensor tC_rRow;
    CTensor tC_cRow;
    Params const* params_ptr;
    int n;

    // This function is modified from VisitorRowBroadcast
    CUTLASS_DEVICE void
    begin_epilogue() {
      clear(tC_rRow);
      auto src_v = filter(tC_gRow);
      auto coord_v = filter(tC_cRow);
      auto dst_v = filter(tC_rRow);

      if (params_ptr->ptr_row) {
        // In this case we are loading from a row vector and broadcasting
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < size(src_v); ++i) {
          bool guard = get<1>(coord_v(i)) < n;
          cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
        }
      } else {
        // In this case we are loading from a scalar and broadcasting
        VecType filled_vec;
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < VecLength; i++) {
          reinterpret_cast<Element*>(&filled_vec)[i] = params_ptr->null_default;
        }

        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < size(src_v); ++i) {
          if(get<1>(coord_v(i)) < n)
          {
            dst_v(i) = filled_vec;
          }
        }
      }
    }

    template <class ElementAccumulator, int FragmentSize>
    CUTLASS_DEVICE auto // returns an Array
    visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
          Array<ElementAccumulator, FragmentSize> const& frg_acc) {
      Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
      return rRow_frg(column_idx);
    }
  };

  template <class ProblemShape>
  CUTLASS_DEVICE auto
  get_callbacks(
    gemm::GemmCoord threadblock_tile_offset,
    int thread_idx,
    ProblemShape problem_shape
  ) {
    Tensor mRow = make_tensor(
      make_gmem_ptr(params_ptr->ptr_row),
      problem_shape,
      params_ptr->dRow);

    // VECTOR, FRAGMENT_COLUMN
    Tensor tC_gRow = recast<VecType>(
      ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
    )(_,_,_0{},_0{},_0{},_0{});
    Tensor tC_rRow = make_tensor_like(tC_gRow);

    // Generate the pred tensor
    Tensor cRow = make_identity_tensor(mRow.shape());
    Tensor tC_cRow = outer_partition(
      ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
      Shape<Int<VecLength>>{},
      (_0{})
    );

    return Callbacks<
      decltype(tC_gRow), decltype(tC_rRow),
      decltype(tC_cRow), ProblemShape>(
      cute::move(tC_gRow),
      cute::move(tC_rRow),
      cute::move(tC_cRow),
      problem_shape,
      params_ptr
    );
  }

};

/////////////////////////////////////////////////////////////////////////////////////////////////

// Column vector broadcast
template<
  class ThreadMap,
  class Element,
  class StrideMNL = Stride<_1,_0,_0>
>
struct VisitorColOrScalarBroadcast {

  struct Arguments {
    Element const* ptr_col = nullptr;
    Element null_default = Element(0);
    StrideMNL dCol = {};
  };

  using Params = Arguments;

  template <class ProblemShape>
  static constexpr Params
  to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
    return args;
  }

  template <class ProblemShape>
  static size_t
  get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
    return 0;
  }

  struct SharedStorage { };

  // Global load type
  static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
  using VecType = uint_bit_t<cute::min(128, vec_bits)>;
  static int constexpr VecLength = sizeof(VecType) / sizeof(Element);

  CUTLASS_HOST_DEVICE
  VisitorColOrScalarBroadcast() { }

  CUTLASS_HOST_DEVICE
  VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
    : params_ptr(&params) { }

  Params const* params_ptr;

  template <class GTensor, class RTensor, class CTensor, class ProblemShape>
  struct Callbacks : EmptyCallbacks {
    CUTLASS_DEVICE
    Callbacks(
      GTensor&& tC_gCol,
      RTensor&& tC_rCol,
      CTensor&& tC_cCol,
      ProblemShape problem_shape,
      Params const* params_ptr
    ):
      tC_gCol(cute::forward<GTensor>(tC_gCol)),
      tC_rCol(cute::forward<RTensor>(tC_rCol)),
      tC_cCol(cute::forward<CTensor>(tC_cCol)),
      m(get<0>(problem_shape)),
      params_ptr(params_ptr) { }

    GTensor tC_gCol;
    RTensor tC_rCol;
    CTensor tC_cCol;
    Params const* params_ptr;
    int m;

    // This function is modified from VisitorColBroadcast
    CUTLASS_DEVICE void
    begin_epilogue() {
      clear(tC_rCol);

      Tensor pred = make_tensor<bool>(shape(tC_gCol));
      CUTLASS_PRAGMA_UNROLL
      for (int i = 0; i < size(pred); ++i) {
        pred(i) = get<0>(tC_cCol(i)) < m;
      }

      if (params_ptr->ptr_col) {
        // In this case we are loading from a column vector and broadcasting
        copy_if(pred, tC_gCol, tC_rCol);
      } else {
        // In this case we are loading from a scalar and broadcasting
        auto dst_v = filter(tC_rCol);

        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < size(dst_v); ++i) {
          if(pred(i)){
             dst_v(i) = params_ptr->null_default;
          }
        }
      }
    }

    template <class ElementAccumulator, int FragmentSize>
    CUTLASS_DEVICE auto // returns an Array
    visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
          Array<ElementAccumulator, FragmentSize> const& frg_acc) {
      Array<Element, FragmentSize> frg_col;
      frg_col.fill(tC_rCol(row_idx,iter_idx));
      return frg_col;
    }
  };

  template <class ProblemShape>
  CUTLASS_DEVICE auto
  get_callbacks(
    gemm::GemmCoord threadblock_tile_offset,
    int thread_idx,
    ProblemShape problem_shape
  ) {
    Tensor mCol = make_tensor(
      make_gmem_ptr(params_ptr->ptr_col),
      problem_shape,
      params_ptr->dCol);

    // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
    Tensor tC_gCol = group_modes<1,4>(
      ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
    Tensor tC_rCol = make_tensor_like(tC_gCol);

    // Generate the pred tensor
    Tensor cCol = make_identity_tensor(mCol.shape());
    Tensor tC_cCol = group_modes<1,4>(
      ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));

    return Callbacks<
      decltype(tC_gCol), decltype(tC_rCol),
      decltype(tC_cCol), ProblemShape>(
      cute::move(tC_gCol),
      cute::move(tC_rCol),
      cute::move(tC_cCol),
      problem_shape,
      params_ptr
    );
  }
};

}
