/**
 * Copyright (c) 2020 xxx Inc.
 * File              : mxnet-worker.cc
 * Author            : 
 * Date              : 2020-05-07
 * Last Modified Date: 2020-05-07
 * Last Modified By  : 
 */
#include "vf/dl/dl-worker.h"

#if USE_MXNet

#include <dlfcn.h>
#include <stdio.h>
#include <stdlib.h>

//
#include <map>

//
#include "mxnet-executor.h"

//
#include <cxxutil/logging.h>
#include <cxxutil/strutil.h>
#include <vf/security/codify.h>

using namespace std;
using namespace cxxutil;
using namespace mxnet::cpp;
using namespace vf::security;

namespace vf {
namespace dl {

class MXNetWorker : public DLWorker {
 public:
  MXNetWorker(const DLWorkerParam& param, const vf::Context& ctx)
      : DLWorker(param, ctx) {}
  virtual ~MXNetWorker();

  int Init() override;

 protected:
  int InferShape(Symbol& sym);
  int BindExecutors();
  void FeedData(DLTaskBatch* batch) override;
  void process(DLTaskBatch* batch) override;

 protected:
  map<string, NDArray> args_map_;
  map<string, NDArray> aux_map_;
  Symbol net_;
  vector<MXNetExecutor*> executors_;
  vector<vector<NDArray*>> input_datas_;
  mxnet::cpp::Context mx_ctx_ = mxnet::cpp::Context::cpu();
  int output_dim_ = 0;
  vector<int> output_dims_;
  vector<float> output_blob_;
};

MXNetWorker::~MXNetWorker() {
  for (MXNetExecutor* exec : executors_) {
    delete exec;
  }
  executors_.clear();
}

int MXNetWorker::Init() {
  if (kCPU == ctx_.dev_type) {
    LOG(INFO) << "Initialize mxnet (" << name() << ") on CPU";
    mx_ctx_ = mxnet::cpp::Context::cpu();
  } else {
    LOG(INFO) << "Initialize mxnet (" << name() << ") on GPU " << ctx_.dev_id;
    mx_ctx_ = mxnet::cpp::Context::gpu(ctx_.dev_id);
  }

  try {
    // load parameters
    const vector<uint8_t>& param_data = Codify::GetFile(param_.weight_path);
    if (param_data.size() == 0) {
      LOG(ERROR) << "Load param file" << param_.weight_path << " failed";
      return Status_IOError;
    }
    map<string, NDArray> parameters;
    NDArray::LoadFromBuffer(param_data.data(), param_data.size(), 0,
                            &parameters);
    for (const auto& k : parameters) {
      if (k.first.substr(0, 4) == "aux:") {
        auto name = k.first.substr(4, k.first.size() - 4);
        aux_map_[name] = k.second.Copy(mx_ctx_);
      }
      if (k.first.substr(0, 4) == "arg:") {
        auto name = k.first.substr(4, k.first.size() - 4);
        args_map_[name] = k.second.Copy(mx_ctx_);
      }
    }
    /*WaitAll is need when we copy data between GPU and the main memory*/
    NDArray::WaitAll();

    // load model
    vector<uint8_t> json_data = Codify::GetFile(param_.net_path);
    if (json_data.size() == 0) {
      LOG(ERROR) << "Load json file" << param_.net_path << " failed";
      return Status_IOError;
    }
    json_data.push_back('\0');
    SymbolHandle handle;
    if (MXSymbolCreateFromJSON((const char*)(json_data.data()), &(handle)) !=
        0) {
      LOG(ERROR) << "MXSymbolCreateFromJSON failed: " << MXGetLastError();
      return Status_IOError;
    }
    Symbol sym = Symbol(handle);

    int ret = InferShape(sym);
    if (Status_OK != ret) return ret;
    ret = BindExecutors();
    if (Status_OK != ret) return ret;
  } catch (exception& ex) {
    LOG(ERROR) << "Execution failed with mxnet error: " << ex.what() << ": "
               << MXGetLastError();
    return Status_IOError;
  }

  return DLWorker::Init();
}

int MXNetWorker::InferShape(Symbol& sym) {
  Symbol internals = sym.GetInternals();
  vector<Symbol> syms;
  for (const string& layer : param_.output_names) {
    if (endswith(layer, "_output")) {
      syms.push_back(Flatten(internals[layer]));
    } else {
      syms.push_back(Flatten(internals[layer + "_output"]));
    }
  }
  if (param_.output_names.size() == 1) {
    net_ = syms[0];
  } else {
    net_ =
        Concat("vf_added_concat_symbol_concat_added_vf", syms, syms.size(), 1);
  }
  syms.push_back(sym);
  sym = mxnet::cpp::Symbol::Group(syms);

  // infer shape
  for (size_t i = 0; i < param_.input_shapes.size(); ++i) {
    vector<mx_uint> in_shape;
    in_shape.push_back(param_.batch_size);
    for (auto v : param_.input_shapes[i]) in_shape.push_back(mx_uint(v));
    args_map_[param_.input_names[i]] = NDArray(Shape(in_shape), mx_ctx_, false);
  }

  map<string, vector<mx_uint>> args_shapes;
  for (auto& pair : args_map_) {
    args_shapes[pair.first] = pair.second.GetShape();
  }

  vector<vector<mx_uint>> in_shapes;
  vector<vector<mx_uint>> aux_shapes;
  vector<vector<mx_uint>> out_shapes;
  sym.InferShape(args_shapes, &in_shapes, &aux_shapes, &out_shapes);
  if (out_shapes.size() <= param_.output_names.size()) {
    LOG(ERROR) << "there should be at least " << param_.output_names.size()
               << " outputs, but got " << out_shapes.size();
    return Status_Error;
  }

  output_dim_ = 0;
  output_dims_.clear();
  for (size_t i = 0; i < param_.output_names.size(); ++i) {
    CHECK(out_shapes[i].size() > 1)
        << "output shape dimension should larger than 1, but got "
        << out_shapes[i].size();
    int feat_dim = 1;
    for (auto v : out_shapes[i]) feat_dim *= v;
    feat_dim /= param_.batch_size;
    output_dims_.push_back(feat_dim);
    output_dim_ += feat_dim;
  }

  output_blob_.resize(output_dim_);
  LOG(INFO) << "Output dimension: " << output_dim_;
  return Status_OK;
}

int MXNetWorker::BindExecutors() {
  // Create an executor after binding the model to input parameters.
  std::vector<NDArray> arg_arrays;
  std::vector<NDArray> grad_arrays;
  std::vector<OpReqType> grad_reqs;
  std::vector<NDArray> aux_arrays;
  std::map<std::string, NDArray> arg_grad_store;
  std::map<std::string, OpReqType> grad_req_type;

  net_.InferExecutorArrays(mx_ctx_, &arg_arrays, &grad_arrays, &grad_reqs,
                           &aux_arrays, args_map_, arg_grad_store,
                           grad_req_type, aux_map_);

  MXNetExecutor* base_exec = new MXNetExecutor(
      net_, mx_ctx_, arg_arrays, grad_arrays, grad_reqs, aux_arrays);

  vector<const char*> arg_names;
  for (const string& input_name : param_.input_names) {
    arg_names.push_back(input_name.c_str());
  }
  const vector<string> arg_list = net_.ListArguments();
  for (int k = 1; k < param_.batch_size; ++k) {
    vector<mx_uint> in_shapes;
    vector<mx_uint> shape_idx;
    vector<int> input_idxes;
    shape_idx.push_back(0);
    for (size_t i = 0; i < param_.input_shapes.size(); ++i) {
      int data_idx = -1;
      vector<mx_uint> in_shape = args_map_[param_.input_names[i]].GetShape();
      for (size_t j = 0; j < arg_list.size(); ++j) {
        if (arg_list[j] == param_.input_names[i]) {
          data_idx = j;
          break;
        }
      }
      if (data_idx < 0) {
        LOG(ERROR) << param_.input_names[i] << " is not found";
        return Status_Error;
      }
      input_idxes.push_back(data_idx);
      in_shapes.push_back(k);
      in_shapes.insert(in_shapes.end(), in_shape.begin() + 1, in_shape.end());
      shape_idx.push_back(in_shapes.size());
    }
    MXNetExecutor* exec =
        base_exec->ReshapeInput(mx_ctx_, param_.input_shapes.size(), in_shapes,
                                shape_idx, arg_names.data());
    executors_.push_back(exec);
    vector<NDArray*> input_data;
    for (int input_idx : input_idxes) {
      input_data.push_back(&(exec->arg_arrays[input_idx]));
    }
    input_datas_.push_back(input_data);
  }

  executors_.push_back(base_exec);
  vector<NDArray*> input_data;
  for (const string& input_name : param_.input_names) {
    input_data.push_back(&(args_map_[input_name]));
  }
  input_datas_.push_back(input_data);
  return Status_OK;
}

void MXNetWorker::FeedData(DLTaskBatch* batch) {
  CHECK(batch->index > 0 && batch->index <= param_.batch_size)
      << "batch size should be in [1," << param_.batch_size << "], but got "
      << batch->index;

  for (size_t i = 0; i < param_.input_shapes.size(); ++i) {
    // feed data
    if (kGPU == ctx_.dev_type) {
      PreProcess(i, batch->index, batch->gpu_data(stream_, i),
                 (mx_float*)(input_datas_[batch->index - 1][i]->GetData()));
    } else if (kCPU == ctx_.dev_type) {
      PreProcess(i, batch->index, batch->cpu_data(stream_, i),
                 (mx_float*)(input_datas_[batch->index - 1][i]->GetData()));
    } else {
      throw RuntimeException("unknown device type for CaffeWorker: " +
                             to_string(ctx_.dev_type));
    }
  }
}

void MXNetWorker::process(DLTaskBatch* batch) {
  // forward
  MXNetExecutor* exec = executors_[batch->index - 1];
  exec->Forward();

  // The output is available in executor->outputs.
  auto out_array = exec->outputs[0];
  NDArray::WaitAll();
  out_array.SyncCopyToCPU(&output_blob_);

  for (int n = 0; n < batch->index; ++n) {
    batch->tasks[n]->features.clear();
  }

  // write back results
  if (param_.output_argmax) {
    for (int i = 0; i < batch->index; ++i) {
      vector<float>& result = batch->tasks[i]->result;
      vector<DLTask::Feature>& features = batch->tasks[i]->features;

      result.resize(output_dims_.size() * 2);
      features.resize(output_dims_.size());
      float* src_ptr = output_blob_.data() + i * output_dim_;
      for (size_t k = 0; k < output_dims_.size(); ++k) {
        auto max_iter = max_element(src_ptr, src_ptr + output_dims_[k]);
        result[2 * k] = float(max_iter - src_ptr);
        result[2 * k + 1] = *max_iter;
        src_ptr += output_dims_[k];
        features[k].ptr = result.data() + 2 * k;
        features[k].dim = 2;
      }
    }
  } else {
    for (int i = 0; i < batch->index; ++i) {
      vector<float>& result = batch->tasks[i]->result;
      vector<DLTask::Feature>& features = batch->tasks[i]->features;

      result.resize(output_dim_);
      features.resize(output_dims_.size());

      memcpy(result.data(), output_blob_.data() + i * output_dim_,
             output_dim_ * sizeof(mx_float));

      int offset = 0;
      for (size_t k = 0; k < output_dims_.size(); ++k) {
        features[k].ptr = result.data() + offset;
        features[k].dim = output_dims_[k];
        offset += output_dims_[k];
      }
    }
  }
}

RegisterDLWorker(MXNetWorker, "mxnet", "");
}  // namespace dl
}  // namespace vf
#endif
