/**
 * Copyright (c) 2020 xxx Inc.
 * File              : pytorch-worker.cc
 * Author            : 
 * Date              : 2020-05-07
 * Last Modified Date: 2020-05-07
 * Last Modified By  : 
 */
#include "vf/vf-def.h"

#if USE_PyTorch

#include <vector>

//
#include <torch/script.h>
//
#include <cxxutil/logging.h>

#include "vf/dl/dl-worker.h"

using namespace cxxutil;
using namespace std;

namespace vf {
namespace dl {

class PytorchWorker : public DLWorker {
 public:
  PytorchWorker(const DLWorkerParam& param, const vf::Context& ctx)
      : DLWorker(param, ctx) {}
  virtual ~PytorchWorker();

  int Init() override;

 protected:
  void FeedData(DLTaskBatch* batch) override;
  void process(DLTaskBatch* batch) override;
  int forward(vector<torch::IValue>& input_tensors,
              c10::List<at::Tensor>* output_tensors);

 protected:
  int output_dim_ = 0;
  vector<int> output_dims_;
  torch::TensorOptions input_options_;
#if PyTorch_OLD
  std::shared_ptr<torch::jit::script::Module> module_;
#else
  torch::jit::script::Module module_;
#endif
  vector<float*> input_blobs_;
  vector<torch::IValue> input_tensors_;
  vector<vector<int64_t>> input_shapes_;
  vector<vector<float>> output_blobs_;
};

PytorchWorker::~PytorchWorker() {
  for (float* input_blob : input_blobs_) {
    VFFree(input_blob, ctx_);
  }
}

int PytorchWorker::Init() {
  module_ = torch::jit::load(param_.weight_path);

#if PyTorch_OLD
  if (nullptr == module_) {
    LOG_ERROR << "Load model from " << param_.weight_path << " failed";
    return Status_IOError;
  }
#endif

  input_options_ = torch::TensorOptions()
                       .dtype(torch::kFloat32)
                       .layout(torch::kStrided)
                       .requires_grad(false);

  if (kGPU == ctx_.dev_type) {
#if PyTorch_OLD
    module_->to(at::Device(at::kCUDA, ctx_.dev_id));
#else
    module_.to(at::Device(at::kCUDA, ctx_.dev_id));
#endif
    input_options_ = input_options_.device(torch::kCUDA, ctx_.dev_id);
  }
  int ret = DLWorker::Init();
  if (Status_OK != ret) return ret;

  input_tensors_.resize(param_.input_names.size());
  for (size_t i = 0; i < param_.input_names.size(); ++i) {
    float* input_blob = nullptr;
    // allocate memory for input blob
    VFMalloc((void**)(&input_blob),
             param_.batch_size * input_dims_[i] * sizeof(float), ctx_);
    input_blobs_.push_back(input_blob);
    vector<int64_t> input_shape;
    input_shape.push_back(param_.batch_size);
    for (int v : param_.input_shapes[i]) input_shape.push_back(v);
    input_tensors_[i] = torch::from_blob(
        input_blob,
        at::ArrayRef<int64_t>(input_shape.data(), input_shape.size()),
        input_options_);
    input_shapes_.push_back(input_shape);
  }

  // forward one pass
  c10::List<at::Tensor> output_tensors;
  ret = forward(input_tensors_, &output_tensors);
  if (Status_OK != ret) return ret;

  if (output_tensors.size() != param_.output_names.size()) {
    LOG_ERROR << "pytorch model has " << output_tensors.size()
              << " outputs, but " << param_.output_names.size()
              << " outputs are specified in the configuration file.";
    return Status_InvalidFormat;
  }

  output_dim_ = 0;
  for (size_t i = 0; i < output_tensors.size(); ++i) {
    const at::Tensor& out_tensor = output_tensors[i];
    if (out_tensor.dim() < 2) {
      LOG_ERROR << "dimension of outuptu " << i
                << " should be at least 2, but got " << out_tensor.dim();
      return Status_InvalidFormat;
    }
    //  check output dtype
    if (strcmp(out_tensor.dtype().name(), "float")) {
      LOG_ERROR << "PyTorchWorker current only supports float output, but got "
                << out_tensor.dtype().name();
      return Status_InvalidFormat;
    }
    // check output batch size
    if (out_tensor.size(0) != param_.batch_size) {
      LOG_ERROR << "the first dimension size of output " << i
                << " should be equal to batch size(" << param_.batch_size
                << "), but got " << out_tensor.size(0);
      return Status_InvalidFormat;
    }
    int feat_dim = 0;
    for (auto k = 1; k < out_tensor.dim(); ++k) {
      feat_dim += out_tensor.size(k);
    }
    output_dims_.push_back(feat_dim);
    output_dim_ += feat_dim;
  }
  LOG_INFO << "Output dimension: " << output_dim_;

  return Status_OK;
}

void PytorchWorker::FeedData(DLTaskBatch* batch) {
  CHECK(batch->index > 0 && batch->index <= param_.batch_size)
      << "batch size should be in [1," << param_.batch_size << "], but got "
      << batch->index;

  for (size_t i = 0; i < input_shapes_.size(); ++i) {
    // feed data
    if (kGPU == ctx_.dev_type) {
      PreProcess(i, batch->index, batch->gpu_data(stream_, i), input_blobs_[i]);
    } else if (kCPU == ctx_.dev_type) {
      PreProcess(i, batch->index, batch->cpu_data(stream_, i), input_blobs_[i]);
    } else {
      throw RuntimeException("unknown device type for CaffeWorker: " +
                             to_string(ctx_.dev_type));
    }
  }
}

void PytorchWorker::process(DLTaskBatch* batch) {
  for (size_t i = 0; i < param_.input_names.size(); ++i) {
    input_shapes_[i][0] = batch->index;
    input_shapes_[i][0] = param_.batch_size;
    input_tensors_[i] = torch::from_blob(
        input_blobs_[i],
        at::ArrayRef<int64_t>(input_shapes_[i].data(), input_shapes_[i].size()),
        input_options_);
  }
  c10::List<at::Tensor> output_tensors;
  int ret = forward(input_tensors_, &output_tensors);
  if (Status_OK != ret) return;

  size_t output_num = output_tensors.size();
  vector<const float*> output_ptrs;
  if (kGPU == ctx_.dev_type) {
    output_blobs_.resize(output_num);
    for (size_t i = 0; i < output_num; ++i) {
      const at::Tensor& tensor = output_tensors.get(i);
      int tensor_size = batch->index * output_dims_[i];
      output_blobs_[i].resize(tensor_size);
      cudaMemcpy(output_blobs_[i].data(), tensor.data_ptr<float>(),
                 tensor_size * sizeof(float), cudaMemcpyDefault);
      output_ptrs.push_back(output_blobs_[i].data());
    }
  } else {
    for (size_t i = 0; i < output_num; ++i) {
      output_ptrs.push_back(output_tensors.get(i).data_ptr<float>());
    }
  }

  WriteResult(output_ptrs, output_dim_, output_dims_, batch);
}

int PytorchWorker::forward(vector<torch::IValue>& input_tensors,
                           c10::List<at::Tensor>* output_tensors) {
#if PyTorch_OLD
  const auto& output = module_->forward(input_tensors);
#else
  const auto& output = module_.forward(input_tensors);
#endif

  if (output.isTensor()) {
    output_tensors->push_back(output.toTensor());
  } else if (output.isTensorList()) {
    *output_tensors = output.toTensorList();
  } else if (output.isTuple()) {
    const auto& elements = output.toTuple()->elements();
    for (const auto& elem : elements) {
      if (false == elem.isTensor()) {
        LOG_ERROR << "output elements in tuple must be Tensor, but got "
                  << elem.type()->str();
        return Status_InvalidFormat;
      }
      output_tensors->push_back(elem.toTensor());
    }
  } else {
    LOG_ERROR << "unsupported output value type: " << output.type()->str();
    return Status_InvalidFormat;
  }
  return Status_OK;
}

RegisterDLWorker(PytorchWorker, "pytorch", "");

}  // namespace dl
}  // namespace vf

#endif
