/**
 * Copyright (c) 2020 xxx Inc.
 * File              : caffe-worker.cc
 * Author            : 
 * Date              : 2020-05-07
 * Last Modified Date: 2020-05-07
 * Last Modified By  : 
 */
#include "vf/dl/dl-worker.h"

#ifdef USE_Caffe

#undef CUDA_POST_KERNEL_CHECK

#include <google/protobuf/text_format.h>

#include <caffe/net.hpp>
#include <caffe/util/upgrade_proto.hpp>
//
#include <cxxutil/logging.h>
#include <cxxutil/strutil.h>
#include <vf/security/codify.h>

using namespace std;
using namespace caffe;
using namespace cxxutil;
using namespace vf::security;

namespace vf {
namespace dl {

class CaffeWorker : public DLWorker {
 public:
  using DLWorker::DLWorker;

  virtual ~CaffeWorker();

  int Init() override;

 protected:
  int LoadNetwork();
  int InitInput();

  void run() override;
  void FeedData(DLTaskBatch* batch) override;
  void process(DLTaskBatch* batch) override;

 protected:
  // the network
  caffe::Net<float>* caffe_net_ = nullptr;
  // pointer to the input blob
  vector<caffe::Blob<float>*> input_blobs_;
  // pointer to the probability blob
  vector<const caffe::Blob<float>*> output_blobs_;
  vector<vector<int>> input_shapes_;
};

CaffeWorker::~CaffeWorker() {
  if (caffe_net_ != nullptr) {
    delete caffe_net_;
    caffe_net_ = nullptr;
  }
}

int CaffeWorker::Init() {
  if (kCPU == ctx_.dev_type) {
    LOG(INFO) << "Initialize caffe (" << name() << ") on CPU";
    Caffe::set_mode(Caffe::CPU);
  } else {
    LOG(INFO) << "Initialize caffe (" << name() << ") on GPU " << ctx_.dev_id;
    Caffe::SetDevice(ctx_.dev_id);
    Caffe::set_mode(Caffe::GPU);
  }
  // init network
  int ret = LoadNetwork();
  if (Status_OK != ret) return ret;

  // init input
  ret = InitInput();
  if (Status_OK != ret) return ret;

  // init output
  int output_dim = 0;
  for (const string& layer : param_.output_names) {
    const caffe::Blob<float>* output_blob =
        caffe_net_->blob_by_name(layer).get();
    if (output_blob == nullptr) {
      LOG(ERROR) << layer << " layer does not exist!";
      return Status_InvalidFormat;
    }
    output_dim += output_blob->count(1);
    output_blobs_.push_back(output_blob);
  }
  LOG(INFO) << "Output dimension: " << output_dim;

  return DLWorker::Init();
}

int CaffeWorker::LoadNetwork() {
  NetParameter net_param;
  NetParameter weight_param;
  try {
    const vector<uint8_t>& net_data = Codify::GetFile(param_.net_path);
    const vector<uint8_t>& weight_data = Codify::GetFile(param_.weight_path);
    if (google::protobuf::TextFormat::ParseFromString(
            string((char*)net_data.data(), net_data.size()), &net_param) ==
        false) {
      LOG(ERROR) << "load caffe prototxt failed";
      return Status_InvalidArgument;
    }
    google::protobuf::io::CodedInputStream* coded_input =
        new google::protobuf::io::CodedInputStream(weight_data.data(),
                                                   weight_data.size());
    coded_input->SetTotalBytesLimit(INT_MAX, 536870912);
    // if (false ==
    //    weight_param.ParseFromArray(weight_data.data(), weight_data.size()))
    if (false == weight_param.ParseFromCodedStream(coded_input)) {
      LOG(ERROR) << "load caffe model failed";
      return Status_InvalidArgument;
    }
  } catch (exception& ex) {
    LOG(ERROR) << ex.what();
    return Status_InvalidArgument;
  }
  UpgradeNetAsNeeded(param_.net_path, &net_param);
  UpgradeNetAsNeeded(param_.weight_path, &weight_param);

  net_param.mutable_state()->set_phase(caffe::TEST);
  caffe_net_ = new Net<float>(net_param);
  caffe_net_->CopyTrainedLayersFrom(weight_param);
  // caffe_net_->CopyTrainedLayersFrom(param_.weight_path);
  return Status_OK;
}

int CaffeWorker::InitInput() {
  const vector<Blob<float>*>& net_input_blobs = caffe_net_->input_blobs();
  if (net_input_blobs.size() != param_.input_shapes.size()) {
    LOG(ERROR) << "network has " << net_input_blobs.size()
               << " inputs, but config specified "
               << param_.input_shapes.size();
    return Status_InvalidFormat;
  }
  input_shapes_.resize(net_input_blobs.size());
  input_blobs_.resize(net_input_blobs.size());
  for (size_t i = 0; i < net_input_blobs.size(); ++i) {
    input_blobs_[i] = net_input_blobs[i];
    input_shapes_[i] = input_blobs_[i]->shape();
    if (input_shapes_[i].size() != param_.input_shapes[i].size() + 1) {
      LOG(ERROR) << "the " << i << "-th input has "
                 << input_shapes_[i].size() - 1
                 << " dimensions, but config specified "
                 << param_.input_shapes[i].size();
      return Status_InvalidFormat;
    }
    for (size_t j = 0; j < param_.input_shapes[i].size(); ++j) {
      if (input_shapes_[i][j + 1] != param_.input_shapes[i][j]) {
        LOG(ERROR) << "the " << i << "-th input shape mismatch, network expect "
                   << vector<int>(input_shapes_[i].begin() + 1,
                                  input_shapes_[i].end())
                   << ", but config specified " << param_.input_shapes[i];
        return Status_InvalidFormat;
      }
    }

    input_shapes_[i][0] = param_.batch_size;
    input_shapes_[i][0] = 1;
    input_blobs_[i]->Reshape(input_shapes_[i]);
  }
  return Status_OK;
}

void CaffeWorker::run() {
  if (kCPU == ctx_.dev_type) {
    Caffe::set_mode(Caffe::CPU);
  } else {
    Caffe::SetDevice(ctx_.dev_id);
    Caffe::set_mode(Caffe::GPU);
  }

  // forward
  LOG(INFO) << "first forward";
  caffe_net_->Forward();
  DLWorker::run();
}

void CaffeWorker::FeedData(DLTaskBatch* batch) {
  CHECK(batch->index <= param_.batch_size)
      << "batch index (" << batch->index << " should be less than "
      << param_.batch_size;

  for (size_t i = 0; i < input_shapes_.size(); ++i) {
    // reshape input
    if (batch->index != input_shapes_[i][0]) {
      input_shapes_[i][0] = batch->index;
      input_blobs_[i]->Reshape(input_shapes_[i]);
    }
    // feed data
    if (kGPU == ctx_.dev_type) {
      PreProcess(i, batch->index, batch->gpu_data(stream_, i),
                 input_blobs_[i]->mutable_gpu_data());
    } else if (kCPU == ctx_.dev_type) {
      PreProcess(i, batch->index, batch->cpu_data(stream_, i),
                 input_blobs_[i]->mutable_cpu_data());
    } else {
      throw RuntimeException("unknown device type for CaffeWorker: " +
                             to_string(ctx_.dev_type));
    }
  }
}

void CaffeWorker::process(DLTaskBatch* batch) {
  // forward
  caffe_net_->Forward();

  for (int n = 0; n < batch->index; ++n) {
    batch->tasks[n]->features.clear();
  }

  // write back results
  if (param_.output_argmax) {
    for (int n = 0; n < batch->index; ++n) {
      batch->tasks[n]->result.clear();
    }
    for (const caffe::Blob<float>* output_blob : output_blobs_) {
      int feat_dim = output_blob->count() / batch->index;
      const float* output_data = output_blob->cpu_data();
      for (int n = 0; n < batch->index; ++n, output_data += feat_dim) {
        vector<float>& result = batch->tasks[n]->result;
        vector<DLTask::Feature>& features = batch->tasks[n]->features;

        auto max_iter = max_element(output_data, output_data + feat_dim);
        result.push_back(float(max_iter - output_data));
        result.push_back(*max_iter);
        features.push_back(
            DLTask::Feature{result.data() + result.size() - 2, 2});
      }  // for (size_t n = 0; n < img_num; ++n)
    }
  } else {
    int output_dim = 0;
    int output_offset = 0;
    for (const caffe::Blob<float>* output_blob : output_blobs_) {
      int feat_dim = output_blob->count() / batch->index;
      output_dim += feat_dim;
      const float* output_data = output_blob->cpu_data();
      for (int n = 0; n < batch->index; ++n, output_data += feat_dim) {
        vector<float>& result = batch->tasks[n]->result;
        vector<DLTask::Feature>& features = batch->tasks[n]->features;

        result.resize(output_dim);
        memcpy(result.data() + output_offset, output_data,
               sizeof(float) * feat_dim);

        features.push_back(
            DLTask::Feature{result.data() + output_offset, feat_dim});
      }  // for (size_t n = 0; n < img_num; ++n)
      output_offset += feat_dim;
    }
  }
}

RegisterDLWorker(CaffeWorker, "caffe", "");

#ifdef USE_CaffeSSD

class CaffeSSDWorker : public CaffeWorker {
 public:
  using CaffeWorker::CaffeWorker;

  int Init() override;

 protected:
  void process(DLTaskBatch* batch) override;

 protected:
  float conf_thresh_;
};

int CaffeSSDWorker::Init() {
  if (param_.output_names.size() != 1) {
    LOG(ERROR) << "output layers of CaffeSSDWorker should be 1, but got "
               << param_.output_names.size();
    return Status_InvalidFormat;
  }
  int ret = CaffeWorker::Init();
  if (ret != Status_OK) return ret;
  const caffe::Blob<float>* output_blob = output_blobs_[0];
  if (output_blob->num() != 1) {
    LOG(ERROR) << "output blob num should be 1, but got " << output_blob->num();
    return Status_InvalidFormat;
  }
  if (output_blob->channels() != 1) {
    LOG(ERROR) << "output blob channels should be 1, but got "
               << output_blob->channels();
    return Status_InvalidFormat;
  }
  if (output_blob->width() != 7) {
    LOG(ERROR) << "output blob width should be 7, but got "
               << output_blob->width();
    return Status_InvalidFormat;
  }

  const boost::shared_ptr<caffe::Layer<float>> out_layer =
      caffe_net_->layer_by_name(param_.output_names[0]);
  const auto& detection_output_param =
      out_layer->layer_param().detection_output_param();

  conf_thresh_ = detection_output_param.confidence_threshold();

  return ret;
}

void CaffeSSDWorker::process(DLTaskBatch* batch) {
  // forward
  caffe_net_->Forward();

  // write back results
  for (int n = 0; n < batch->index; ++n) {
    batch->tasks[n]->result.resize(0);
    batch->tasks[n]->features.resize(1);
  }

  const caffe::Blob<float>* output_blob = output_blobs_[0];
  int box_num = output_blob->height();
  int box_dim = output_blob->width();
  CHECK(box_dim == 7) << "detection output of caffe should have 7 fields";
  const float* output_data = output_blob->cpu_data();

  for (int i = 0; i < box_num; ++i, output_data += box_dim) {
    int index = output_data[0];
    float score = output_data[2];
    if (score < conf_thresh_) continue;
    int ori_dim = int(batch->tasks[index]->result.size());
    batch->tasks[index]->result.resize(ori_dim + box_dim - 1);
    memcpy(batch->tasks[index]->result.data() + ori_dim, output_data + 1,
           sizeof(float) * (box_dim - 1));
  }
  for (int n = 0; n < batch->index; ++n) {
    batch->tasks[n]->features[0].ptr = batch->tasks[n]->result.data();
    batch->tasks[n]->features[0].dim = int(batch->tasks[n]->result.size());
  }
}

RegisterDLWorker(CaffeSSDWorker, "caffe-ssd", "");
#endif

class GloGInit {
 public:
  GloGInit() {
#ifdef GFLAGS_NAMESPACE
    GFLAGS_NAMESPACE::InitGoogleLogging("caffe-worker");
    GFLAGS_NAMESPACE::SetCommandLineOption("GLOG_minloglevel", "2");
#else
    gflags::SetCommandLineOption("minloglevel", "2");
#endif
  }
};

static GloGInit __glog_init_;
}  // namespace dl
}  // namespace vf
#endif
