/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
#pragma once

#include <cutlass/cutlass.h>
#include <cutlass/device_kernel.h>
#include <cutlass/trace.h>

////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace device {

/////////////////////////////////////////////////////////////////////////////////////////////////

/*
    This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
    It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
    and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.

    Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
          that feature at the moment.
  */

template <typename GemmKernel_>
class GemmUniversalBaseCompat {
 public:
  using GemmKernel = GemmKernel_;
  using ThreadblockShape = typename GemmKernel::Mma::Shape;

  using ElementA = typename GemmKernel::ElementA;
  using LayoutA = typename GemmKernel::LayoutA;
  using TensorRefA = TensorRef<ElementA const, LayoutA>;
  static ComplexTransform const kTransformA = GemmKernel::kTransformA;

  using ElementB = typename GemmKernel::ElementB;
  using LayoutB = typename GemmKernel::LayoutB;
  using TensorRefB = TensorRef<ElementB const, LayoutB>;
  static ComplexTransform const kTransformB = GemmKernel::kTransformB;

  using ElementC = typename GemmKernel::ElementC;
  using LayoutC = typename GemmKernel::LayoutC;
  using TensorRefC = TensorRef<ElementC const, LayoutC>;
  using TensorRefD = TensorRef<ElementC, LayoutC>;

  using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;

  using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
  using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
  using Operator = typename GemmKernel::Operator;

  /// Argument structure
  using Arguments = typename GemmKernel::Arguments;

 protected:
  /// Kernel parameters object
  typename GemmKernel::Params params_;

 protected:
  /// Private helper to obtain the grid dimensions with fix-up for split-K
  static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) {
    // Determine grid shape
    ThreadblockSwizzle threadblock_swizzle;

    grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);

    gemm_k_size = args.problem_size.k();

    if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
      int const kAlignK =
          const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);

      gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);

      if (gemm_k_size) {
        grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
      }
    }
  }

 public:
  /// Constructs the GEMM.
  GemmUniversalBaseCompat() {}

  /// Determines whether the GEMM can execute the given problem.
  static Status can_implement(Arguments const& args) {
    // Determine grid shape
    cutlass::gemm::GemmCoord grid_tiled_shape;
    int gemm_k_size = 0;

    get_grid_shape_(grid_tiled_shape, gemm_k_size, args);

    ThreadblockSwizzle threadblock_swizzle;
    dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);

    uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);

    if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
      return Status::kErrorInvalidProblem;
    }

    return GemmKernel::can_implement(args);
  }

  /// Gets the workspace size
  static size_t get_workspace_size(Arguments const& args) {
    CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");

    size_t workspace_bytes = 0;

    // Determine grid shape
    cutlass::gemm::GemmCoord grid_tiled_shape;
    int gemm_k_size = 0;

    get_grid_shape_(grid_tiled_shape, gemm_k_size, args);

    if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
      // Split-K parallel always requires a temporary workspace
      workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
    } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
      // Serial split-K only requires a temporary workspace if the number of partitions along the
      // GEMM K dimension is greater than one.
      workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
    }

    CUTLASS_TRACE_HOST("  workspace_bytes: " << workspace_bytes);

    workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);

    return workspace_bytes;
  }

  /// Computes the grid shape
  static dim3 get_grid_shape(Arguments const& args) {
    CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");

    ThreadblockSwizzle threadblock_swizzle;

    cutlass::gemm::GemmCoord grid_tiled_shape;
    int gemm_k_size = 0;

    get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
    dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);

    CUTLASS_TRACE_HOST(
        "  grid_tiled_shape: " << grid_tiled_shape << "\n"
                               << "  result = {" << result << "}");

    return result;
  }

  /// Computes the maximum number of active blocks per multiprocessor
  static int maximum_active_blocks(int smem_capacity = -1) {
    CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");

    int max_active_blocks = -1;
    int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

    CUTLASS_TRACE_HOST("  smem_size: " << smem_size << " bytes");

    if (smem_size <= (48 << 10)) {
      cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
          &max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);

      if (result == cudaSuccess) {
        CUTLASS_TRACE_HOST("  max_active_blocks: " << max_active_blocks);
        return max_active_blocks;
      }
    } else {
      // Query assuming zero shared memory then compute occupancy limit based on SMEM
      cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
          &max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);

      if (result != cudaSuccess) {
        CUTLASS_TRACE_HOST(
            "  cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));

        return -1;
      }

      if (smem_capacity < 0) {
        int device_idx = 0;
        result = cudaGetDevice(&device_idx);

        if (result != cudaSuccess) {
          return -1;
        }

        cudaDeviceProp properties;
        result = cudaGetDeviceProperties(&properties, device_idx);

        if (result != cudaSuccess) {
          return -1;
        }

        smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
      }

      int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);

      CUTLASS_TRACE_HOST("  occupancy: " << occupancy);

      return occupancy;
    }

    CUTLASS_TRACE_HOST("  returning internal error");

    return -1;
  }

  /// Initializes GEMM state from arguments.
  Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
    CUTLASS_TRACE_HOST(
        "GemmUniversalBaseCompat::initialize() - workspace " << workspace
                                                             << ", stream: " << (stream ? "non-null" : "null"));

    size_t workspace_bytes = get_workspace_size(args);

    CUTLASS_TRACE_HOST("  workspace_bytes: " << workspace_bytes);

    if (workspace_bytes) {
      if (!workspace) {
        CUTLASS_TRACE_HOST("  error: device workspace must not be null");

        return Status::kErrorWorkspaceNull;
      }

      if (args.mode == GemmUniversalMode::kGemm) {
        CUTLASS_TRACE_HOST("  clearing device workspace");
        cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);

        if (result != cudaSuccess) {
          CUTLASS_TRACE_HOST("  cudaMemsetAsync() returned error " << cudaGetErrorString(result));

          return Status::kErrorInternal;
        }
      }
    }

    // Get CUDA grid shape
    cutlass::gemm::GemmCoord grid_tiled_shape;
    int gemm_k_size = 0;

    get_grid_shape_(grid_tiled_shape, gemm_k_size, args);

    // Initialize the Params structure
    params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));

    // Specify shared memory capacity for kernel.
    int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

    if (smem_size >= (48 << 10)) {
      cudaError_t result =
          cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);

      if (result != cudaSuccess) {
        return Status::kErrorInternal;
      }
    }

    return Status::kSuccess;
  }

  /// Lightweight update given a subset of arguments
  Status update(Arguments const& args, void* workspace = nullptr) {
    CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);

    size_t workspace_bytes = get_workspace_size(args);

    if (workspace_bytes && !workspace) {
      return Status::kErrorWorkspaceNull;
    }

    params_.update(args, workspace);

    return Status::kSuccess;
  }

  /// Runs the kernel using initialized state.
  Status run(cudaStream_t stream = nullptr) {
    CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");

    //
    // Configure grid and block dimensions
    //

    ThreadblockSwizzle threadblock_swizzle;

    dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
    dim3 block(GemmKernel::kThreadCount, 1, 1);

    int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

    //
    // Launch kernel
    //

    CUTLASS_TRACE_HOST("  grid: (" << grid << "),  block: (" << block << "),  SMEM: " << smem_size << " bytes");

    // Launch
    cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);

    //
    // Query for errors
    //
    cudaError_t result = cudaGetLastError();

    if (result != cudaSuccess) {
      CUTLASS_TRACE_HOST("  grid launch failed with error " << cudaGetErrorString(result));
      return Status::kErrorInternal;
    }

    return Status::kSuccess;
  }

  /// Runs the kernel using initialized state.
  Status operator()(cudaStream_t stream = nullptr) {
    return run(stream);
  }

  /// Runs the kernel using initialized state.
  Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
    Status status = initialize(args, workspace, stream);

    if (status == Status::kSuccess) {
      status = run(stream);
    }

    return status;
  }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace device
}  // namespace gemm
}  // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
