// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// fixedpoint.h: fixed-point arithmetic, with basic operations and
// a few math functions such as tanh.

#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
#define GEMMLOWP_INTERNAL_FIXEDPOINT_H_

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <limits>

#include "../internal/detect_platform.h"

namespace gemmlowp {

// Part 1: Low-level integer-arithmetic primitives.
// The implementations here are generic implementations valid for
// scalar types (e.g. std::int32_t). Architecture-specific SIMD types
// (e.g. NEON int32x4_t) may be supported by providing
// specializations for them in separate files.
//
// The purpose of these primitives is two-fold:
//  - They will be used to implement higher-level fixed-point
//    abstractions, namely the FixedPoint class and its arithmetic
//    operators.
//  - They will be directly used to implement some more involved
//    fixed-point computations, e.g. the fixed-point implementation
//    of math functions such as tanh.

// Some compile-time traits around raw types to handle SIMD aspects:
// number of lanes, underlying scalar type.
template <typename tIntegerType>
struct FixedPointRawTypeTraits {};

template <>
struct FixedPointRawTypeTraits<std::int32_t> {
  typedef std::int32_t ScalarRawType;
  static constexpr int kLanes = 1;
};

template <>
struct FixedPointRawTypeTraits<std::int16_t> {
  typedef std::int16_t ScalarRawType;
  static constexpr int kLanes = 1;
};

// Returns a SIMD value duplicating a scalar value across all lanes.
template <typename tRawType>
tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
  return x;
}

// Plain bit-wise AND
template <typename tIntegerType>
tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
  return a & b;
}

// Plain bit-wise OR
template <typename tIntegerType>
tIntegerType BitOr(tIntegerType a, tIntegerType b) {
  return a | b;
}

// Plain bit-wise XOR
template <typename tIntegerType>
tIntegerType BitXor(tIntegerType a, tIntegerType b) {
  return a ^ b;
}

// Plain bit-wise NOT
template <typename tIntegerType>
tIntegerType BitNot(tIntegerType a) {
  return ~a;
}

// Integer addition. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Add(tIntegerType a, tIntegerType b) {
  return a + b;
}

// Integer multiplication. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Mul(tIntegerType a, tIntegerType b) {
  return a * b;
}

// Integer subtraction. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Sub(tIntegerType a, tIntegerType b) {
  return a - b;
}

// Integer unary negative. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Neg(tIntegerType a) {
  return -a;
}

// Integer arithmetic left-shift, equivalent to multiplying with a power of two.
// Negative values are OK. In case of overflow, no Undefined
// Behavior, but the results are implementation-defined (in practice,
// they currently are saturated, but we make no commitment to that). The idea
// is that the caller will want to implement the overflowing cases with
// saturation with compare-and-mask, so we don't care about the results
// in the overflow case, we just want to avoid undefined behavior.
//
// tIntegerType may be int32 or any narrower signed type.
template <typename tIntegerType, typename OffsetType>
tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
  const std::int64_t wide_a = static_cast<std::int64_t>(a);
  const std::int64_t wide_shifted = wide_a * (1 << offset);
  const auto min = std::numeric_limits<tIntegerType>::min();
  const auto max = std::numeric_limits<tIntegerType>::max();
  return wide_shifted < min
             ? min
             : wide_shifted > max ? max
                                  : static_cast<tIntegerType>(wide_shifted);
}

// Integer arithmetic right-shift. Not rounding.
// Relying on implementation-defined, but in-practice-consistent,
// C++ compiler behavior.
template <typename tIntegerType>
tIntegerType ShiftRight(tIntegerType a, int offset) {
  return a >> offset;
}

// Each bit of the result is set to the corresponding bit of either then_val or
// else_val depending on whether the corresponding bit of if_mask is set.
// Equivalent to the VBSL instruction in ARM NEON.
template <typename tIntegerType>
tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
                             tIntegerType else_val) {
  return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
}

// For each input scalar, the corresponding bits of the result are set if the
// input scalar is non-zero.
template <typename tIntegerType>
tIntegerType MaskIfNonZero(tIntegerType a) {
  static constexpr tIntegerType zero = 0;
  return a ? BitNot(zero) : zero;
}

// For each input scalar, the corresponding bits of the result are set if the
// input scalar is zero.
template <typename tIntegerType>
tIntegerType MaskIfZero(tIntegerType a) {
  return MaskIfNonZero<tIntegerType>(!a);
}

// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars are equal.
template <typename tIntegerType>
tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
  return MaskIfNonZero<tIntegerType>(a == b);
}

// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars are not equal.
template <typename tIntegerType>
tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
  return MaskIfNonZero<tIntegerType>(a != b);
}

// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars a, b satisfy a > b.
template <typename tIntegerType>
tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
  return MaskIfNonZero<tIntegerType>(a > b);
}

// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars a, b satisfy a >= b.
template <typename tIntegerType>
tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
  return MaskIfNonZero<tIntegerType>(a >= b);
}

// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars a, b satisfy a < b.
template <typename tIntegerType>
tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
  return MaskIfNonZero<tIntegerType>(a < b);
}

// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars a, b satisfy a <= b.
template <typename tIntegerType>
tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
  return MaskIfNonZero<tIntegerType>(a <= b);
}

// Returns true if all of the input scalars are nonzero.
// This function may currently assume that each of the input scalars has either
// all or none of its bits set. Otherwise, its behavior is currently undefined.
template <typename tIntegerType>
bool All(tIntegerType a) {
  return a;
}

// Returns true if any of the input scalars are nonzero.
// This function may currently assume that each of the input scalars has either
// all or none of its bits set. Otherwise, its behavior is currently undefined.
template <typename tIntegerType>
bool Any(tIntegerType a) {
  return a;
}

// Returns (a+b)/2, rounded to the nearest integer.
// Equivalent to VRHADD in the ARM NEON instruction set.
template <typename IntegerType>
IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
  (void)b;
  return a;
}

template <>
inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
  std::int64_t a64 = a;
  std::int64_t b64 = b;
  std::int64_t sum = a64 + b64;
  std::int64_t sign = sum >= 0 ? 1 : -1;
  return static_cast<std::int32_t>((sum + sign) / 2);
}

template <>
inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
  std::int32_t a32 = a;
  std::int32_t b32 = b;
  std::int32_t sum = a32 + b32;
  std::int32_t sign = sum >= 0 ? 1 : -1;
  return static_cast<std::int16_t>((sum + sign) / 2);
}

template <typename IntegerType>
IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
  (void)b;
  return a;
}

// So far this is only needed for int16.
template <>
inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
  std::int32_t a32 = a;
  std::int32_t b32 = b;
  std::int32_t sum = a32 + b32;
  return static_cast<std::int16_t>(
      std::min(static_cast<std::int32_t>(32767),
               std::max(static_cast<std::int32_t>(-32768), sum)));
}

template <>
inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) {
  std::int16_t a16 = a;
  std::int16_t b16 = b;
  std::int16_t sum = a16 + b16;
  return static_cast<std::int8_t>(std::min(
      static_cast<int16_t>(std::numeric_limits<int8_t>::max()),
      std::max(static_cast<int16_t>(std::numeric_limits<int8_t>::min()), sum)));
}

// Returns a+b, saturating if the integers are 16bit or narrower,
// otherwise just a plain addition.
template <typename IntegerType, bool Is16Bit>
struct AddSaturatingIf16BitImpl {
  static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
};
template <typename IntegerType>
struct AddSaturatingIf16BitImpl<IntegerType, true> {
  static IntegerType Run(IntegerType a, IntegerType b) {
    return SaturatingAdd(a, b);
  }
};
template <typename IntegerType>
IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
  using ScalarType =
      typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
  return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
                                                                             b);
}

// Returns the integer that represents the product of two fixed-point
// numbers, interpreting all integers as fixed-point values in the
// interval [-1, 1), rounding to the nearest value, and saturating
// -1 * -1 to the maximum value (since 1 is not in the half-open
// interval [-1, 1)).
//
// [The explanation below specializes to std::int32_t for example purpose.]
//
// The mapping between IntegerType and the interval [-1, 1) is unique and
// implied by IntegerType, which is assumed to be signed. For example,
// for IntegerType==std::int32_t, the mapping is
//   real_value = integer_value / 2^31.
// So in this case, and leaving aside rounding and saturating, this
// function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
//   (a * b) / 2^31.
//
// The 'doubling' part in the name of this function comes from the fact that
// this operation is very close to a "multiply-high" operation, keeping only
// the top half bits, except that that would be effectively computing
//   (a * b) / 2^32,
// so here we are computing 2x that, since
//   1/2^31 = 2 * 1/2^32.
// The idea is to use all of the available 32 bits in the destination int32
// value.
//
// [End of the explanation specializing to int32.]
//
// This is equivalent to the VQRDMULH instruction in ARM NEON.
template <typename IntegerType>
IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
  (void)b;
  return a;
}

// This function implements the same computation as the ARMv7 NEON VQRDMULH
// instruction.
template <>
inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
                                                      std::int32_t b) {
  bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
  std::int64_t a_64(a);
  std::int64_t b_64(b);
  std::int64_t ab_64 = a_64 * b_64;
  std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
  std::int32_t ab_x2_high32 =
      static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
  return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
}

template <>
inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
                                                      std::int16_t b) {
  bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
  std::int32_t a_32(a);
  std::int32_t b_32(b);
  std::int32_t ab_32 = a_32 * b_32;
  std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
  std::int16_t ab_x2_high16 =
      static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
  return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
}

// Correctly-rounded-to-nearest division by a power-of-two.
// Also known as a rounding arithmetic right shift.
template <typename IntegerType, typename ExponentType>
inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
  assert(exponent >= 0);
  assert(exponent <= 31);
  const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
  const IntegerType zero = Dup<IntegerType>(0);
  const IntegerType one = Dup<IntegerType>(1);
  const IntegerType remainder = BitAnd(x, mask);
  const IntegerType threshold =
      Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
  return Add(ShiftRight(x, exponent),
             BitAnd(MaskIfGreaterThan(remainder, threshold), one));
}

// Returns the product of a run-time integer value by a compile-time power
// of two, with either a positive exponent (equivalent to an arithmetic
// left shift, saturating) or a negative exponent (equivalent to an arithmetic
// right shift, rounding to nearest).
template <int Exponent, typename IntegerType,
          int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
struct ImplSaturatingRoundingMultiplyByPOT {};

template <int Exponent, typename IntegerType>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
  static IntegerType eval(IntegerType x) { return x; }
};

template <int Exponent, typename IntegerType>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
  static IntegerType eval(IntegerType x) {
    using ScalarIntegerType =
        typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
    const IntegerType min =
        Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
    const IntegerType max =
        Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
    const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);

    const std::int32_t threshold =
        ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
    const IntegerType positive_mask =
        MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
    const IntegerType negative_mask =
        MaskIfLessThan(x, Dup<IntegerType>(-threshold));

    IntegerType result = ShiftLeft(x, Exponent);
    result = SelectUsingMask(positive_mask, max, result);
    result = SelectUsingMask(negative_mask, min, result);
    return result;
  }
};

template <int Exponent, typename IntegerType>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
  static IntegerType eval(IntegerType x) {
    return RoundingDivideByPOT<IntegerType>(x, -Exponent);
  }
};

template <int Exponent, typename IntegerType>
IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
  return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
}

// Part 2: the FixedPoint class.

// A FixedPoint object represents a fixed-point value stored in the underlying
// integer type tRawType, if tRawType is a plain scalar integer type.
// Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
// case a FixedPoint object represents a corresponding SIMD vector of fixed
// point values.
//
// tIntegerBits describes the range of the fixed-point format: if
// tIntegerBits == m then the range of representable values is the half-open
// interval [-2^m; 2^m) where the open boundary on the right side means that
// 2^m is not representable (how close the maximum representable value is to
// it, depends on bit-depth of tRawType).
//
// In "Q format notation",
//   https://en.wikipedia.org/wiki/Q_(number_format)
// we are describing the format
//   Qm.n
// where
//   m = tIntegerBits
// and
//   n = NumberOfBits(tRawType) - (m + 1)
// Note that the (m + 1) in the above line is because we adopt the convention
// that we count the integer bits exclusively of the sign bit; so (m + 1) is
// the total number of integer bits inclusive of the sign bit.
//
// Accordingly, the number of integral representable values in our range
//   [-2^m ; 2^m)
// is equal to 2^(m+1).
template <typename tRawType, int tIntegerBits>
class FixedPoint {
 public:
  typedef tRawType RawType;

  typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
  typedef typename RawTypeTraits::ScalarRawType ScalarRawType;

  static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
  static constexpr int kIntegerBits = tIntegerBits;
  static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
  static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
                "bad IntegerBits");

  typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;

  static const ScalarRawType ScalarRawMin() {
    return std::numeric_limits<ScalarRawType>::min();
  }

  static const ScalarRawType ScalarRawMax() {
    return std::numeric_limits<ScalarRawType>::max();
  }

  static const ScalarRawType RawMin() {
    return VectorFromScalar(ScalarRawMin());
  }

  static const ScalarRawType RawMax() {
    return VectorFromScalar(ScalarRawMax());
  }

  static FixedPoint FromRaw(RawType x) {
    FixedPoint retval;
    retval.raw() = x;
    return retval;
  }

  static FixedPoint FromScalarRaw(ScalarRawType x) {
    FixedPoint retval;
    retval.raw() = Dup<RawType>(x);
    return retval;
  }

  static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
    return FromScalarRaw(x.raw());
  }

  template <int Exponent>
  static FixedPoint ConstantPOT() {
    static constexpr int kOffset = kFractionalBits + Exponent;
    static_assert(
        kOffset < 31,
        "Constant not exactly representable in this fixed-point format");
    return FromScalarRaw(ScalarRawType(1) << kOffset);
  }

  static FixedPoint Zero() { return FromScalarRaw(0); }

  static FixedPoint One() {
    return FromScalarRaw(
        kIntegerBits == 0
            ? ScalarRawMax()
            : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
  }

  static FixedPoint FromDouble(double x) {
    const double min_bound = static_cast<double>(ScalarRawMin());
    const double max_bound = static_cast<double>(ScalarRawMax());
    return FromScalarRaw(static_cast<ScalarRawType>(std::min(
        std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
                 min_bound),
        max_bound)));
  }

  RawType raw() const { return i_; }
  RawType& raw() { return i_; }

 private:
  RawType i_;
};

// Part 3: implementation of arithmetic operators for the
// FixedPoint class, and a few related functions.

// A FixedPoint multiplication is just a
// SaturatingRoundingDoublingHighMul operation on the underlying
// raw integer values. The IntegerBits simply add up, as is obvious
// from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
    FixedPoint<tRawType, tIntegerBits_a> a,
    FixedPoint<tRawType, tIntegerBits_b> b) {
  FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
  c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
  return c;
}

// Tweaking IntegerBits gives exact multiplication by a power of two.
template <int tExponent, typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
    FixedPoint<tRawType, tIntegerBits> a) {
  FixedPoint<tRawType, tExponent + tIntegerBits> c;
  c.raw() = a.raw();
  return c;
}

// If we want to leave IntegerBits fixed, then multiplication
// by a power of two has to be saturating/rounding, not exact anymore.
template <int tExponent, typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
    FixedPoint<tRawType, tIntegerBits> a) {
  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
      SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
}

// Generic arithmetic operators.

#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
  template <typename tRawType, int tIntegerBits>                               \
  FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
      FixedPoint<tRawType, tIntegerBits> a) {                                  \
    return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
  }

#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
  template <typename tRawType, int tIntegerBits>            \
  FixedPoint<tRawType, tIntegerBits> FuncName(              \
      FixedPoint<tRawType, tIntegerBits> a,                 \
      FixedPoint<tRawType, tIntegerBits> b) {               \
    return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
        ImplFuncName(a.raw(), b.raw()));                    \
  }

MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)

#undef MAKE_FIXEDPOINT_UNARY_FUNC
#undef MAKE_FIXEDPOINT_BINARY_FUNC

#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
  template <typename tRawType, int tIntegerBits>            \
  tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
    return FuncName(a.raw());                               \
  }

#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
  template <typename tRawType, int tIntegerBits>            \
  tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
                    FixedPoint<tRawType, tIntegerBits> b) { \
    return FuncName(a.raw(), b.raw());                      \
  }

MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)

#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW

template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
    tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
    FixedPoint<tRawType, tIntegerBits> else_val) {
  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
      SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
}

template <typename tRawType, int tIntegerBits>
bool operator==(FixedPoint<tRawType, tIntegerBits> a,
                FixedPoint<tRawType, tIntegerBits> b) {
  return All(MaskIfEqual(a.raw(), b.raw()));
}

template <typename tRawType, int tIntegerBits>
bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
                FixedPoint<tRawType, tIntegerBits> b) {
  return !(a == b);
}

template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
    FixedPoint<tRawType, tIntegerBits> a,
    FixedPoint<tRawType, tIntegerBits> b) {
  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
      SaturatingAdd(a.raw(), b.raw()));
}

template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
    FixedPoint<tRawType, tIntegerBits> a,
    FixedPoint<tRawType, tIntegerBits> b) {
  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
      AddSaturatingIf16Bit(a.raw(), b.raw()));
}

// Conversion to floating-point.
template <typename tRawType, int tIntegerBits>
double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
  static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
                "not applicable to SIMD types");
  typedef FixedPoint<tRawType, tIntegerBits> F;
  return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
}

// Rescale changes the number of IntegerBits and updates the underlying
// raw integer value accordingly.
template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
FixedPoint<tRawType, tIntegerBitsDst> Rescale(
    FixedPoint<tRawType, tIntegerBitsSrc> x) {
  static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
  FixedPoint<tRawType, tIntegerBitsDst> result;
  result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
  return result;
}

// CheckedFixedPointConstant allows to specify fixed-point constants
// initialized as real numbers, in a way that does not compile floating-point
// arithmetic in production code, yet still checks agreement with the
// floating-point expressions when asserts are enabled.
//
// The raw integer value provided is always a int32, encoding a 32-bit
// fixed-point value, regardless of the actual Scalar type. This allows
// writing generic code that applies just as well to the 32-bit and 16-bit
// cases. In the 16-bit case, the raw integer value is internally
// rounding-shifted by 16 bits to the right.
template <typename FixedPointType>
inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
    std::int32_t int32_value) {
  typedef typename FixedPointType::ScalarRawType ScalarRawType;
  static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
  return static_cast<ScalarRawType>(
      RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
}
#ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
template <typename FixedPointType>
FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
                                         double double_value) {
  const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
  assert(result == FixedPointType::FromDouble(double_value));
  return result;
}
#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
                                             ScalarRawInt32Value, DoubleValue) \
  (gemmlowp::CheckedFixedPointConstant<FixedPointType>(                        \
      gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
          ScalarRawInt32Value),                                                \
      DoubleValue))

#else
#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
                                             ScalarRawInt32Value, DoubleValue) \
  (FixedPointType::FromScalarRaw(                                              \
      gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
          ScalarRawInt32Value)))
#endif

// Implementation of exponential function.

// Returns exp(x) for x in [-1/4, 0).
template <typename tRawType>
FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
    FixedPoint<tRawType, 0> a) {
  typedef FixedPoint<tRawType, 0> F;
  const F constant_term =
      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
  const F constant_1_over_3 =
      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
  // We're evaluating a Taylor expansion around -1/8, so we do the change of
  // variable: x = a + 1/8.
  // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
  F x = a + F::template ConstantPOT<-3>();
  F x2 = x * x;
  F x3 = x2 * x;
  F x4 = x2 * x2;
  F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
  F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
      SaturatingRoundingMultiplyByPOT<-1>(
          ((x4_over_4 + x3) * constant_1_over_3) + x2);
  return AddSaturatingIf16Bit(
      constant_term,
      constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
}

// Returns exp(x) for x < 0.
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, 0> exp_on_negative_values(
    FixedPoint<tRawType, tIntegerBits> a) {
  typedef FixedPoint<tRawType, tIntegerBits> InputF;
  typedef FixedPoint<tRawType, 0> ResultF;
  static constexpr int kFractionalBits = InputF::kFractionalBits;
  static constexpr int kIntegerBits = InputF::kIntegerBits;
  const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
  InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
  InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
  ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
      Rescale<0>(a_mod_quarter_minus_one_quarter));
  tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();

#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
  if (kIntegerBits > Exponent) {                                            \
    const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
        ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
    static constexpr int kShiftAmount =                                     \
        kIntegerBits > Exponent ? kFractionalBits + Exponent : 0;           \
    result = SelectUsingMask(                                               \
        MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
        result * kMultiplier, result);                                      \
  }

  // Constants below are Q0 representations of negative exp fractionals:
  GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);  // exp(-1/4)
  GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);  // exp(-1/2)
  GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);   // exp(-1)
  GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);   // exp(-2)
  GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);    // exp(-4)
  GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);      // exp(-8)
  GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);         // exp(-16)

#undef GEMMLOWP_EXP_BARREL_SHIFTER

  static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
  if (kIntegerBits > 5) {
    const InputF clamp =
        GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
    result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
  }

  result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
  return result;
}

// Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).

// Returns (1 - x) / (1 + x) for x in (0, 1).
template <typename tRawType>
FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
    FixedPoint<tRawType, 0> a) {
  typedef FixedPoint<tRawType, 0> F0;
  typedef FixedPoint<tRawType, 2> F2;
  F0 half_denominator = RoundingHalfSum(a, F0::One());
  // Newton-Raphson division
  // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
  // Refer to that page for the logic behind the 48/17 and 32/17 constants.
  const F2 constant_48_over_17 =
      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
  const F2 constant_neg_32_over_17 =
      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
  F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
  for (int i = 0; i < 3; i++) {
    F2 half_denominator_times_x = half_denominator * x;
    F2 one_minus_half_denominator_times_x =
        F2::One() - half_denominator_times_x;
    x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
  }
  return Rescale<0>(x - F2::One());
}

// Returns -tanh(x) for x < 0.
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
    FixedPoint<tRawType, tIntegerBits> a) {
  return one_minus_x_over_one_plus_x_for_x_in_0_1(
      exp_on_negative_values(ExactMulByPot<1>(a)));
}

// Returns tanh(x) for any x.
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
  typedef FixedPoint<tRawType, tIntegerBits> InputF;
  typedef FixedPoint<tRawType, 0> ResultF;
  tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
  tRawType mask_if_zero = MaskIfZero(a);
  InputF n = SelectUsingMask(mask_if_negative, a, -a);
  ResultF t = neg_tanh_on_negative_values(n);
  return SelectUsingMask(mask_if_zero, ResultF::Zero(),
                         SelectUsingMask(mask_if_negative, -t, t));
}

// Implementation of logistic function.

// Returns 1 / (1 + x) for x in (0, 1).
template <typename tRawType>
FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
    FixedPoint<tRawType, 0> a) {
  typedef FixedPoint<tRawType, 0> F0;
  typedef FixedPoint<tRawType, 2> F2;
  F0 half_denominator = RoundingHalfSum(a, F0::One());
  // Newton-Raphson division
  // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
  // Refer to that page for the logic behind the 48/17 and 32/17 constants.
  const F2 constant_48_over_17 =
      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
  const F2 constant_neg_32_over_17 =
      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
  F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
  for (int i = 0; i < 3; i++) {
    F2 half_denominator_times_x = half_denominator * x;
    F2 one_minus_half_denominator_times_x =
        F2::One() - half_denominator_times_x;
    x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
  }
  return Rescale<0>(ExactMulByPot<-1>(x));
}

// Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, 0> logistic_on_positive_values(
    FixedPoint<tRawType, tIntegerBits> a) {
  return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
}

// Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
  typedef FixedPoint<tRawType, tIntegerBits> InputF;
  typedef FixedPoint<tRawType, 0> ResultF;
  tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
  tRawType mask_if_zero = MaskIfZero(a);
  InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
  ResultF result_if_positive = logistic_on_positive_values(abs_input);
  ResultF result_if_negative = ResultF::One() - result_if_positive;
  const ResultF one_half =
      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
  return SelectUsingMask(mask_if_zero, one_half,
                         SelectUsingMask(mask_if_positive, result_if_positive,
                                         result_if_negative));
}

}  // end namespace gemmlowp

#ifdef GEMMLOWP_NEON
#include "./fixedpoint_neon.h"
#elif defined(GEMMLOWP_AVX2)
#include "./fixedpoint_avx.h"
#elif defined(GEMMLOWP_SSE4)
#include "./fixedpoint_sse.h"
#elif defined(GEMMLOWP_MSA)
#include "./fixedpoint_msa.h"
#endif

#endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
