// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "grpc/grpc_server.h"

namespace wenet {

using grpc::ServerReaderWriter;
using wenet::Request;
using wenet::Response;

GrpcConnectionHandler::GrpcConnectionHandler(
    ServerReaderWriter<Response, Request>* stream,
    std::shared_ptr<Request> request, std::shared_ptr<Response> response,
    std::shared_ptr<FeaturePipelineConfig> feature_config,
    std::shared_ptr<DecodeOptions> decode_config,
    std::shared_ptr<fst::SymbolTable> symbol_table,
    std::shared_ptr<TorchAsrModel> model,
    std::shared_ptr<fst::Fst<fst::StdArc>> fst)
    : stream_(std::move(stream)),
      request_(std::move(request)),
      response_(std::move(response)),
      feature_config_(std::move(feature_config)),
      decode_config_(std::move(decode_config)),
      symbol_table_(std::move(symbol_table)),
      model_(std::move(model)),
      fst_(std::move(fst)) {}

void GrpcConnectionHandler::OnSpeechStart() {
  LOG(INFO) << "Recieved speech start signal, start reading speech";
  got_start_tag_ = true;
  response_->set_status(Response::ok);
  response_->set_type(Response::server_ready);
  stream_->Write(*response_);
  feature_pipeline_ = std::make_shared<FeaturePipeline>(*feature_config_);
  decoder_ = std::make_shared<TorchAsrDecoder>(
      feature_pipeline_, model_, symbol_table_, *decode_config_, fst_);
  // Start decoder thread
  decode_thread_ = std::make_shared<std::thread>(
      &GrpcConnectionHandler::DecodeThreadFunc, this);
}

void GrpcConnectionHandler::OnSpeechEnd() {
  LOG(INFO) << "Recieved speech end signal";
  CHECK(feature_pipeline_ != nullptr);
  feature_pipeline_->set_input_finished();
  got_end_tag_ = true;
}

void GrpcConnectionHandler::OnPartialResult() {
  LOG(INFO) << "Partial result";
  response_->set_status(Response::ok);
  response_->set_type(Response::partial_result);
  stream_->Write(*response_);
}

void GrpcConnectionHandler::OnFinalResult() {
  LOG(INFO) << "Final result";
  response_->set_status(Response::ok);
  response_->set_type(Response::final_result);
  stream_->Write(*response_);
}

void GrpcConnectionHandler::OnFinish() {
  // Send finish tag
  response_->set_status(Response::ok);
  response_->set_type(Response::speech_end);
  stream_->Write(*response_);
}

void GrpcConnectionHandler::OnSpeechData() {
  // Read binary PCM data
  const int16_t* pdata =
      reinterpret_cast<const int16_t*>(request_->audio_data().c_str());
  int num_samples = request_->audio_data().length() / sizeof(int16_t);
  std::vector<float> pcm_data(num_samples);
  for (int i = 0; i < num_samples; i++) {
    pcm_data[i] = static_cast<float>(*pdata);
    pdata++;
  }
  VLOG(2) << "Recieved " << num_samples << " samples";
  CHECK(feature_pipeline_ != nullptr);
  CHECK(decoder_ != nullptr);
  feature_pipeline_->AcceptWaveform(pcm_data);
}

void GrpcConnectionHandler::SerializeResult(bool finish) {
  for (const DecodeResult& path : decoder_->result()) {
    Response_OneBest* one_best_ = response_->add_nbest();
    one_best_->set_sentence(path.sentence);
    if (finish) {
      for (const WordPiece& word_piece : path.word_pieces) {
        Response_OnePiece* one_piece_ = one_best_->add_wordpieces();
        one_piece_->set_word(word_piece.word);
        one_piece_->set_start(word_piece.start);
        one_piece_->set_end(word_piece.end);
      }
    }
    if (response_->nbest_size() == nbest_) {
      break;
    }
  }
  return;
}

void GrpcConnectionHandler::DecodeThreadFunc() {
  while (true) {
    DecodeState state = decoder_->Decode();
    response_->clear_status();
    response_->clear_type();
    response_->clear_nbest();
    if (state == DecodeState::kEndFeats) {
      decoder_->Rescoring();
      SerializeResult(true);
      OnFinalResult();
      OnFinish();
      stop_recognition_ = true;
      break;
    } else if (state == DecodeState::kEndpoint) {
      decoder_->Rescoring();
      SerializeResult(true);
      OnFinalResult();
      // If it's not continuous decoidng, continue to do next recognition
      // otherwise stop the recognition
      if (continuous_decoding_) {
        decoder_->ResetContinuousDecoding();
      } else {
        OnFinish();
        stop_recognition_ = true;
        break;
      }
    } else {
      if (decoder_->DecodedSomething()) {
        SerializeResult(false);
        OnPartialResult();
      }
    }
  }
}

void GrpcConnectionHandler::operator()() {
  try {
    while (stream_->Read(request_.get())) {
      if (!got_start_tag_) {
        nbest_ = request_->decode_config().nbest_config();
        continuous_decoding_ =
            request_->decode_config().continuous_decoding_config();
        OnSpeechStart();
      } else {
        OnSpeechData();
      }
    }
    OnSpeechEnd();
    LOG(INFO) << "Read all pcm data, wait for decoding thread";
    if (decode_thread_ != nullptr) {
      decode_thread_->join();
    }
  } catch (std::exception const& e) {
    LOG(ERROR) << e.what();
  }
}

Status GrpcServer::Recognize(ServerContext* context,
                             ServerReaderWriter<Response, Request>* stream) {
  LOG(INFO) << "Get Recognize request" << std::endl;
  auto request = std::make_shared<Request>();
  auto response = std::make_shared<Response>();
  GrpcConnectionHandler handler(stream, request, response, feature_config_,
                                decode_config_, symbol_table_, model_, fst_);
  std::thread t(std::move(handler));
  t.join();
  return Status::OK;
}
}  // namespace wenet
