#pragma once

#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass_extensions/epilogue/threadblock/epilogue_quant.h"
////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
          typename OutputOp_, int ElementsPerAccess, bool ScatterD = false,
          typename PermuteDLayout = layout::NoPermute, bool is_quartet=true, int RotationSize=32>
struct DefaultEpilogueTensorOpQuantMx
    : public DefaultEpilogueTensorOp<Shape_, WarpMmaTensorOp_, PartitionsK,
                                     OutputOp_, ElementsPerAccess, ScatterD,
                                     PermuteDLayout> {
  using OutputOp = OutputOp_;
  using DefaultEpilogueTensorOp =
      DefaultEpilogueTensorOp<Shape_,
                              WarpMmaTensorOp_,
                              PartitionsK,
                              OutputOp_,
                              ElementsPerAccess,
                              ScatterD,
                              PermuteDLayout>;

  using Epilogue = cutlass::epilogue::threadblock::EpilogueQuantMx<
      typename DefaultEpilogueTensorOp::Shape,
      typename DefaultEpilogueTensorOp::WarpMmaTensorOp,
      DefaultEpilogueTensorOp::kPartitionsK,
      typename DefaultEpilogueTensorOp::OutputTileIterator,
      typename DefaultEpilogueTensorOp::AccumulatorFragmentIterator,
      typename DefaultEpilogueTensorOp::WarpTileIterator,
      typename DefaultEpilogueTensorOp::SharedLoadIterator, OutputOp,
      typename DefaultEpilogueTensorOp::Padding,
      DefaultEpilogueTensorOp::kFragmentsPerIteration,
      is_quartet, RotationSize>;
};

template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
          typename OutputOp_, int ElementsPerAccess, bool ScatterD = false,
          typename PermuteDLayout = layout::NoPermute, bool is_quartet=true, int RotationSize=32>
struct DefaultEpilogueTensorOpQuantWushMx
    : public DefaultEpilogueTensorOp<Shape_, WarpMmaTensorOp_, PartitionsK,
                                     OutputOp_, ElementsPerAccess, ScatterD,
                                     PermuteDLayout> {
  using OutputOp = OutputOp_;
  using DefaultEpilogueTensorOp =
      DefaultEpilogueTensorOp<Shape_,
                              WarpMmaTensorOp_,
                              PartitionsK,
                              OutputOp_,
                              ElementsPerAccess,
                              ScatterD,
                              PermuteDLayout>;

  using Epilogue = cutlass::epilogue::threadblock::EpilogueQuantWushMx<
      typename DefaultEpilogueTensorOp::Shape,
      typename DefaultEpilogueTensorOp::WarpMmaTensorOp,
      DefaultEpilogueTensorOp::kPartitionsK,
      typename DefaultEpilogueTensorOp::OutputTileIterator,
      typename DefaultEpilogueTensorOp::AccumulatorFragmentIterator,
      typename DefaultEpilogueTensorOp::WarpTileIterator,
      typename DefaultEpilogueTensorOp::SharedLoadIterator, OutputOp,
      typename DefaultEpilogueTensorOp::Padding,
      DefaultEpilogueTensorOp::kFragmentsPerIteration,
      is_quartet, RotationSize>;
};

template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
          typename OutputOp_, int ElementsPerAccess, bool ScatterD = false,
          typename PermuteDLayout = layout::NoPermute>
struct DefaultEpilogueTensorOpQuantMxMask
    : public DefaultEpilogueTensorOp<Shape_, WarpMmaTensorOp_, PartitionsK,
                                     OutputOp_, ElementsPerAccess, ScatterD,
                                     PermuteDLayout> {
  using OutputOp = OutputOp_;
  using DefaultEpilogueTensorOp =
      DefaultEpilogueTensorOp<Shape_,
                              WarpMmaTensorOp_,
                              PartitionsK,
                              OutputOp_,
                              ElementsPerAccess,
                              ScatterD,
                              PermuteDLayout>;

  using Epilogue = cutlass::epilogue::threadblock::EpilogueQuantMxMask<
      typename DefaultEpilogueTensorOp::Shape,
      typename DefaultEpilogueTensorOp::WarpMmaTensorOp,
      DefaultEpilogueTensorOp::kPartitionsK,
      typename DefaultEpilogueTensorOp::OutputTileIterator,
      typename DefaultEpilogueTensorOp::AccumulatorFragmentIterator,
      typename DefaultEpilogueTensorOp::WarpTileIterator,
      typename DefaultEpilogueTensorOp::SharedLoadIterator, OutputOp,
      typename DefaultEpilogueTensorOp::Padding,
      DefaultEpilogueTensorOp::kFragmentsPerIteration>;
};

template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
          typename OutputOp_, int ElementsPerAccess, bool ScatterD = false,
          typename PermuteDLayout = layout::NoPermute, bool is_quartet=true, int RotationSize=16>
struct DefaultEpilogueTensorOpQuantNv
    : public DefaultEpilogueTensorOp<Shape_, WarpMmaTensorOp_, PartitionsK,
                                     OutputOp_, ElementsPerAccess, ScatterD,
                                     PermuteDLayout> {
  using OutputOp = OutputOp_;
  using DefaultEpilogueTensorOp =
      DefaultEpilogueTensorOp<Shape_,
                              WarpMmaTensorOp_,
                              PartitionsK,
                              OutputOp_,
                              ElementsPerAccess,
                              ScatterD,
                              PermuteDLayout>;

  using Epilogue = cutlass::epilogue::threadblock::EpilogueQuantNv<
      typename DefaultEpilogueTensorOp::Shape,
      typename DefaultEpilogueTensorOp::WarpMmaTensorOp,
      DefaultEpilogueTensorOp::kPartitionsK,
      typename DefaultEpilogueTensorOp::OutputTileIterator,
      typename DefaultEpilogueTensorOp::AccumulatorFragmentIterator,
      typename DefaultEpilogueTensorOp::WarpTileIterator,
      typename DefaultEpilogueTensorOp::SharedLoadIterator, OutputOp,
      typename DefaultEpilogueTensorOp::Padding,
      DefaultEpilogueTensorOp::kFragmentsPerIteration,
      is_quartet, RotationSize>; //TODO: remove/add?
};

////////////////////////////////////////////////////////////////////////////////

}  // namespace threadblock
}  // namespace epilogue
}  // namespace cutlass

////////////////////////////////////////////////////////////////////////////////
