/***************************************************************************************************
 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/* \file
   \brief Helper functions for mapping CUTLASS concepts to cuDNN.
*/
#if CUTLASS_ENABLE_CUDNN

#include <stdexcept>

#include "cudnn_helpers.h"

namespace cutlass {
namespace profiler {

/////////////////////////////////////////////////////////////////////////////////////////////////
/// Converts a cuDNN status to cutlass::Status
Status get_cutlass_status(cudnnStatus_t cudnn_status) {

  if (cudnn_status == CUDNN_STATUS_SUCCESS) {
    return Status::kSuccess;
  }
  else if (cudnn_status == CUDNN_STATUS_INVALID_VALUE) {
    return Status::kErrorInvalidProblem;
  }
  if (cudnn_status == CUDNN_STATUS_NOT_SUPPORTED) {
    return Status::kErrorNotSupported;
  }
  return Status::kErrorInternal;
}

/// Converts a cuDNN status to cutlass::profiler::Disposition
Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status) {

  if (cudnn_status == CUDNN_STATUS_INVALID_VALUE) {
    return Disposition::kInvalidProblem;
  }
  else if (cudnn_status == CUDNN_STATUS_NOT_SUPPORTED) {
    return Disposition::kNotSupported;
  }
  return Disposition::kFailed;
}

/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception
Status checkCudnnErr(cudnnStatus_t cudnn_status) {
  Status cutlass_status = get_cutlass_status(cudnn_status);
  if(cutlass_status != Status::kSuccess) {
    throw std::runtime_error("checkCudnnErr failed");
  }
  return cutlass_status;
}

/// Maps a CUTLASS conv mode to a cuDNN cudnnConvolutionMode_t
bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode) {
  switch (conv_mode) {
    case conv::Mode::kCrossCorrelation:
      cudnn_conv_mode = CUDNN_CROSS_CORRELATION;
      return true;
    case conv::Mode::kConvolution:
      cudnn_conv_mode = CUDNN_CONVOLUTION;
      return true;
    default: break;
  }
  return false;
}

/// Maps a CUTLASS tensor layout to a cuDNN cudnnTensorFormat_t
bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout) {
  switch (layout) {
    // cudnn uses the same enum for TensorNC*HW along nDim (ConvDescription::conv_dim)
    case library::LayoutTypeID::kTensorNCHW:
    case library::LayoutTypeID::kTensorNCDHW:
      cudnn_layout = CUDNN_TENSOR_NCHW;
      return true;
    case library::LayoutTypeID::kTensorNHWC:
    case library::LayoutTypeID::kTensorNDHWC:
      cudnn_layout = CUDNN_TENSOR_NHWC;
      return true;
    default: break;
  }
  return false;
}

/// Maps a CUTLASS numeric type to a cuDNN cudnnDataType_t
bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type) {
  switch (element_type) {
    case library::NumericTypeID::kF16:
      cudnn_element_type = CUDNN_DATA_HALF;
      return true;

    case library::NumericTypeID::kF32:
      cudnn_element_type = CUDNN_DATA_FLOAT;
      return true;
    
    case library::NumericTypeID::kF64: 
      cudnn_element_type = CUDNN_DATA_DOUBLE;
      return true;
    
    case library::NumericTypeID::kS2: 
      break;
  
    case library::NumericTypeID::kS4: 
      break;
  
    case library::NumericTypeID::kS8: 
      cudnn_element_type = CUDNN_DATA_INT8;
      return true;

    case library::NumericTypeID::kS16: 
      break;
 
    case library::NumericTypeID::kS32: 
      cudnn_element_type = CUDNN_DATA_INT32;
      return true;

    case library::NumericTypeID::kS64: 
      break;

    case library::NumericTypeID::kU2: 
      break;
  
    case library::NumericTypeID::kU4: 
      break;
  
    case library::NumericTypeID::kU8: 
      cudnn_element_type = CUDNN_DATA_UINT8;
      return true;

    case library::NumericTypeID::kU16: 
      break;
    
    case library::NumericTypeID::kU32: 
      break;
    
    case library::NumericTypeID::kU64: 
      break;

    case library::NumericTypeID::kB1: 
      break;
  
    case library::NumericTypeID::kInvalid:
  
    default: 
      break;
  }

  return false;
}

/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type
bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc) {

  switch (conv_desc.tile_description.math_instruction.opcode_class) {

    case library::OpcodeClassID::kTensorOp:
    {
      cudnn_math_type = CUDNN_TENSOR_OP_MATH;

      library::MathOperationID math_op = conv_desc.tile_description.math_instruction.math_operation;
      
      // Allow conversion on input data type for fast math operations
      if (math_op == library::MathOperationID::kMultiplyAddFastF16 || 
        math_op == library::MathOperationID::kMultiplyAddFastBF16) 
      {
        cudnn_math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
      }

      return true;
    }
    case library::OpcodeClassID::kSimt:
      return false;
  }

  return false;
}

/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue)
float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src) {

  switch (type) {
    case library::NumericTypeID::kF16:
    {
      return float(*(static_cast<half_t const*>(src)));
    }
    case library::NumericTypeID::kF32:
    {
      return float(*(static_cast<float const*>(src)));
    }
    case library::NumericTypeID::kS32:
    {
      return float(*(static_cast<int const*>(src)));
    }
    default:
      throw std::runtime_error("Data type handled in cast_compute_type_to_float");
  }
}

/////////////////////////////////////////////////////////////////////////////////////////////////
/// Returns a status if cuDNN can satisfy a particular Conv2d description
Status cudnn_satisfies(
  library::ConvDescription const &desc, 
  library::Conv2dConfiguration const &configuration) {
  
  auto const &a_tensor = desc.A;
  auto const &b_tensor = desc.B;
  auto const &c_tensor = desc.C;
  auto const &math_instruction = desc.tile_description.math_instruction;

  if(a_tensor.element != b_tensor.element) {
    return Status::kErrorInvalidDataType;
  }

  ////////////////////////  Convolution output dimensions p and q ///////////////////////
  // Cutlass convolutions support arbitrary output dimensions and not constriant by    //
  // input, filter, padding, striding, dilation sizes.                                 //
  // cuDNN sets the output dimensions (p, q) using following equations:                //
  //                                                                                   //
  // output = div_up(input + 2 * pad - ((filter - 1) * dilation + 1) + 1, stride)      //
  // where; div_up(a, b) : (a - 1)/b + 1                                               //
  //                                                                                   //
  // Before launching cudnn verification or profiling check that output p and q        //
  // dimensions are cuDNN compliant.                                                   //
  //                                                                                   //
  // If user sets output p and q which do not follow above constraints, cutlass conv,  //
  // host reference, device reference can run. However, cudnn convolution returns      //
  // "Invalid problem"                                                                 //
  //                                                                                   //
  ///////////////////////////////////////////////////////////////////////////////////////

  // check conv output dimension p for cudnn
  int cudnn_output_p = 
  (
    (
      configuration.problem_size.H + 
      2 * configuration.problem_size.pad_h - 
      ((configuration.problem_size.R - 1) * 
      configuration.problem_size.dilation_h + 1)
    ) / 
    (configuration.problem_size.stride_h) 
    + 1
  );

  if (cudnn_output_p != configuration.problem_size.P) {
    return Status::kErrorInvalidProblem;
  }

  // check conv output dimension q for cudnn
  int cudnn_output_q = 
  (
    (
      configuration.problem_size.W + 
      2 * configuration.problem_size.pad_w - 
      ((configuration.problem_size.S - 1) * 
      configuration.problem_size.dilation_w + 1)
    ) / 
    (configuration.problem_size.stride_w) 
    + 1
  );

  if (cudnn_output_q != configuration.problem_size.Q) {
    return Status::kErrorInvalidProblem;
  }
  //////////////////////////////////////////////////////////////////////////////////////

  // conv operator with input=FP16, accumulator=FP32, output=FP32 datatype 
  if (a_tensor.element ==  library::NumericTypeID::kF16 && 
      b_tensor.element ==  library::NumericTypeID::kF16 &&
      math_instruction.element_accumulator == library::NumericTypeID::kF32 &&
      c_tensor.element == library::NumericTypeID::kF32
      ) {

    return Status::kErrorNotSupported;
  }

  if (a_tensor.element ==  library::NumericTypeID::kBF16 || 
      b_tensor.element ==  library::NumericTypeID::kBF16 ||
      c_tensor.element == library::NumericTypeID::kBF16
      ) {

    return Status::kErrorNotSupported;
  }

  // TF32 input not supported in cuDNN
  if (a_tensor.element ==  library::NumericTypeID::kTF32 || 
      b_tensor.element ==  library::NumericTypeID::kTF32 ||
      c_tensor.element == library::NumericTypeID::kTF32
      ) {

    return Status::kErrorNotSupported;
  }

  if (a_tensor.element ==  library::NumericTypeID::kS8 || 
      b_tensor.element ==  library::NumericTypeID::kS8 ||
      c_tensor.element == library::NumericTypeID::kS8
      ) {

    return Status::kErrorNotSupported;
  }

  if (a_tensor.element ==  library::NumericTypeID::kU8 || 
      b_tensor.element ==  library::NumericTypeID::kU8 ||
      c_tensor.element == library::NumericTypeID::kU8
      ) {

    return Status::kErrorNotSupported;
  }

  if (a_tensor.element ==  library::NumericTypeID::kS4 || 
      b_tensor.element ==  library::NumericTypeID::kS4 ||
      c_tensor.element == library::NumericTypeID::kS4
      ) {

    return Status::kErrorNotSupported;
  }

  if (a_tensor.element ==  library::NumericTypeID::kU4 || 
      b_tensor.element ==  library::NumericTypeID::kU4 ||
      c_tensor.element == library::NumericTypeID::kU4
      ) {

    return Status::kErrorNotSupported;
  }

  return Status::kSuccess;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Returns a status if cuDNN can satisfy a particular Conv3d description
Status cudnn_satisfies(
  library::ConvDescription const &desc, 
  library::Conv3dConfiguration const &configuration) {
  
  auto const &a_tensor = desc.A;
  auto const &b_tensor = desc.B;
  auto const &c_tensor = desc.C;
  auto const &math_instruction = desc.tile_description.math_instruction;

  if(a_tensor.element != b_tensor.element) {
    return Status::kErrorInvalidDataType;
  }

  ////////////////////////  Convolution output dimensions p and q ///////////////////////
  // Cutlass convolutions support arbitrary output dimensions and not constriant by    //
  // input, filter, padding, striding, dilation sizes.                                 //
  // cuDNN sets the output dimensions (p, q) using following equations:                //
  //                                                                                   //
  // output = div_up(input + 2 * pad - ((filter - 1) * dilation + 1) + 1, stride)      //
  // where; div_up(a, b) : (a - 1)/b + 1                                               //
  //                                                                                   //
  // Before launching cudnn verification or profiling check that output p and q        //
  // dimensions are cuDNN compliant.                                                   //
  //                                                                                   //
  // If user sets output p and q which do not follow above constraints, cutlass conv,  //
  // host reference, device reference can run. However, cudnn convolution returns      //
  // "Invalid problem"                                                                 //
  //                                                                                   //
  ///////////////////////////////////////////////////////////////////////////////////////

  // check conv output dimension z for cudnn
  int cudnn_output_z = 
  (
    (
      configuration.problem_size.D + 
      2 * configuration.problem_size.pad_d - 
      ((configuration.problem_size.T - 1) * 
      configuration.problem_size.dilation_d + 1)
    ) / 
    (configuration.problem_size.stride_d) 
    + 1
  );

  if (cudnn_output_z != configuration.problem_size.Z) {
    return Status::kErrorInvalidProblem;
  }

  // check conv output dimension p for cudnn
  int cudnn_output_p = 
  (
    (
      configuration.problem_size.H + 
      2 * configuration.problem_size.pad_h - 
      ((configuration.problem_size.R - 1) * 
      configuration.problem_size.dilation_h + 1)
    ) / 
    (configuration.problem_size.stride_h) 
    + 1
  );

  if (cudnn_output_p != configuration.problem_size.P) {
    return Status::kErrorInvalidProblem;
  }

  // check conv output dimension q for cudnn
  int cudnn_output_q = 
  (
    (
      configuration.problem_size.W + 
      2 * configuration.problem_size.pad_w - 
      ((configuration.problem_size.S - 1) * 
      configuration.problem_size.dilation_w + 1)
    ) / 
    (configuration.problem_size.stride_w) 
    + 1
  );

  if (cudnn_output_q != configuration.problem_size.Q) {
    return Status::kErrorInvalidProblem;
  }
  //////////////////////////////////////////////////////////////////////////////////////

  // conv operator with input, accumulator, output datatype of (hss) are not supported 
  // in cuDNN
  if (a_tensor.element ==  library::NumericTypeID::kF16 && 
      b_tensor.element ==  library::NumericTypeID::kF16 &&
      math_instruction.element_accumulator == library::NumericTypeID::kF32 &&
      c_tensor.element == library::NumericTypeID::kF32
      ) {

    return Status::kErrorNotSupported;
  }

  if (a_tensor.element ==  library::NumericTypeID::kBF16 || 
      b_tensor.element ==  library::NumericTypeID::kBF16 ||
      c_tensor.element == library::NumericTypeID::kBF16
      ) {

    return Status::kErrorNotSupported;
  }

  if (a_tensor.element ==  library::NumericTypeID::kTF32 || 
      b_tensor.element ==  library::NumericTypeID::kTF32 ||
      c_tensor.element == library::NumericTypeID::kTF32
      ) {

    return Status::kErrorNotSupported;
  }

  if (a_tensor.element ==  library::NumericTypeID::kS8 || 
      b_tensor.element ==  library::NumericTypeID::kS8 ||
      c_tensor.element == library::NumericTypeID::kS8
      ) {

    return Status::kErrorNotSupported;
  }

  // S4 not supported in cuDNN 
  if (a_tensor.element ==  library::NumericTypeID::kS4 || 
      b_tensor.element ==  library::NumericTypeID::kS4 ||
      c_tensor.element == library::NumericTypeID::kS4
      ) {

    return Status::kErrorNotSupported;
  }

  return Status::kSuccess;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace profiler
} // namespace cutlass

#endif
