// decoder/lattice-faster-online-decoder.cc

// Copyright 2009-2012  Microsoft Corporation  Mirko Hannemann
//           2013-2014  Johns Hopkins University (Author: Daniel Povey)
//                2014  Guoguo Chen
//                2014  IMSL, PKU-HKUST (author: Wei Shi)
//                2018  Zhehuai Chen

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

// see note at the top of lattice-faster-decoder.cc, about how to maintain this
// file in sync with lattice-faster-decoder.cc

#include <limits>
#include <queue>
#include <unordered_map>
#include <utility>

#include "decoder/lattice-faster-online-decoder.h"

namespace kaldi {

template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::TestGetBestPath(
    bool use_final_probs) const {
  Lattice lat1;
  {
    Lattice raw_lat;
    this->GetRawLattice(&raw_lat, use_final_probs);
    ShortestPath(raw_lat, &lat1);
  }
  Lattice lat2;
  GetBestPath(&lat2, use_final_probs);
  BaseFloat delta = 0.1;
  int32 num_paths = 1;
  if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
    KALDI_WARN << "Best-path test failed";
    return false;
  } else {
    return true;
  }
}

// Outputs an FST corresponding to the single best path through the lattice.
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetBestPath(
    Lattice *olat, bool use_final_probs) const {
  olat->DeleteStates();
  BaseFloat final_graph_cost;
  BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
  if (iter.Done()) return false;  // would have printed warning.
  StateId state = olat->AddState();
  olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
  while (!iter.Done()) {
    LatticeArc arc;
    iter = TraceBackBestPath(iter, &arc);
    arc.nextstate = state;
    StateId new_state = olat->AddState();
    olat->AddArc(new_state, arc);
    state = new_state;
  }
  olat->SetStart(state);
  return true;
}

template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator
LatticeFasterOnlineDecoderTpl<FST>::BestPathEnd(
    bool use_final_probs, BaseFloat *final_cost_out) const {
  if (this->decoding_finalized_ && !use_final_probs)
    KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
              << "BestPathEnd() with use_final_probs == false";
  KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
               "You cannot call BestPathEnd if no frames were decoded.");

  unordered_map<Token *, BaseFloat> final_costs_local;

  const unordered_map<Token *, BaseFloat> &final_costs =
      (this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
  if (!this->decoding_finalized_ && use_final_probs)
    this->ComputeFinalCosts(&final_costs_local, NULL, NULL);

  // Singly linked list of tokens on last frame (access list through "next"
  // pointer).
  BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
  BaseFloat best_final_cost = 0;
  Token *best_tok = NULL;
  for (Token *tok = this->active_toks_.back().toks; tok != NULL;
       tok = tok->next) {
    BaseFloat cost = tok->tot_cost, final_cost = 0.0;
    if (use_final_probs && !final_costs.empty()) {
      // if we are instructed to use final-probs, and any final tokens were
      // active on final frame, include the final-prob in the cost of the token.
      typename unordered_map<Token *, BaseFloat>::const_iterator iter =
          final_costs.find(tok);
      if (iter != final_costs.end()) {
        final_cost = iter->second;
        cost += final_cost;
      } else {
        cost = std::numeric_limits<BaseFloat>::infinity();
      }
    }
    if (cost < best_cost) {
      best_cost = cost;
      best_tok = tok;
      best_final_cost = final_cost;
    }
  }
  if (best_tok ==
      NULL) {  // this should not happen, and is likely a code error or
    // caused by infinities in likelihoods, but I'm not making
    // it a fatal error for now.
    KALDI_WARN << "No final token found.";
  }
  if (final_cost_out) *final_cost_out = best_final_cost;
  return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
}

template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator
LatticeFasterOnlineDecoderTpl<FST>::TraceBackBestPath(BestPathIterator iter,
                                                      LatticeArc *oarc) const {
  KALDI_ASSERT(!iter.Done() && oarc != NULL);
  Token *tok = static_cast<Token *>(iter.tok);
  int32 cur_t = iter.frame, step_t = 0;
  if (tok->backpointer != NULL) {
    // retrieve the correct forward link(with the best link cost)
    BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
    ForwardLinkT *link;
    for (link = tok->backpointer->links; link != NULL; link = link->next) {
      if (link->next_tok == tok) {  // this is a link to "tok"
        BaseFloat graph_cost = link->graph_cost,
                  acoustic_cost = link->acoustic_cost;
        BaseFloat cost = graph_cost + acoustic_cost;
        if (cost < best_cost) {
          oarc->ilabel = link->ilabel;
          oarc->olabel = link->olabel;
          if (link->ilabel != 0) {
            KALDI_ASSERT(static_cast<size_t>(cur_t) <
                         this->cost_offsets_.size());
            acoustic_cost -= this->cost_offsets_[cur_t];
            step_t = -1;
          } else {
            step_t = 0;
          }
          oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
          best_cost = cost;
        }
      }
    }
    if (link == NULL &&
        best_cost ==
            std::numeric_limits<BaseFloat>::infinity()) {  // Did not find
                                                           // correct link.
      KALDI_ERR << "Error tracing best-path back (likely "
                << "bug in token-pruning algorithm)";
    }
  } else {
    oarc->ilabel = 0;
    oarc->olabel = 0;
    oarc->weight = LatticeWeight::One();  // zero costs.
  }
  return BestPathIterator(tok->backpointer, cur_t + step_t);
}

template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
    Lattice *ofst, bool use_final_probs, BaseFloat beam) const {
  typedef LatticeArc Arc;
  typedef Arc::StateId StateId;
  typedef Arc::Weight Weight;
  typedef Arc::Label Label;

  // Note: you can't use the old interface (Decode()) if you want to
  // get the lattice with use_final_probs = false.  You'd have to do
  // InitDecoding() and then AdvanceDecoding().
  if (this->decoding_finalized_ && !use_final_probs)
    KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
              << "GetRawLattice() with use_final_probs == false";

  unordered_map<Token *, BaseFloat> final_costs_local;

  const unordered_map<Token *, BaseFloat> &final_costs =
      (this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
  if (!this->decoding_finalized_ && use_final_probs)
    this->ComputeFinalCosts(&final_costs_local, NULL, NULL);

  ofst->DeleteStates();
  // num-frames plus one (since frames are one-based, and we have
  // an extra frame for the start-state).
  int32 num_frames = this->active_toks_.size() - 1;
  KALDI_ASSERT(num_frames > 0);
  for (int32 f = 0; f <= num_frames; f++) {
    if (this->active_toks_[f].toks == NULL) {
      KALDI_WARN << "No tokens active on frame " << f
                 << ": not producing lattice.\n";
      return false;
    }
  }
  unordered_map<Token *, StateId> tok_map;
  std::queue<std::pair<Token *, int32> > tok_queue;
  // First initialize the queue and states.  Put the initial state on the queue;
  // this is the last token in the list active_toks_[0].toks.
  for (Token *tok = this->active_toks_[0].toks; tok != NULL; tok = tok->next) {
    if (tok->next == NULL) {
      tok_map[tok] = ofst->AddState();
      ofst->SetStart(tok_map[tok]);
      std::pair<Token *, int32> tok_pair(tok, 0);  // #frame = 0
      tok_queue.push(tok_pair);
    }
  }

  // Next create states for "good" tokens
  while (!tok_queue.empty()) {
    std::pair<Token *, int32> cur_tok_pair = tok_queue.front();
    tok_queue.pop();
    Token *cur_tok = cur_tok_pair.first;
    int32 cur_frame = cur_tok_pair.second;
    KALDI_ASSERT(cur_frame >= 0 && cur_frame <= this->cost_offsets_.size());

    typename unordered_map<Token *, StateId>::const_iterator iter =
        tok_map.find(cur_tok);
    KALDI_ASSERT(iter != tok_map.end());
    StateId cur_state = iter->second;

    for (ForwardLinkT *l = cur_tok->links; l != NULL; l = l->next) {
      Token *next_tok = l->next_tok;
      if (next_tok->extra_cost < beam) {
        // so both the current and the next token are good; create the arc
        int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
        StateId nextstate;
        if (tok_map.find(next_tok) == tok_map.end()) {
          nextstate = tok_map[next_tok] = ofst->AddState();
          tok_queue.push(std::pair<Token *, int32>(next_tok, next_frame));
        } else {
          nextstate = tok_map[next_tok];
        }
        BaseFloat cost_offset =
            (l->ilabel != 0 ? this->cost_offsets_[cur_frame] : 0);
        Arc arc(l->ilabel, l->olabel,
                Weight(l->graph_cost, l->acoustic_cost - cost_offset),
                nextstate);
        ofst->AddArc(cur_state, arc);
      }
    }
    if (cur_frame == num_frames) {
      if (use_final_probs && !final_costs.empty()) {
        typename unordered_map<Token *, BaseFloat>::const_iterator iter =
            final_costs.find(cur_tok);
        if (iter != final_costs.end())
          ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
      } else {
        ofst->SetFinal(cur_state, LatticeWeight::One());
      }
    }
  }
  return (ofst->NumStates() != 0);
}

// Instantiate the template for the FST types that we'll need.
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;

}  // end namespace kaldi.
