/**
 * Copyright (c) 2020 xxx Inc.
 * File              : nvdecoder-backend.cc
 * Author            : 
 * Date              : 2020-05-07
 * Last Modified Date: 2020-05-07
 * Last Modified By  : 
 */
#include "vf/io/decoder-backend.h"

#if USE_CUDA && ON_x86

#include <npp.h>

//
#include <cxxutil/logging.h>
#include <cxxutil/types.h>

#include "NvCodec/NvDecoder/NvDecoder.h"
#include "vf/devices/cuda.h"

using namespace cxxutil;
using namespace std;

namespace vf {
namespace io {

class StageContext {
 public:
  ~StageContext() {
    for (int i = 0; i < MAX_CHANNELS; ++i) {
      ResizeOutput(i, 0);
    }
  }

  void ResizeOutput(int dim, int alloc_len) {
    if (alloc_len == out_data_len[dim]) return;

    if (out_data_len[dim] > 0) {
      cuda::FreeCuda(out_data_internal[dim], dev_id);
      out_data_internal[dim] = nullptr;
      out_data_len[dim] = 0;
    }
    if (alloc_len > 0) {
      cuda::MallocCuda((void **)(&out_data_internal[dim]), alloc_len, dev_id);
      out_data_len[dim] = alloc_len;
    }
    out_data[dim] = out_data_internal[dim];
  }

 public:
  int dev_id = -1;
  uint8_t *out_data[MAX_CHANNELS] = {nullptr};
  int out_w[MAX_CHANNELS] = {0};
  int out_h[MAX_CHANNELS] = {0};
  int out_step[MAX_CHANNELS] = {0};

 private:
  uint8_t *out_data_internal[MAX_CHANNELS] = {nullptr};
  int out_data_len[MAX_CHANNELS] = {0};
};

class NVDecoderBackend : public DecoderBackend {
 public:
  using DecoderBackend::DecoderBackend;
  virtual ~NVDecoderBackend() { Close(); }

 public:
  int Init(AVStream *video_stream, AVCodec *codec, PixelFormat tgt_pix_fmt,
           int tgt_w, int tgt_h, int sample_rate = 1,
           int64_t timestamp_intervals = -1, bool with_mvs = false) override;
  void Close() override;
  int PutPacket(const AVPacket *packet, int64_t timestamp) override;
  int Next(VFFrame &vf_frame) override;

 private:
  int PostProcessFrame(VFFrame &vf_frame);
  int DeInterleave();
  int Resize();
  int Convert(VFFrame &vf_frame);

 private:
  uint8_t **pp_frames_ = nullptr;
  uint8_t *cur_frame_ = nullptr;
  int decoded_frame_idx_ = 0;
  int decoded_frame_num_ = 0;
  int cur_frame_pitch_ = 0;

  NvDecoder *nv_dec_ = nullptr;
  std::mutex mutex_;

  StageContext dframe_;  // deinterleaved frame
  StageContext cframe_;  // converted frame

 public:
  static vector<BackendInfo> GetBackendInfo();
  static void ReleaseBackendInfo();
  static const string kBackendName;
};

inline cudaVideoCodec FFmpeg2NvCodecId(AVCodecID id) {
  switch (id) {
    case AV_CODEC_ID_MPEG1VIDEO:
      return cudaVideoCodec_MPEG1;
    case AV_CODEC_ID_MPEG2VIDEO:
      return cudaVideoCodec_MPEG2;
    case AV_CODEC_ID_MPEG4:
      return cudaVideoCodec_MPEG4;
    case AV_CODEC_ID_VC1:
      return cudaVideoCodec_VC1;
    case AV_CODEC_ID_H264:
      return cudaVideoCodec_H264;
    case AV_CODEC_ID_HEVC:
      return cudaVideoCodec_HEVC;
    case AV_CODEC_ID_VP8:
      return cudaVideoCodec_VP8;
    case AV_CODEC_ID_VP9:
      return cudaVideoCodec_VP9;
    case AV_CODEC_ID_MJPEG:
      return cudaVideoCodec_JPEG;
    default:
      return cudaVideoCodec_NumCodecs;
  }
}

int NVDecoderBackend::Init(AVStream *video_stream, AVCodec *codec,
                           PixelFormat tgt_pix_fmt, int tgt_w, int tgt_h,
                           int sample_rate, int64_t timestamp_intervals,
                           bool with_mvs) {
  int ret = DecoderBackend::Init(video_stream, codec, tgt_pix_fmt, tgt_w, tgt_h,
                                 sample_rate, timestamp_intervals, with_mvs);
  if (Status_OK != ret) return ret;

  Dim resize_dim;
  resize_dim.w = tgt_w > 0 ? tgt_w : 0;
  resize_dim.h = tgt_h > 0 ? tgt_h : 0;

  src_w_ = video_stream->codecpar->width;
  src_h_ = video_stream->codecpar->height;

  if (resize_dim.w > 0 && resize_dim.h == 0) resize_dim.h = src_h_;
  if (resize_dim.w == 0 && resize_dim.h > 0) resize_dim.w = src_w_;

  nv_dec_ = new NvDecoder(CUcontext(binfo_.hw_ctx), src_w_, src_h_, true,
                          FFmpeg2NvCodecId(video_stream->codecpar->codec_id),
                          &mutex_, false, false, nullptr, &resize_dim);
  decoded_frame_num_ = 0;
  decoded_frame_idx_ = 0;
  dframe_.dev_id = binfo_.dev_id;
  cframe_.dev_id = binfo_.dev_id;

  return Status_OK;
}

void NVDecoderBackend::Close() {
  DecoderBackend::Close();

  cuda::ChooseGPU(binfo_.dev_id);
  DeletePointer(nv_dec_);
}

int NVDecoderBackend::PutPacket(const AVPacket *packet, int64_t timestamp) {
  if (0 == packet->size) return Status_OK;

  if (decoded_frame_idx_ < decoded_frame_num_) return Status_Pending;

  try {
    if (false == nv_dec_->Decode(packet->data, packet->size, &pp_frames_,
                                 &decoded_frame_num_)) {
      LOG_ERROR << "Error while sending packet to NvDecoder";
      return Status_Error;
    }
  } catch (exception &ex) {
    LOG(ERROR) << "NvDecode failed with internal error: " << ex.what();
    return Status_InternalError;
  }
  decoded_frame_idx_ = 0;
  cur_timestamp_ = timestamp;
  return Status_OK;
}

int NVDecoderBackend::Next(VFFrame &vf_frame) {
  while (true) {
    if (decoded_frame_idx_ >= decoded_frame_num_) return Status_Pending;
    vf_frame.id = cur_frame_id_++;
    if (vf_frame.id % sample_rate_ == 0 &&
        cur_timestamp() - vf_frame.timestamp >= timestamp_interval_) {
      vf_frame.timestamp = cur_timestamp();
      break;
    }
    ++decoded_frame_idx_;
  }

  return PostProcessFrame(vf_frame);
}

int NVDecoderBackend::PostProcessFrame(VFFrame &vf_frame) {
  int frame_w = nv_dec_->GetWidth();
  int frame_h = nv_dec_->GetHeight();
  int bit_depth = nv_dec_->GetBitDepth();
  cur_frame_pitch_ = nv_dec_->GetDeviceFramePitch();
  cur_frame_ = pp_frames_[decoded_frame_idx_++];
  if (frame_w <= 0 || frame_h <= 0) {
    LOG_ERROR << "decode failed: invalid decoded frame size(" << frame_w << ","
              << frame_h << ")";
    return Status_Error;
  }
  if (src_w_ <= 0) src_w_ = frame_w;
  if (src_h_ <= 0) src_h_ = frame_h;
  if (tgt_w_ <= 0) tgt_w_ = frame_w;
  if (tgt_h_ <= 0) tgt_h_ = frame_h;

  if (tgt_w_ != frame_w || tgt_h_ != frame_h) {
    LOG_ERROR << "expect output size is " << tgt_w_ << "x" << tgt_h_
              << ", but got " << frame_w << "x" << frame_h;
    return Status_Error;
  }

  if (bit_depth != 8) {
    LOG_ERROR << "NvDecode error: unsupported bit depth " << bit_depth;
    return Status_Error;
  }

  cuda::ChooseGPU(binfo_.dev_id);

  int ret = Status_OK;
  if (vf_frame.descr.pix_fmt != VF_PIX_FMT_NV12 &&
      vf_frame.descr.pix_fmt != VF_PIX_FMT_GRAY) {
    // deinterleave
    ret = DeInterleave();
    if (Status_OK != ret) return ret;
  }

  // convert color space
  ret = Convert(vf_frame);
  if (Status_OK != ret) return ret;

  vf_frame.ctx.dev_type = kGPU;
  vf_frame.ctx.dev_id = binfo_.dev_id;

  vf_frame.descr =
      PixelDescr<uint8_t>::Parse(vf_frame.descr.pix_fmt, tgt_h_, tgt_w_,
                                 cframe_.out_step, cframe_.out_data);

  vf_frame.pic_type = 0;
  vf_frame.mvs = nullptr;
  vf_frame.mvs_length = 0;

  return Status_OK;
}

inline NppiSize GetNppiSize(int w, int h) {
  return NppiSize{(w >> 1) << 1, (h >> 1) << 1};
}

int NVDecoderBackend::DeInterleave() {
  int err = NPP_SUCCESS;
  // resize
  int out_sw, out_sh;
  av_pix_fmt_get_chroma_sub_sample(AV_PIX_FMT_YUV420P, &out_sw, &out_sh);

  dframe_.out_w[0] = tgt_w_;
  dframe_.out_h[0] = tgt_h_;
  dframe_.out_w[1] = dframe_.out_w[2] = tgt_w_ >> out_sw;
  dframe_.out_h[1] = dframe_.out_h[2] = tgt_h_ >> out_sh;

  for (int i = 0; i < 3; ++i) {
    dframe_.out_step[i] = FFALIGN(dframe_.out_w[i], 32);
    dframe_.ResizeOutput(i, dframe_.out_h[i] * dframe_.out_step[i]);
  }
  err = nppiYCbCr420_8u_P2P3R(cur_frame_, cur_frame_pitch_,
                              cur_frame_ + tgt_h_ * cur_frame_pitch_,
                              cur_frame_pitch_, dframe_.out_data,
                              dframe_.out_step, (NppiSize){tgt_w_, tgt_h_});
  if (err != NPP_SUCCESS) {
    LOG_ERROR << "nppi resize failed with error code: " << err;
    return Status_Error;
  }

  return Status_OK;
}

int NVDecoderBackend::Convert(VFFrame &vf_frame) {
  int err = NPP_SUCCESS;
  switch (vf_frame.descr.pix_fmt) {
    case VF_PIX_FMT_GRAY: {
      cframe_.out_w[0] = tgt_w_;
      cframe_.out_h[0] = tgt_h_;
      cframe_.out_step[0] = cur_frame_pitch_;
      cframe_.out_data[0] = cur_frame_;
      break;
    }
    case VF_PIX_FMT_BGR:
      cframe_.out_w[0] = tgt_w_;
      cframe_.out_h[0] = tgt_h_;
      cframe_.out_step[0] = tgt_w_ * 3;
      cframe_.ResizeOutput(0, cframe_.out_h[0] * cframe_.out_step[0]);
      err = nppiYCbCr420ToBGR_8u_P3C3R(
          dframe_.out_data, dframe_.out_step, cframe_.out_data[0],
          cframe_.out_step[0], GetNppiSize(cframe_.out_w[0], cframe_.out_h[0]));
      break;
    case VF_PIX_FMT_RGB:
      cframe_.out_w[0] = tgt_w_;
      cframe_.out_h[0] = tgt_h_;
      cframe_.out_step[0] = tgt_w_ * 3;
      cframe_.ResizeOutput(0, cframe_.out_h[0] * cframe_.out_step[0]);
      err = nppiYCbCr420ToRGB_8u_P3C3R(
          dframe_.out_data, dframe_.out_step, cframe_.out_data[0],
          cframe_.out_step[0], GetNppiSize(cframe_.out_w[0], cframe_.out_h[0]));
      break;
    case VF_PIX_FMT_YUV420P:
      for (int i = 0; i < 3; ++i) {
        cframe_.out_w[i] = dframe_.out_w[i];
        cframe_.out_h[i] = dframe_.out_h[i];
        cframe_.out_step[i] = dframe_.out_step[i];
        cframe_.out_data[i] = dframe_.out_data[i];
      }
      break;
    case VF_PIX_FMT_NV12:
      cframe_.out_w[0] = tgt_w_;
      cframe_.out_h[0] = tgt_h_;
      cframe_.out_step[0] = cur_frame_pitch_;
      cframe_.out_data[0] = cur_frame_;

      cframe_.out_w[1] = tgt_w_;
      cframe_.out_h[1] = tgt_h_ >> 1;
      cframe_.out_step[1] = cur_frame_pitch_;
      cframe_.out_data[1] = cur_frame_ + cur_frame_pitch_ * tgt_h_;
      break;
    default:
      LOG_ERROR << "Unsupported hardware pixel format: "
                << vf_frame.descr.pix_fmt;
      return Status_InvalidFormat;
  }
  if (err != NPP_SUCCESS) {
    LOG_ERROR << "nppi convert failed with error code: " << err;
    return Status_Error;
  }

  return Status_OK;
}

const string NVDecoderBackend::kBackendName = "nvdecoder";

vector<BackendInfo> NVDecoderBackend::GetBackendInfo() {
  vector<BackendInfo> infos;
  int ret = cuInit(0);
  if (CUDA_SUCCESS != ret) {
    std::cerr << "cuInit failed with error code: " << ret
              << ", gpu decode will be unavailable" << std::endl;
    return infos;
  }
  int gpu_num = 0;
  ret = cuDeviceGetCount(&gpu_num);
  if (CUDA_SUCCESS != ret) {
    std::cerr << "cuDeviceGetCount failed with error code: " << ret
              << ", gpu decode will be unavailable" << std::endl;
    return infos;
  } else if (gpu_num <= 0) {
    std::cerr << "cuDeviceGetCount get negative gpu number: " << gpu_num
              << ", gpu decode will be unavailable" << std::endl;
    return infos;
  }

  // create device contexts
  for (int dev_id = 0; dev_id < gpu_num; ++dev_id) {
    BackendInfo info;
    info.dev_type = kGPU;
    info.dev_id = dev_id;
    info.limits = -1;
    info.priority = 10;
    info.name = NVDecoderBackend::kBackendName;
    info.codecs = {AV_CODEC_ID_MPEG1VIDEO, AV_CODEC_ID_MPEG2VIDEO,
                   AV_CODEC_ID_MPEG4,      AV_CODEC_ID_VC1,
                   AV_CODEC_ID_H264,       AV_CODEC_ID_H265,
                   AV_CODEC_ID_HEVC,       AV_CODEC_ID_VP8,
                   AV_CODEC_ID_VP9,        AV_CODEC_ID_MJPEG};

    cuda::ChooseGPU(dev_id);
    if (CUDA_SUCCESS != ret) {
      LOG(ERROR) << "cuInit failed with error code: " << ret
                 << ", gpu decode will be unavailable";
      continue;
    }

    CUdevice cu_dev = 0;
    ret = cuDeviceGet(&cu_dev, dev_id);
    if (CUDA_SUCCESS != ret) {
      std::cerr << "cuDeviceGet on GPU " << dev_id << ": "
                << " failed, skip this device" << std::endl;
      continue;
    }
    char dev_name_buf[80];
    ret = cuDeviceGetName(dev_name_buf, sizeof(dev_name_buf), cu_dev);
    if (CUDA_SUCCESS != ret) {
      std::cerr << "cuDeviceGetName on GPU " << dev_id << ": "
                << " failed, skip this device" << std::endl;
      continue;
    }

    string dev_name = dev_name_buf;
    if (dev_name.find("P4") != string::npos ||
        dev_name.find("P100") != string::npos ||
        dev_name.find("V100") != string::npos) {
      info.limits = 16;
    } else if (dev_name.find("T4") != string::npos) {
      info.limits = 32;
    } else {
      std::cerr << "Unsupported cuda device: " << dev_name << std::endl;
      info.limits = 0;
      continue;
    }

    // create context
    int dev_active = 0;
    unsigned int dev_flags = 0;
    ret = cuDevicePrimaryCtxGetState(cu_dev, &dev_flags, &dev_active);
    if (CUDA_SUCCESS != ret) {
      std::cerr << "cuDevicePrimaryCtxGetState on GPU " << dev_id << ": "
                << " failed, skip this device" << std::endl;
      continue;
    }

    const unsigned int desired_flags = CU_CTX_SCHED_BLOCKING_SYNC;
    if (dev_active && dev_flags != desired_flags) {
      std::cerr << "Primary context already active with incompatible flags, "
                   "skip this device"
                << std::endl;
      continue;
    } else if (dev_flags != desired_flags) {
      ret = cuDevicePrimaryCtxSetFlags(cu_dev, desired_flags);
      if (CUDA_SUCCESS != ret) {
        std::cerr << "cuDevicePrimaryCtxSetFlags on GPU " << dev_id << ": "
                  << " failed, skip this device" << std::endl;
        continue;
      }
    }

    CUcontext cu_ctx;
    ret = cuDevicePrimaryCtxRetain(&cu_ctx, cu_dev);
    if (CUDA_SUCCESS != ret) {
      std::cerr << "cuDevicePrimaryCtxRetain on GPU " << dev_id << ": "
                << " failed, skip this device" << std::endl;
      continue;
    }

    info.hw_ctx = cu_ctx;
    infos.emplace_back(info);
  }

  if (infos.empty()) {
    std::cerr << "no valid GPUs found" << std::endl;
  }
  return infos;
}

void NVDecoderBackend::ReleaseBackendInfo() {
  map<int, bool> released;
  for (BackendInfo &binfo : DecoderBackend::backends()) {
    if (binfo.dev_type == kGPU && nullptr != binfo.hw_ctx) {
      auto iter = released.find(binfo.dev_id);
      if (released.end() != iter && true == iter->second) continue;

      CUdevice cu_dev = 0;
      int ret = cuDeviceGet(&cu_dev, binfo.dev_id);
      if (CUDA_SUCCESS != ret) {
        std::cerr << "cuDeviceGet on GPU " << binfo.dev_id << ": "
                  << " failed, skip this device" << std::endl;
        continue;
      }
      ret = cuDevicePrimaryCtxRelease(cu_dev);
      if (CUDA_SUCCESS != ret) {
        std::cerr << "cuDevicePrimaryCtxRelease on GPU " << binfo.dev_id
                  << " with error code: " << ret << std::endl;
        continue;
      }
      released[binfo.dev_id] = true;
    }
  }
}

RegisterDecoderBackend(NVDecoderBackend, NVDecoderBackend::kBackendName,
                       "gpu decoder-backend, the most compatible one.");

}  // namespace io
}  // namespace vf

#endif
