/**
 * Copyright (c) 2020 xxx Inc.
 * File              : yuv.cu
 * Author            : 
 * Date              : 2020-05-07
 * Last Modified Date: 2020-05-07
 * Last Modified By  : 
 */
#include "yuv.h"

namespace vf {

template <typename T1, typename T2, int sPixFmt, int dPixFmt>
__global__ void YUV2RGBKernel(const PixelDescr<T1>* src, int dst_N,
                              const CoordMapParam* mparams,
                              const PixelDescr<T2>** dsts, bool swap) {
  int dst_idx = blockIdx.z;
  // minus 1, because in each thead, the right and bottom line are re-computed
  // for synchronization
  int dx = blockIdx.x * blockDim.x + threadIdx.x;
  int dy = blockIdx.y * blockDim.y + threadIdx.y;
  const PixelDescr<T2>* dst = dsts[dst_idx];
  if (dx >= dst->w || dy >= dst->h || dst_idx >= dst_N) return;

  YUV2RGBConvFunc<T2, sPixFmt>(dst_idx, dx, dy, src, mparams + dst_idx,
                               dsts[dst_idx], swap);
}

template <typename T1, typename T2, int sPixFmt, int dPixFmt>
struct PixelConvert<kGPU, T1, T2, sPixFmt, dPixFmt> {
  inline static void Map(TStream stream, const PixelDescr<T1>* src, int dst_N,
                         int dst_h, int dst_w, const CoordMapParam* mparams,
                         PixelDescr<T2> const** dsts) {
    const dim3 blockDim(32, 16, 1);
    const dim3 gridDim((dst_w + blockDim.x - 1) / blockDim.x,
                       (dst_h + blockDim.y - 1) / blockDim.y, dst_N);

    YUV2RGBKernel<T1, T2, sPixFmt, dPixFmt>
        <<<gridDim, blockDim, 0, cudaStream_t(stream)>>>(src, dst_N, mparams,
                                                         dsts, isbgr(dPixFmt));

    CUDA_POST_KERNEL_CHECK(YUV2RGBKernel);
  }
};

// yuv420p - bgr source images
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_YUV420P, VF_PIX_FMT_BGR);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_YUV420P, VF_PIX_FMT_RGB);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_YUV420P, VF_PIX_FMT_BGRPlanar);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_YUV420P, VF_PIX_FMT_RGBPlanar);

// yuv420p - bgra source images
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_YUV420P, VF_PIX_FMT_BGRA);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_YUV420P, VF_PIX_FMT_RGBA);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_YUV420P, VF_PIX_FMT_BGRAPlanar);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_YUV420P, VF_PIX_FMT_RGBAPlanar);

// nv12 - bgr source images
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV12, VF_PIX_FMT_BGR);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV12, VF_PIX_FMT_RGB);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV12, VF_PIX_FMT_BGRPlanar);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV12, VF_PIX_FMT_RGBPlanar);

// nv12 - bgra source images
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV12, VF_PIX_FMT_BGRA);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV12, VF_PIX_FMT_RGBA);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV12, VF_PIX_FMT_BGRAPlanar);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV12, VF_PIX_FMT_RGBAPlanar);

// nv21 - bgr planar source images
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV21, VF_PIX_FMT_BGR);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV21, VF_PIX_FMT_RGB);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV21, VF_PIX_FMT_BGRPlanar);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV21, VF_PIX_FMT_RGBPlanar);

// nv21 - bgra planar source images
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV21, VF_PIX_FMT_BGRA);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV21, VF_PIX_FMT_RGBA);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV21, VF_PIX_FMT_BGRAPlanar);
RegisterYUVCvtFunc(kGPU, VF_PIX_FMT_NV21, VF_PIX_FMT_RGBAPlanar);

}  // namespace vf
