// fstext/lattice-weight.h
// Copyright 2009-2012  Microsoft Corporation
//                      Johns Hopkins University (author: Daniel Povey)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_FSTEXT_LATTICE_WEIGHT_H_
#define KALDI_FSTEXT_LATTICE_WEIGHT_H_

#include <algorithm>
#include <limits>
#include <string>
#include <vector>

#include "base/kaldi-common.h"
#include "fst/fstlib.h"

namespace fst {

// Declare weight type for lattice... will import to namespace kaldi.  has two
// members, value1_ and value2_, of type BaseFloat (normally equals float).  It
// is basically the same as the tropical semiring on value1_+value2_, except it
// keeps track of a and b separately.  More precisely, it is equivalent to the
// lexicographic semiring on (value1_+value2_), (value1_-value2_)

template <class FloatType>
class LatticeWeightTpl;

template <class FloatType>
inline std::ostream &operator<<(std::ostream &strm,
                                const LatticeWeightTpl<FloatType> &w);

template <class FloatType>
inline std::istream &operator>>(std::istream &strm,
                                LatticeWeightTpl<FloatType> &w);

template <class FloatType>
class LatticeWeightTpl {
 public:
  typedef FloatType T;  // normally float.
  typedef LatticeWeightTpl ReverseWeight;

  inline T Value1() const { return value1_; }

  inline T Value2() const { return value2_; }

  inline void SetValue1(T f) { value1_ = f; }

  inline void SetValue2(T f) { value2_ = f; }

  LatticeWeightTpl() : value1_{}, value2_{} {}

  LatticeWeightTpl(T a, T b) : value1_(a), value2_(b) {}

  LatticeWeightTpl(const LatticeWeightTpl &other)
      : value1_(other.value1_), value2_(other.value2_) {}

  LatticeWeightTpl &operator=(const LatticeWeightTpl &w) {
    value1_ = w.value1_;
    value2_ = w.value2_;
    return *this;
  }

  LatticeWeightTpl<FloatType> Reverse() const { return *this; }

  static const LatticeWeightTpl Zero() {
    return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
                            std::numeric_limits<T>::infinity());
  }

  static const LatticeWeightTpl One() { return LatticeWeightTpl(0.0, 0.0); }

  static const std::string &Type() {
    static const std::string type = (sizeof(T) == 4 ? "lattice4" : "lattice8");
    return type;
  }

  static const LatticeWeightTpl NoWeight() {
    return LatticeWeightTpl(std::numeric_limits<FloatType>::quiet_NaN(),
                            std::numeric_limits<FloatType>::quiet_NaN());
  }

  bool Member() const {
    // value1_ == value1_ tests for NaN.
    // also test for no -inf, and either both or neither
    // must be +inf, and
    if (value1_ != value1_ || value2_ != value2_) return false;  // NaN
    if (value1_ == -std::numeric_limits<T>::infinity() ||
        value2_ == -std::numeric_limits<T>::infinity())
      return false;  // -infty not allowed
    if (value1_ == std::numeric_limits<T>::infinity() ||
        value2_ == std::numeric_limits<T>::infinity()) {
      if (value1_ != std::numeric_limits<T>::infinity() ||
          value2_ != std::numeric_limits<T>::infinity())
        return false;  // both must be +infty;
      // this is necessary so that the semiring has only one zero.
    }
    return true;
  }

  LatticeWeightTpl Quantize(float delta = kDelta) const {
    if (value1_ + value2_ == -std::numeric_limits<T>::infinity()) {
      return LatticeWeightTpl(-std::numeric_limits<T>::infinity(),
                              -std::numeric_limits<T>::infinity());
    } else if (value1_ + value2_ == std::numeric_limits<T>::infinity()) {
      return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
                              std::numeric_limits<T>::infinity());
    } else if (value1_ + value2_ != value1_ + value2_) {  // NaN
      return LatticeWeightTpl(value1_ + value2_, value1_ + value2_);
    } else {
      return LatticeWeightTpl(floor(value1_ / delta + 0.5F) * delta,
                              floor(value2_ / delta + 0.5F) * delta);
    }
  }
  static constexpr uint64 Properties() {
    return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
  }

  // This is used in OpenFst for binary I/O.  This is OpenFst-style,
  // not Kaldi-style, I/O.
  std::istream &Read(std::istream &strm) {
    // Always read/write as float, even if T is double,
    // so we can use OpenFst-style read/write and still maintain
    // compatibility when compiling with different FloatTypes
    ReadType(strm, &value1_);
    ReadType(strm, &value2_);
    return strm;
  }

  // This is used in OpenFst for binary I/O.  This is OpenFst-style,
  // not Kaldi-style, I/O.
  std::ostream &Write(std::ostream &strm) const {
    WriteType(strm, value1_);
    WriteType(strm, value2_);
    return strm;
  }

  size_t Hash() const {
    size_t ans;
    union {
      T f;
      size_t s;
    } u;
    u.s = 0;
    u.f = value1_;
    ans = u.s;
    u.f = value2_;
    ans += u.s;
    return ans;
  }

 protected:
  inline static void WriteFloatType(std::ostream &strm, const T &f) {
    if (f == std::numeric_limits<T>::infinity())
      strm << "Infinity";
    else if (f == -std::numeric_limits<T>::infinity())
      strm << "-Infinity";
    else if (f != f)
      strm << "BadNumber";
    else
      strm << f;
  }

  // Internal helper function, used in ReadNoParen.
  inline static void ReadFloatType(std::istream &strm, T &f) {  // NOLINT
    std::string s;
    strm >> s;
    if (s == "Infinity") {
      f = std::numeric_limits<T>::infinity();
    } else if (s == "-Infinity") {
      f = -std::numeric_limits<T>::infinity();
    } else if (s == "BadNumber") {
      f = std::numeric_limits<T>::quiet_NaN();
    } else {
      char *p;
      f = strtod(s.c_str(), &p);
      if (p < s.c_str() + s.size()) strm.clear(std::ios::badbit);
    }
  }

  // Reads LatticeWeight when there are no parentheses around pair terms...
  // currently the only form supported.
  inline std::istream &ReadNoParen(std::istream &strm, char separator) {
    int c;
    do {
      c = strm.get();
    } while (isspace(c));

    std::string s1;
    while (c != separator) {
      if (c == EOF) {
        strm.clear(std::ios::badbit);
        return strm;
      }
      s1 += c;
      c = strm.get();
    }
    std::istringstream strm1(s1);
    ReadFloatType(strm1, value1_);  // ReadFloatType is class member function
    // read second element
    ReadFloatType(strm, value2_);
    return strm;
  }

  friend std::istream &operator>>
      <FloatType>(std::istream &, LatticeWeightTpl<FloatType> &);
  friend std::ostream &operator<<<FloatType>(
      std::ostream &, const LatticeWeightTpl<FloatType> &);

 private:
  T value1_;
  T value2_;
};

/* ScaleTupleWeight is a function defined for LatticeWeightTpl and
   CompactLatticeWeightTpl that mutliplies the pair (value1_, value2_) by a 2x2
   matrix.  Used, for example, in applying acoustic scaling.
 */
template <class FloatType, class ScaleFloatType>
inline LatticeWeightTpl<FloatType> ScaleTupleWeight(
    const LatticeWeightTpl<FloatType> &w,
    const std::vector<std::vector<ScaleFloatType> > &scale) {
  // Without the next special case we'd get NaNs from infinity * 0
  if (w.Value1() == std::numeric_limits<FloatType>::infinity())
    return LatticeWeightTpl<FloatType>::Zero();
  return LatticeWeightTpl<FloatType>(
      scale[0][0] * w.Value1() + scale[0][1] * w.Value2(),
      scale[1][0] * w.Value1() + scale[1][1] * w.Value2());
}

/* For testing purposes and in case it's ever useful, we define a similar
   function to apply to LexicographicWeight and the like, templated on
   TropicalWeight<float> etc.; we use PairWeight which is the base class of
   LexicographicWeight.
*/
template <class FloatType, class ScaleFloatType>
inline PairWeight<TropicalWeightTpl<FloatType>, TropicalWeightTpl<FloatType> >
ScaleTupleWeight(const PairWeight<TropicalWeightTpl<FloatType>,
                                  TropicalWeightTpl<FloatType> > &w,
                 const std::vector<std::vector<ScaleFloatType> > &scale) {
  typedef TropicalWeightTpl<FloatType> BaseType;
  typedef PairWeight<BaseType, BaseType> PairType;
  const BaseType zero = BaseType::Zero();
  // Without the next special case we'd get NaNs from infinity * 0
  if (w.Value1() == zero || w.Value2() == zero) return PairType(zero, zero);
  FloatType f1 = w.Value1().Value(), f2 = w.Value2().Value();
  return PairType(BaseType(scale[0][0] * f1 + scale[0][1] * f2),
                  BaseType(scale[1][0] * f1 + scale[1][1] * f2));
}

template <class FloatType>
inline bool operator==(const LatticeWeightTpl<FloatType> &wa,
                       const LatticeWeightTpl<FloatType> &wb) {
  // Volatile qualifier thwarts over-aggressive compiler optimizations
  // that lead to problems esp. with NaturalLess().
  volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
                     vb2 = wb.Value2();
  return (va1 == vb1 && va2 == vb2);
}

template <class FloatType>
inline bool operator!=(const LatticeWeightTpl<FloatType> &wa,
                       const LatticeWeightTpl<FloatType> &wb) {
  // Volatile qualifier thwarts over-aggressive compiler optimizations
  // that lead to problems esp. with NaturalLess().
  volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
                     vb2 = wb.Value2();
  return (va1 != vb1 || va2 != vb2);
}

// We define a Compare function LatticeWeightTpl even though it's
// not required by the semiring standard-- it's just more efficient
// to do it this way rather than using the NaturalLess template.

/// Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.

template <class FloatType>
inline int Compare(const LatticeWeightTpl<FloatType> &w1,
                   const LatticeWeightTpl<FloatType> &w2) {
  FloatType f1 = w1.Value1() + w1.Value2(), f2 = w2.Value1() + w2.Value2();
  if (f1 < f2) {  // having smaller cost means you're larger
    return 1;
  } else if (f1 > f2) {  // in the semiring [higher probability]
    return -1;
  } else if (w1.Value1() < w2.Value1()) {
  // mathematically we should be comparing (w1.value1_-w1.value2_ <
  // w2.value1_-w2.value2_) in the next line, but add w1.value1_+w1.value2_ =
  // w2.value1_+w2.value2_ to both sides and divide by two, and we get the
  // simpler equivalent form w1.value1_ < w2.value1_.
    return 1;
  } else if (w1.Value1() > w2.Value1()) {
    return -1;
  } else {
    return 0;
  }
}

template <class FloatType>
inline LatticeWeightTpl<FloatType> Plus(const LatticeWeightTpl<FloatType> &w1,
                                        const LatticeWeightTpl<FloatType> &w2) {
  return (Compare(w1, w2) >= 0 ? w1 : w2);
}

// For efficiency, override the NaturalLess template class.
template <class FloatType>
class NaturalLess<LatticeWeightTpl<FloatType> > {
 public:
  typedef LatticeWeightTpl<FloatType> Weight;

  NaturalLess() {}

  bool operator()(const Weight &w1, const Weight &w2) const {
    // NaturalLess is a negative order (opposite to normal ordering).
    // This operator () corresponds to "<" in the negative order, which
    // corresponds to the ">" in the normal order.
    return (Compare(w1, w2) == 1);
  }
};
template <>
class NaturalLess<LatticeWeightTpl<float> > {
 public:
  typedef LatticeWeightTpl<float> Weight;

  NaturalLess() {}

  bool operator()(const Weight &w1, const Weight &w2) const {
    // NaturalLess is a negative order (opposite to normal ordering).
    // This operator () corresponds to "<" in the negative order, which
    // corresponds to the ">" in the normal order.
    return (Compare(w1, w2) == 1);
  }
};
template <>
class NaturalLess<LatticeWeightTpl<double> > {
 public:
  typedef LatticeWeightTpl<double> Weight;

  NaturalLess() {}

  bool operator()(const Weight &w1, const Weight &w2) const {
    // NaturalLess is a negative order (opposite to normal ordering).
    // This operator () corresponds to "<" in the negative order, which
    // corresponds to the ">" in the normal order.
    return (Compare(w1, w2) == 1);
  }
};

template <class FloatType>
inline LatticeWeightTpl<FloatType> Times(
    const LatticeWeightTpl<FloatType> &w1,
    const LatticeWeightTpl<FloatType> &w2) {
  return LatticeWeightTpl<FloatType>(w1.Value1() + w2.Value1(),
                                     w1.Value2() + w2.Value2());
}

// divide w1 by w2 (on left/right/any doesn't matter as
// commutative).
template <class FloatType>
inline LatticeWeightTpl<FloatType> Divide(const LatticeWeightTpl<FloatType> &w1,
                                          const LatticeWeightTpl<FloatType> &w2,
                                          DivideType typ = DIVIDE_ANY) {
  typedef FloatType T;
  T a = w1.Value1() - w2.Value1(), b = w1.Value2() - w2.Value2();
  if (a != a || b != b || a == -std::numeric_limits<T>::infinity() ||
      b == -std::numeric_limits<T>::infinity()) {
    KALDI_WARN << "LatticeWeightTpl::Divide, NaN or invalid number produced. "
               << "[dividing by zero?]  Returning zero";
    return LatticeWeightTpl<T>::Zero();
  }
  if (a == std::numeric_limits<T>::infinity() ||
      b == std::numeric_limits<T>::infinity())
    return LatticeWeightTpl<T>::Zero();  // not a valid number if only one is
                                         // infinite.
  return LatticeWeightTpl<T>(a, b);
}

template <class FloatType>
inline bool ApproxEqual(const LatticeWeightTpl<FloatType> &w1,
                        const LatticeWeightTpl<FloatType> &w2,
                        float delta = kDelta) {
  if (w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2())
    return true;  // handles Zero().
  return (fabs((w1.Value1() + w1.Value2()) - (w2.Value1() + w2.Value2())) <=
          delta);
}

template <class FloatType>
inline std::ostream &operator<<(std::ostream &strm,
                                const LatticeWeightTpl<FloatType> &w) {
  LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value1());
  CHECK(FLAGS_fst_weight_separator.size() == 1);  // NOLINT
  strm << FLAGS_fst_weight_separator[0];  // comma by default;
  // may or may not be settable from Kaldi programs.
  LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value2());
  return strm;
}

template <class FloatType>
inline std::istream &operator>>(std::istream &strm,
                                LatticeWeightTpl<FloatType> &w1) {
  CHECK(FLAGS_fst_weight_separator.size() == 1);  // NOLINT
  // separator defaults to ','
  return w1.ReadNoParen(strm, FLAGS_fst_weight_separator[0]);
}

// CompactLattice will be an acceptor (accepting the words/output-symbols),
// with the weights and input-symbol-seqs on the arcs.
// There must be a total order on W.  We assume for the sake of efficiency
// that there is a function
// Compare(W w1, W w2) that returns -1 if w1 < w2, +1 if w1 > w2, and
// zero if w1 == w2, and Plus for type W returns (Compare(w1,w2) >= 0 ? w1 :
// w2).

template <class WeightType, class IntType>
class CompactLatticeWeightTpl {
 public:
  typedef WeightType W;

  typedef CompactLatticeWeightTpl<WeightType, IntType> ReverseWeight;

  // Plus is like LexicographicWeight on the pair (weight_, string_), but where
  // we use standard lexicographic order on string_ [this is not the same as
  // NaturalLess on the StringWeight equivalent, which does not define a
  // total order].
  // Times, Divide obvious... (support both left & right division..)
  // CommonDivisor would need to be coded separately.

  CompactLatticeWeightTpl() {}

  CompactLatticeWeightTpl(const WeightType &w, const std::vector<IntType> &s)
      : weight_(w), string_(s) {}

  CompactLatticeWeightTpl &operator=(
      const CompactLatticeWeightTpl<WeightType, IntType> &w) {
    weight_ = w.weight_;
    string_ = w.string_;
    return *this;
  }

  const W &Weight() const { return weight_; }

  const std::vector<IntType> &String() const { return string_; }

  void SetWeight(const W &w) { weight_ = w; }

  void SetString(const std::vector<IntType> &s) { string_ = s; }

  static const CompactLatticeWeightTpl<WeightType, IntType> Zero() {
    return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::Zero(),
                                                        std::vector<IntType>());
  }

  static const CompactLatticeWeightTpl<WeightType, IntType> One() {
    return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::One(),
                                                        std::vector<IntType>());
  }

  inline static std::string GetIntSizeString() {
    char buf[2];
    buf[0] = '0' + sizeof(IntType);
    buf[1] = '\0';
    return buf;
  }
  static const std::string &Type() {
    static const std::string type =
        "compact" + WeightType::Type() + GetIntSizeString();
    return type;
  }

  static const CompactLatticeWeightTpl<WeightType, IntType> NoWeight() {
    return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::NoWeight(),
                                                        std::vector<IntType>());
  }

  CompactLatticeWeightTpl<WeightType, IntType> Reverse() const {
    size_t s = string_.size();
    std::vector<IntType> v(s);
    for (size_t i = 0; i < s; i++) v[i] = string_[s - i - 1];
    return CompactLatticeWeightTpl<WeightType, IntType>(weight_, v);
  }

  bool Member() const {
    // a semiring has only one zero, this is the important property
    // we're trying to maintain here.  So force string_ to be empty if
    // w_ == zero.
    if (!weight_.Member()) return false;
    if (weight_ == WeightType::Zero())
      return string_.empty();
    else
      return true;
  }

  CompactLatticeWeightTpl Quantize(float delta = kDelta) const {
    return CompactLatticeWeightTpl(weight_.Quantize(delta), string_);
  }

  static constexpr uint64 Properties() {
    return kLeftSemiring | kRightSemiring | kPath | kIdempotent;
  }

  // This is used in OpenFst for binary I/O.  This is OpenFst-style,
  // not Kaldi-style, I/O.
  std::istream &Read(std::istream &strm) {
    weight_.Read(strm);
    if (strm.fail()) {
      return strm;
    }
    int32 sz;
    ReadType(strm, &sz);
    if (strm.fail()) {
      return strm;
    }
    if (sz < 0) {
      KALDI_WARN << "Negative string size!  Read failure";
      strm.clear(std::ios::badbit);
      return strm;
    }
    string_.resize(sz);
    for (int32 i = 0; i < sz; i++) {
      ReadType(strm, &(string_[i]));
    }
    return strm;
  }

  // This is used in OpenFst for binary I/O.  This is OpenFst-style,
  // not Kaldi-style, I/O.
  std::ostream &Write(std::ostream &strm) const {
    weight_.Write(strm);
    if (strm.fail()) {
      return strm;
    }
    int32 sz = static_cast<int32>(string_.size());
    WriteType(strm, sz);
    for (int32 i = 0; i < sz; i++) WriteType(strm, string_[i]);
    return strm;
  }
  size_t Hash() const {
    size_t ans = weight_.Hash();
    // any weird numbers here are largish primes
    size_t sz = string_.size(), mult = 6967;
    for (size_t i = 0; i < sz; i++) {
      ans += string_[i] * mult;
      mult *= 7499;
    }
    return ans;
  }

 private:
  W weight_;
  std::vector<IntType> string_;
};

template <class WeightType, class IntType>
inline bool operator==(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
                       const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
  return (w1.Weight() == w2.Weight() && w1.String() == w2.String());
}

template <class WeightType, class IntType>
inline bool operator!=(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
                       const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
  return (w1.Weight() != w2.Weight() || w1.String() != w2.String());
}

template <class WeightType, class IntType>
inline bool ApproxEqual(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
                        const CompactLatticeWeightTpl<WeightType, IntType> &w2,
                        float delta = kDelta) {
  return (ApproxEqual(w1.Weight(), w2.Weight(), delta) &&
          w1.String() == w2.String());
}

// Compare is not part of the standard for weight types, but used internally for
// efficiency.  The comparison here first compares the weight; if this is the
// same, it compares the string.  The comparison on strings is: first compare
// the length, if this is the same, use lexicographical order.  We can't just
// use the lexicographical order because this would destroy the distributive
// property of multiplication over addition, taking into account that addition
// uses Compare.  The string element of "Compare" isn't super-important in
// practical terms; it's only needed to ensure that Plus always give consistent
// answers and is symmetric.  It's essentially for tie-breaking, but we need to
// make sure all the semiring axioms are satisfied otherwise OpenFst might
// break.

template <class WeightType, class IntType>
inline int Compare(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
                   const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
  int c1 = Compare(w1.Weight(), w2.Weight());
  if (c1 != 0) return c1;
  int l1 = w1.String().size(), l2 = w2.String().size();
  // Use opposite order on the string lengths, so that if the costs are the
  // same, the shorter string wins.
  if (l1 > l2)
    return -1;
  else if (l1 < l2)
    return 1;
  for (int i = 0; i < l1; i++) {
    if (w1.String()[i] < w2.String()[i])
      return -1;
    else if (w1.String()[i] > w2.String()[i])
      return 1;
  }
  return 0;
}

// For efficiency, override the NaturalLess template class.
template <class FloatType, class IntType>
class NaturalLess<
    CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> > {
 public:
  typedef CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> Weight;

  NaturalLess() {}

  bool operator()(const Weight &w1, const Weight &w2) const {
    // NaturalLess is a negative order (opposite to normal ordering).
    // This operator () corresponds to "<" in the negative order, which
    // corresponds to the ">" in the normal order.
    return (Compare(w1, w2) == 1);
  }
};
template <>
class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> > {
 public:
  typedef CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> Weight;

  NaturalLess() {}

  bool operator()(const Weight &w1, const Weight &w2) const {
    // NaturalLess is a negative order (opposite to normal ordering).
    // This operator () corresponds to "<" in the negative order, which
    // corresponds to the ">" in the normal order.
    return (Compare(w1, w2) == 1);
  }
};
template <>
class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> > {
 public:
  typedef CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> Weight;

  NaturalLess() {}

  bool operator()(const Weight &w1, const Weight &w2) const {
    // NaturalLess is a negative order (opposite to normal ordering).
    // This operator () corresponds to "<" in the negative order, which
    // corresponds to the ">" in the normal order.
    return (Compare(w1, w2) == 1);
  }
};

// Make sure Compare is defined for TropicalWeight, so everything works
// if we substitute LatticeWeight for TropicalWeight.
inline int Compare(const TropicalWeight &w1, const TropicalWeight &w2) {
  float f1 = w1.Value(), f2 = w2.Value();
  if (f1 == f2)
    return 0;
  else if (f1 > f2)
    return -1;
  else
    return 1;
}

template <class WeightType, class IntType>
inline CompactLatticeWeightTpl<WeightType, IntType> Plus(
    const CompactLatticeWeightTpl<WeightType, IntType> &w1,
    const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
  return (Compare(w1, w2) >= 0 ? w1 : w2);
}

template <class WeightType, class IntType>
inline CompactLatticeWeightTpl<WeightType, IntType> Times(
    const CompactLatticeWeightTpl<WeightType, IntType> &w1,
    const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
  WeightType w = Times(w1.Weight(), w2.Weight());
  if (w == WeightType::Zero()) {
    return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
    // special case to ensure zero is unique
  } else {
    std::vector<IntType> v;
    v.resize(w1.String().size() + w2.String().size());
    typename std::vector<IntType>::iterator iter = v.begin();
    iter = std::copy(w1.String().begin(), w1.String().end(),
                     iter);  // returns end of first range.
    std::copy(w2.String().begin(), w2.String().end(), iter);
    return CompactLatticeWeightTpl<WeightType, IntType>(w, v);
  }
}

template <class WeightType, class IntType>
inline CompactLatticeWeightTpl<WeightType, IntType> Divide(
    const CompactLatticeWeightTpl<WeightType, IntType> &w1,
    const CompactLatticeWeightTpl<WeightType, IntType> &w2,
    DivideType div = DIVIDE_ANY) {
  if (w1.Weight() == WeightType::Zero()) {
    if (w2.Weight() != WeightType::Zero()) {
      return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
    } else {
      KALDI_ERR << "Division by zero [0/0]";
    }
  } else if (w2.Weight() == WeightType::Zero()) {
    KALDI_ERR << "Error: division by zero";
  }
  WeightType w = Divide(w1.Weight(), w2.Weight());

  const std::vector<IntType> v1 = w1.String(), v2 = w2.String();
  if (v2.size() > v1.size()) {
    KALDI_ERR << "Cannot divide, length mismatch";
  }
  typename std::vector<IntType>::const_iterator v1b = v1.begin(),
                                                v1e = v1.end(),
                                                v2b = v2.begin(),
                                                v2e = v2.end();
  if (div == DIVIDE_LEFT) {
    if (!std::equal(v2b, v2e,
                    v1b)) {  // v2 must be identical to first part of v1.
      KALDI_ERR << "Cannot divide, data mismatch";
    }
    return CompactLatticeWeightTpl<WeightType, IntType>(
        w, std::vector<IntType>(v1b + (v2e - v2b),
                                v1e));  // return last part of v1.
  } else if (div == DIVIDE_RIGHT) {
    if (!std::equal(
            v2b, v2e,
            v1e - (v2e - v2b))) {  // v2 must be identical to last part of v1.
      KALDI_ERR << "Cannot divide, data mismatch";
    }
    return CompactLatticeWeightTpl<WeightType, IntType>(
        w, std::vector<IntType>(
               v1b, v1e - (v2e - v2b)));  // return first part of v1.

  } else {
    KALDI_ERR << "Cannot divide CompactLatticeWeightTpl with DIVIDE_ANY";
  }
  return CompactLatticeWeightTpl<WeightType,
                                 IntType>::Zero();  // keep compiler happy.
}

template <class WeightType, class IntType>
inline std::ostream &operator<<(
    std::ostream &strm, const CompactLatticeWeightTpl<WeightType, IntType> &w) {
  strm << w.Weight();
  CHECK(FLAGS_fst_weight_separator.size() == 1);  // NOLINT
  strm << FLAGS_fst_weight_separator[0];  // comma by default.
  for (size_t i = 0; i < w.String().size(); i++) {
    strm << w.String()[i];
    if (i + 1 < w.String().size())
      strm << kStringSeparator;  // '_'; defined in string-weight.h in OpenFst
                                 // code.
  }
  return strm;
}

template <class WeightType, class IntType>
inline std::istream &operator>>(
    std::istream &strm, CompactLatticeWeightTpl<WeightType, IntType> &w) {
  std::string s;
  strm >> s;
  if (strm.fail()) {
    return strm;
  }
  CHECK(FLAGS_fst_weight_separator.size() == 1);  // NOLINT
  size_t pos = s.find_last_of(FLAGS_fst_weight_separator);  // normally ","
  if (pos == std::string::npos) {
    strm.clear(std::ios::badbit);
    return strm;
  }
  // get parts of str before and after the separator (default: ',');
  std::string s1(s, 0, pos), s2(s, pos + 1);
  std::istringstream strm1(s1);
  WeightType weight;
  strm1 >> weight;
  w.SetWeight(weight);
  if (strm1.fail() || !strm1.eof()) {
    strm.clear(std::ios::badbit);
    return strm;
  }
  // read string part.
  std::vector<IntType> string;
  const char *c = s2.c_str();
  while (*c != '\0') {
    if (*c == kStringSeparator)  // '_'
      c++;
    char *c2;
    int64_t i = strtol(c, &c2, 10);
    if (c2 == c || static_cast<int64_t>(static_cast<IntType>(i)) != i) {
      strm.clear(std::ios::badbit);
      return strm;
    }
    c = c2;
    string.push_back(static_cast<IntType>(i));
  }
  w.SetString(string);
  return strm;
}

template <class BaseWeightType, class IntType>
class CompactLatticeWeightCommonDivisorTpl {
 public:
  typedef CompactLatticeWeightTpl<BaseWeightType, IntType> Weight;

  Weight operator()(const Weight &w1, const Weight &w2) const {
    // First find longest common prefix of the strings.
    typename std::vector<IntType>::const_iterator s1b = w1.String().begin(),
                                                  s1e = w1.String().end(),
                                                  s2b = w2.String().begin(),
                                                  s2e = w2.String().end();
    while (s1b < s1e && s2b < s2e && *s1b == *s2b) {
      s1b++;
      s2b++;
    }
    return Weight(Plus(w1.Weight(), w2.Weight()),
                  std::vector<IntType>(w1.String().begin(), s1b));
  }
};

/** Scales the pair (a, b) of floating-point weights inside a
    CompactLatticeWeight by premultiplying it (viewed as a vector)
    by a 2x2 matrix "scale".
    Assumes there is a ScaleTupleWeight function that applies to "Weight";
    this currently only works if Weight equals LatticeWeightTpl<FloatType>
    for some FloatType.
*/
template <class Weight, class IntType, class ScaleFloatType>
inline CompactLatticeWeightTpl<Weight, IntType> ScaleTupleWeight(
    const CompactLatticeWeightTpl<Weight, IntType> &w,
    const std::vector<std::vector<ScaleFloatType> > &scale) {
  return CompactLatticeWeightTpl<Weight, IntType>(
      Weight(ScaleTupleWeight(w.Weight(), scale)), w.String());
}

/** Define some ConvertLatticeWeight functions that are used in various lattice
    conversions... make them all templates, some with no arguments, since some
    must be templates.*/
template <class Float1, class Float2>
inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1> &w_in,
                                 LatticeWeightTpl<Float2> *w_out) {
  w_out->SetValue1(w_in.Value1());
  w_out->SetValue2(w_in.Value2());
}

template <class Float1, class Float2, class Int>
inline void ConvertLatticeWeight(
    const CompactLatticeWeightTpl<LatticeWeightTpl<Float1>, Int> &w_in,
    CompactLatticeWeightTpl<LatticeWeightTpl<Float2>, Int> *w_out) {
  LatticeWeightTpl<Float2> weight2(w_in.Weight().Value1(),
                                   w_in.Weight().Value2());
  w_out->SetWeight(weight2);
  w_out->SetString(w_in.String());
}

// to convert from Lattice to standard FST
template <class Float1, class Float2>
inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1> &w_in,
                                 TropicalWeightTpl<Float2> *w_out) {
  TropicalWeightTpl<Float2> w1(w_in.Value1());
  TropicalWeightTpl<Float2> w2(w_in.Value2());
  *w_out = Times(w1, w2);
}

template <class Float>
inline double ConvertToCost(const LatticeWeightTpl<Float> &w) {
  return static_cast<double>(w.Value1()) + static_cast<double>(w.Value2());
}

template <class Float, class Int>
inline double ConvertToCost(
    const CompactLatticeWeightTpl<LatticeWeightTpl<Float>, Int> &w) {
  return static_cast<double>(w.Weight().Value1()) +
         static_cast<double>(w.Weight().Value2());
}

template <class Float>
inline double ConvertToCost(const TropicalWeightTpl<Float> &w) {
  return w.Value();
}

}  // namespace fst

#endif  // KALDI_FSTEXT_LATTICE_WEIGHT_H_
