/**
 * Copyright (c) 2020 xxx Inc.
 * File              : mxnet-executor.h
 * Author            : 
 * Date              : 2020-05-07
 * Last Modified Date: 2020-05-07
 * Last Modified By  : 
 */
#ifndef VIDEOFLOW_SRC_DL_MXNET_EXECUTOR_H__
#define VIDEOFLOW_SRC_DL_MXNET_EXECUTOR_H__

#include <map>
#include <set>
#include <string>
#include <vector>

//
#include <mxnet-cpp/MxNetCpp.h>

using mxnet::cpp::NDArray;
using mxnet::cpp::OpReqType;
using mxnet::cpp::Symbol;
using std::map;
using std::string;
using std::vector;

namespace vf {
namespace dl {

/*!
 * \brief Executor interface
 */
class MXNetExecutor {
 public:
  MXNetExecutor(const Symbol &symbol, mxnet::cpp::Context context,
                const vector<NDArray> &arg_arrays,
                const vector<NDArray> &grad_arrays,
                const vector<OpReqType> &grad_reqs,
                const vector<NDArray> &aux_arrays) {
    this->arg_arrays = arg_arrays;
    this->grad_arrays = grad_arrays;
    this->aux_arrays = aux_arrays;
    this->symbol_ = symbol;

    vector<NDArrayHandle> arg_handles;
    vector<NDArrayHandle> grad_handles;
    vector<NDArrayHandle> aux_handles;

    for (const auto &array : arg_arrays) {
      arg_handles.push_back(array.GetHandle());
    }
    for (const auto &array : grad_arrays) {
      grad_handles.push_back(array.GetHandle());
    }
    for (const auto &array : aux_arrays) {
      aux_handles.push_back(array.GetHandle());
    }

    vector<mx_uint> grad_reqs_uint;
    for (auto s : grad_reqs) grad_reqs_uint.push_back(s);
    map<string, mxnet::cpp::Context> group_to_ctx;
    vector<const char *> map_keys;
    vector<int> dev_types, dev_ids;

    CHECK_EQ(MXExecutorBindEX(
                 symbol.GetHandle(), context.GetDeviceType(),
                 context.GetDeviceId(), group_to_ctx.size(), map_keys.data(),
                 dev_types.data(), dev_ids.data(), arg_handles.size(),
                 arg_handles.data(), grad_handles.data(), grad_reqs_uint.data(),
                 aux_handles.size(), aux_handles.data(), nullptr, &handle_),
             0);

    mx_uint out_size;
    NDArrayHandle *out_array;
    CHECK_EQ(MXExecutorOutputs(handle_, &out_size, &out_array), 0);
    for (mx_uint i = 0; i < out_size; ++i) {
      outputs.push_back(NDArray(out_array[i]));
    }
  }

  explicit MXNetExecutor(ExecutorHandle handle) { handle_ = handle; }

  MXNetExecutor *ReshapeInput(mxnet::cpp::Context ctx, size_t input_num,
                              const vector<mx_uint> &in_shape,
                              const vector<mx_uint> &shape_idx,
                              char const **arg_names) {
    vector<const char *> map_keys;
    vector<int> dev_types, dev_ids;
    mx_uint num_in_args;
    NDArrayHandle *in_args = nullptr;
    NDArrayHandle *arg_grads = nullptr;
    mx_uint num_aux_states = 0;
    NDArrayHandle *aux_states = nullptr;

    ExecutorHandle new_handle = nullptr;

    CHECK_EQ(
        MXExecutorReshape(1, 0, ctx.GetDeviceType(), ctx.GetDeviceId(), 0,
                          map_keys.data(), dev_types.data(), dev_ids.data(),
                          input_num, arg_names, in_shape.data(),
                          shape_idx.data(), &num_in_args, &in_args, &arg_grads,
                          &num_aux_states, &aux_states, handle_, &new_handle),
        0);
    MXNetExecutor *exec = new MXNetExecutor(new_handle);
    for (mx_uint i = 0; i < num_in_args; ++i) {
      exec->arg_arrays.push_back(NDArray(in_args[i]));
    }
    if (nullptr != arg_grads) {
      for (mx_uint i = 0; i < num_in_args; ++i) {
        exec->grad_arrays.push_back(NDArray(arg_grads[i]));
      }
    }

    for (mx_uint i = 0; i < num_aux_states; ++i) {
      exec->aux_arrays.push_back(NDArray(aux_states[i]));
    }
    exec->symbol_ = symbol_;
    mx_uint out_size;
    NDArrayHandle *out_array;
    CHECK_EQ(MXExecutorOutputs(new_handle, &out_size, &out_array), 0);
    for (mx_uint i = 0; i < out_size; ++i) {
      exec->outputs.push_back(NDArray(out_array[i]));
    }
    return exec;
  }

  /*!
   * \brief Perform a Forward operation of Operator
   *  After this operation, user can get the result by using function head.
   */
  void Forward() {
    MXExecutorForward(handle_, 0);
    mx_uint out_size;
    NDArrayHandle *out_array;
    CHECK_EQ(MXExecutorOutputs(handle_, &out_size, &out_array), 0);
    for (mx_uint i = 0; i < out_size; ++i) {
      outputs[i] = NDArray(out_array[i]);
    }
  }

  /*!
   * \brief destructor, free the handle
   */
  ~MXNetExecutor() { MXExecutorFree(handle_); }
  vector<NDArray> arg_arrays;
  vector<NDArray> grad_arrays;
  vector<NDArray> aux_arrays;
  /*!
   * \brief arrays store the outputs of forward
   */
  vector<NDArray> outputs;
  map<string, NDArray> arg_dict() {
    return GetDict(symbol_.ListArguments(), arg_arrays);
  }
  map<string, NDArray> grad_dict() {
    return GetDict(symbol_.ListArguments(), grad_arrays);
  }
  map<string, NDArray> aux_dict() {
    return GetDict(symbol_.ListAuxiliaryStates(), aux_arrays);
  }

 private:
  MXNetExecutor(const MXNetExecutor &e);
  MXNetExecutor &operator=(const MXNetExecutor &e);
  ExecutorHandle handle_;
  Symbol symbol_;
  map<string, NDArray> GetDict(const vector<string> &names,
                               const vector<NDArray> &arrays) {
    map<string, NDArray> ret;
    std::set<string> name_set;
    for (const auto &s : names) {
      CHECK(name_set.find(s) == name_set.end())
          << "Duplicate names detected, " << s;
      name_set.insert(s);
    }
    CHECK_EQ(name_set.size(), arrays.size())
        << "names size not equal to arrays size";
    for (size_t i = 0; i < names.size(); ++i) {
      ret[names[i]] = arrays[i];
    }
    return ret;
  }
};

}  // namespace dl
}  // namespace vf
#endif
