#include "fast_quant.h"
#include "CommonLib/Quant.h"

namespace fast_quant {
template <typename T1, typename T2, typename T3, typename T4, typename T5>
int32_t quantize_multiple(float *W, T1 min, T2 max, T3 lambda, T4 pv, T5 ws,
                          const float *rate_estimations, size_t size) {
  for (size_t i = 0; i < size; ++i) {
    quantize_single(W[i], get_value(min, i), get_value(max, i),
                    get_value(lambda, i), get_value(pv, i), get_value(ws, i),
                    rate_estimations);
  }
}

/// @brief Calculates cost in index space
/// @param w  The weight to quantize in index space
/// @param i  The index to quantize to
/// @param lambda Tradeoff parameter between rate and distortion
/// @param rate_estimation The estimated rate for the index
/// @return r * lm + (w - i)^2
float inline cost(float w, int32_t i, float lambda, float rate_estimation) {
  return rate_estimation * lambda + pow((w - float(i)), 2);
}


/// @brief Quantise a single parameter value according to RD cost in index space, rate*lm + distortion / posterior_variance
/// @param w The weight to quantise in parameter space
/// @param min Minimum index value
/// @param max Maximum index value
/// @param lambda Trade-off parameter between rate and distortion (rate * lambda + distortion)
/// @param posterior_variance Parameter that determines how much the weight can vary
/// @param weight_scale 
/// @param rate_estimations 
/// @return 
int32_t quantize_single(float w, int32_t min, int32_t max, float lambda,
                        float posterior_variance, float weight_scale,
                        const float *rate_estimations) {
  float lambda_prime = lambda * posterior_variance / pow(weight_scale, 2);
  float best_cost = std::numeric_limits<float>::max();
  int32_t best_idx = 0;
  for (int32_t i = min; i <= max; i++) {
    float c = cost(w, i, lambda_prime, rate_estimations[i - min]);
    if (c < best_cost) {
      best_cost = c;
      best_idx = i;
    }
  }
  return best_idx;
}
} // namespace fast_quant

int32_t DeepCabacRdQuantiser::quantize_single(float w,
                                              float posterior_variance) {
  return this->quantize_single(w, posterior_variance, this->min, this->max, this->delta);
}

int32_t DeepCabacRdQuantiser::quantize_single(float w,
                                              float posterior_variance,
                                              int min_, int max_, float weight_scale_) {
  if (min_ >= max_)  {
    std::cout << "Illegal min and max values: (" << min << "," << max << ")"  << std::endl;
    std::cout << "Called with w: " << w << " pv: " << posterior_variance << std::endl;
    throw std::range_error("min must be strictly less than max");
  }
  w = w / weight_scale_;
  float *rate_estimations = new float[max_ - min_ + 1];
  for (int32_t i = min_; i <= max_; i++) {
    rate_estimations[i - min_] = rate_estimator.estimate(i);
  } 
  int32_t best_idx = fast_quant::quantize_single(
      w, min_, max_, lambda, posterior_variance, weight_scale_, rate_estimations);
  rate_estimator.update(best_idx);
  free(rate_estimations);
  return best_idx;
}

int32_t *
DeepCabacRdQuantiser::quantize_multiple(const float *W, size_t size,
                                        const float *posterior_variance,
                                        int32_t *out_buffer) {
  for (size_t i = 0; i < size; ++i) {
    out_buffer[i] = quantize_single(W[i], posterior_variance[i]);
  }
  return out_buffer;
}

int32_t *DeepCabacRdQuantiser::quantize_multiple(const float *W, size_t size,
                                                 float posterior_variance,
                                                 int32_t *out_buffer) {
  for (size_t i = 0; i < size; ++i) {
    out_buffer[i] = quantize_single(W[i], posterior_variance);
  }
  return out_buffer;
}

int32_t *
DeepCabacRdQuantiser::quantize_multiple(const float *W, size_t size,
                                        const float *posterior_variance,
                                        const float *weight_scale,
                                        const int32_t *min_idx, 
                                        const int32_t *max_idx,
                                        int32_t *out_buffer) {
  for (size_t i = 0; i < size; ++i) {
    out_buffer[i] = quantize_single(W[i], posterior_variance[i], min_idx[i], max_idx[i], weight_scale[i]);
  }
  return out_buffer;
}

int32_t *
DeepCabacRdQuantiser::quantize_multiple(const float *W, size_t size,
                                        float posterior_variance,
                                        const float *weight_scale,
                                        int32_t min_idx, 
                                        int32_t max_idx,
                                        int32_t *out_buffer) {
  for (size_t i = 0; i < size; ++i) {
    out_buffer[i] = quantize_single(W[i], posterior_variance, min_idx, max_idx, weight_scale[i]);
  }
  return out_buffer;
}