/*
 * Copyright (c) 2024 by SageAttention team.
 *
 * Inspired by CUTLASS, https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/numeric_conversion.h
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda/pipeline>

#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400)
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890))
#define FP8_CAST_ENABLED
#endif
#endif

#if defined(__CUDA_ARCH__)
#define RUNTIME_ASSERT(x) __brkpt()
#else
#include <assert.h>
#define RUNTIME_ASSERT(x) assert(0 && x)
#endif

__device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0, float *source1)
{
#ifdef FP8_CAST_ENABLED
  asm volatile( \
      "{\n" \
      ".reg .b16 lo;\n" \
      ".reg .b16 hi;\n" \
      "cvt.rn.satfinite.e4m3x2.f32   lo, %2, %1;\n" \
      "cvt.rn.satfinite.e4m3x2.f32   hi, %4, %3;\n" \
      "mov.b32 %0, {lo, hi};\n" \
      "}" \
      : "=r"(dest[0]) : "f"(source0[0]), "f"(source0[1]), "f"(source1[0]), "f"(source1[1]));
#else
  RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction");
#endif
}

__device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0, float *source1)
{
#ifdef FP8_CAST_ENABLED
  asm volatile( \
      "{\n" \
      ".reg .b16 lo;\n" \
      ".reg .b16 hi;\n" \
      "cvt.rn.satfinite.e5m2x2.f32   lo, %2, %1;\n" \
      "cvt.rn.satfinite.e5m2x2.f32   hi, %4, %3;\n" \
      "mov.b32 %0, {lo, hi};\n" \
      "}" \
      : "=r"(dest[0]) : "f"(source0[0]), "f"(source1[1]), "f"(source1[0]), "f"(source1[1]));
#else
  RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction");
#endif
}

__device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *source0, uint32_t *source1)
{
#ifdef FP8_CAST_ENABLED
  asm volatile( \
      "{\n" \
      ".reg .b16 lo;\n" \
      ".reg .b16 hi;\n" \
      "cvt.rn.satfinite.e4m3x2.f16x2   lo, %1;\n" \
      "cvt.rn.satfinite.e4m3x2.f16x2   hi, %2;\n" \
      "mov.b32 %0, {lo, hi};\n" \
      "}" \
      : "=r"(dest[0]) : "r"(source0[0]), "r"(source1[0]));
#else
  RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction");
#endif
}

__device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *source0, uint32_t *source1)
{
#ifdef FP8_CAST_ENABLED
  asm volatile( \
      "{\n" \
      ".reg .b16 lo;\n" \
      ".reg .b16 hi;\n" \
      "cvt.rn.satfinite.e5m2x2.f16x2   lo, %1;\n" \
      "cvt.rn.satfinite.e5m2x2.f16x2   hi, %2;\n" \
      "mov.b32 %0, {lo, hi};\n" \
      "}" \
      : "=r"(dest[0]) : "r"(source0[0]), "r"(source1[0]));
#else
  RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction");
#endif
}

__device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source)
{
#ifdef FP8_CAST_ENABLED
  asm volatile( \
      "{\n" \
      ".reg .b16 lo, hi;\n" \
      "mov.b32 {lo, hi}, %2;\n" \
      "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \
      "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \
      "}\n" : "=r"(dest0[0]), "=r"(dest1[0]) : "r"(source[0]));
#else
  RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction");
#endif
}

__device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source)
{
#ifdef FP8_CAST_ENABLED
  asm volatile( \
      "{\n" \
      ".reg .b16 lo, hi;\n" \
      "mov.b32 {lo, hi}, %2;\n" \
      "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \
      "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \
      "}\n" : "=r"(dest0[0]), "=r"(dest1[0]) : "r"(source[0]));
#else
  RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction");
#endif
}

__device__ __forceinline__ int8_t float_to_int8_rn(float x)
{
    uint32_t dst;
    asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
    return reinterpret_cast<const int8_t&>(dst);
}

__device__ __forceinline__ uint32_t floatx4_to_int8x4_rn(float4 x) {
    uint32_t dst;
    asm volatile(
        "{\n\t"
        "  .reg .b32 t;\n\t"
        "  cvt.rni.sat.s8.f32 t, %1;\n\t"
        "  mov.b32 %0, t;\n\t"
        "  cvt.rni.sat.s8.f32 t, %2;\n\t"
        "  bfi.b32 %0, t, %0, 8, 8;\n\t"
        "  cvt.rni.sat.s8.f32 t, %3;\n\t"
        "  bfi.b32 %0, t, %0, 16, 8;\n\t"
        "  cvt.rni.sat.s8.f32 t, %4;\n\t"
        "  bfi.b32 %0, t, %0, 24, 8;\n\t"
        "}"
        : "=r"(dst)
        : "f"(x.x), "f"(x.y), "f"(x.z), "f"(x.w)
    );
    return dst;
}