#pragma once
#include <CommonLib/Quant.h>
#include <cmath>
#include <limits>
#include <type_traits>

namespace fast_quant {
template <typename T1, typename T2, typename T3, typename T4, typename T5>
int32_t quantize_multiple(float *W, T1 min, T2 max, T3 lambda, T4 pv, T5 ws,
                          const float *rate_estimations, size_t size);


// Overload for array handling (get value at index or pass directly)
template <typename T> inline float get_value(T arg, size_t i) {
  if constexpr (std::is_array<T>::value) { // Use std::is_pointer to check for
                                           // array
    return arg[i];
  } else {
    return arg;
  }
}

/// \brief Quantize a single weight. Works completely in indices.
/// The cost can be calculated as
///    cost = rate * lambda + 1/pv * (w - w_q)^2 * weight_scale^2
/// \param w The weight to quantize
/// \param min The minimum value the weight can take
/// \param max The maximum value the weight can take
/// \param lambda Scaling factor for the rate term
/// \param posterior_variance How much the weight can vary
/// \param weight_scale The real distance between two quantized values, Delta
/// \param rate_estimations The rate estimations for each weight. Must have
/// length max - min + 1
/// \return The quantized weight
int32_t quantize_single(float w, int32_t min, int32_t max, float lambda,
                        float posterior_variance, float weight_scale,
                        const float *rate_estimations);
} // namespace fast_quant

class DeepCabacRdQuantiser {
public:
  DeepCabacRdQuantiser(float lambda, float delta, int32_t min, int32_t max)
      : lambda(lambda), delta(delta), min(min), max(max) {};
  int32_t quantize_single(float w, float posterior_variance);
  int32_t quantize_single(float w, float posterior_variance, int min, int max, float weight_scale);
  int32_t *quantize_multiple(const float *W, size_t size,
                             float posterior_variance, int32_t *out_buffer);
  int32_t *quantize_multiple(const float *W, size_t size,
                             const float *posterior_variance,
                             int32_t *out_buffer);
  int32_t *quantize_multiple(const float *W, size_t size,
                                        const float *posterior_variance,
                                        const float *weight_scale,
                                        const int32_t *min_idx, 
                                        const int32_t *max_idx,
                                        int32_t *out_buffer);
int32_t* quantize_multiple(const float *W, size_t size,
                                        float posterior_variance,
                                        const float *weight_scale,
                                        int32_t min_idx, 
                                        int32_t max_idx,
                                        int32_t *out_buffer);
  float lambda;
  float delta;
  int32_t min;
  int32_t max;
private:
  RateEstimation rate_estimator;
};