/**
 * Copyright (c) 2020 xxx Inc.
 * File              : image-preprocess.cu
 * Author            : 
 * Date              : 2020-05-07
 * Last Modified Date: 2020-05-07
 * Last Modified By  : 
 */
#include "image-preprocess.h"

#include <cxxutil/logging.h>
#include <cxxutil/strutil.h>

using namespace cxxutil;
using namespace std;

namespace vf {
namespace dl {

template <>
struct PreProcessFunc<kGPU, VF_PIX_FMT_None, kCopy> {
  inline static void Map(TStream stream, int N, const std::vector<int>& shape,
                         const float* src, float* dst, float scale,
                         const float* mean, const float* std) {
    int shape_dim = 1;
    for (size_t i = 0; i < shape.size(); i++) {
      shape_dim *= shape[i];
    }
    CuCHECK(cudaMemcpyAsync(dst, src, N * shape_dim * sizeof(float),
                            cudaMemcpyDefault, stream));
  }
};

RegisterPreProcessFunc(kGPU, VF_PIX_FMT_1D, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_2D, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_3D, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_4D, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_GRAY, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_BGR, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_RGB, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_BGRPlanar, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_RGBPlanar, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_BGRA, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_RGBA, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_BGRAPlanar, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_RGBAPlanar, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_YUV420P, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_NV12, VF_PIX_FMT_None, kCopy);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_NV21, VF_PIX_FMT_None, kCopy);

template <typename KFunc>
__global__ void ScaleMeanStdKernel(int N, int C, int HW, const float* src,
                                   float* dst, float scale, const float* mean,
                                   const float* std) {
  int pid = blockIdx.x * blockDim.x + threadIdx.x;
  int img_id = blockIdx.y * blockDim.y + threadIdx.y;
  if (pid >= HW || img_id >= N) return;

  KFunc::Map(img_id, pid, C, HW, src, dst, scale, mean, std);
}

template <typename KFunc>
inline void PreProcessFuncGPUT(TStream stream, int N, int C, int HW,
                               const float* src, float* dst, float scale,
                               const float* mean, const float* std) {
  const dim3 blockDim(32, 16, 1);
  const dim3 gridDim((HW + blockDim.x - 1) / blockDim.x,
                     (N + blockDim.y - 1) / blockDim.y, 1);

  ScaleMeanStdKernel<KFunc><<<gridDim, blockDim, 0, cudaStream_t(stream)>>>(
      N, C, HW, src, dst, scale, mean, std);

  CUDA_POST_KERNEL_CHECK(ScaleMeanStdKernel);
}

template <>
struct PreProcessFunc<kGPU, VF_PIX_FMT_BGR, kScaleMeanStd> {
  inline static void Map(TStream stream, int N, const std::vector<int>& shape,
                         const float* src, float* dst, float scale,
                         const float* mean, const float* std) {
    CHECK(shape.size() == 3 && shape[2] == 3)
        << "shape must be 3 dimensional and the last must be 3, but got "
        << to_string(shape);
    PreProcessFuncGPUT<ScaleMeanStdHWCKernel>(
        stream, N, shape[2], shape[0] * shape[1], src, dst, scale, mean, std);
  }
};

RegisterPreProcessFunc(kGPU, VF_PIX_FMT_BGR, VF_PIX_FMT_BGR, kScaleMeanStd);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_RGB, VF_PIX_FMT_BGR, kScaleMeanStd);

template <>
struct PreProcessFunc<kGPU, VF_PIX_FMT_BGRA, kScaleMeanStd> {
  inline static void Map(TStream stream, int N, const std::vector<int>& shape,
                         const float* src, float* dst, float scale,
                         const float* mean, const float* std) {
    CHECK(shape.size() == 3 && shape[2] == 4)
        << "shape must be 3 dimensional and the last must be 4, but got "
        << to_string(shape);
    PreProcessFuncGPUT<ScaleMeanStdHWCKernel>(
        stream, N, shape[2], shape[0] * shape[1], src, dst, scale, mean, std);
  }
};
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_BGRA, VF_PIX_FMT_BGRA, kScaleMeanStd);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_RGBA, VF_PIX_FMT_BGRA, kScaleMeanStd);

template <>
struct PreProcessFunc<kGPU, VF_PIX_FMT_BGRPlanar, kScaleMeanStd> {
  inline static void Map(TStream stream, int N, const std::vector<int>& shape,
                         const float* src, float* dst, float scale,
                         const float* mean, const float* std) {
    CHECK(shape.size() == 3 && shape[0] == 3)
        << "shape must be 3 dimensional and the first must be 3, but got "
        << to_string(shape);
    PreProcessFuncGPUT<ScaleMeanStdCHWKernel>(
        stream, N, shape[0], shape[1] * shape[2], src, dst, scale, mean, std);
  }
};
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_BGRPlanar, VF_PIX_FMT_BGRPlanar,
                       kScaleMeanStd);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_RGBPlanar, VF_PIX_FMT_BGRPlanar,
                       kScaleMeanStd);

template <>
struct PreProcessFunc<kGPU, VF_PIX_FMT_BGRAPlanar, kScaleMeanStd> {
  inline static void Map(TStream stream, int N, const std::vector<int>& shape,
                         const float* src, float* dst, float scale,
                         const float* mean, const float* std) {
    CHECK(shape.size() == 3 && shape[0] == 4)
        << "shape must be 3 dimensional and the first must be 4, but got "
        << to_string(shape);
    PreProcessFuncGPUT<ScaleMeanStdCHWKernel>(
        stream, N, shape[0], shape[1] * shape[2], src, dst, scale, mean, std);
  }
};
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_BGRAPlanar, VF_PIX_FMT_BGRAPlanar,
                       kScaleMeanStd);
RegisterPreProcessFunc(kGPU, VF_PIX_FMT_RGBAPlanar, VF_PIX_FMT_BGRAPlanar,
                       kScaleMeanStd);

}  // namespace dl
}  // namespace vf
