#include "CommonLib/Quant.h"
#include "cpp/fast_quant.h"
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

namespace py = pybind11;

int32_t quantize_py(int32_t W, int32_t min, int32_t max, float lambda, float pv,
                    float ws, py::array_t<float> rate_estimations) {
  auto rate_estimations_ref = rate_estimations.unchecked<1>();
  float *rate_estimations_ptr = rate_estimations.mutable_data();
  return fast_quant::quantize_single(W, min, max, lambda, pv, ws,
                                     rate_estimations_ptr);
}

class DeepCabacRdQuantiserPy {
public:
  DeepCabacRdQuantiserPy(float lambda, float delta, int32_t min, int32_t max)
      : quantiser(lambda, delta, min, max) {}
  DeepCabacRdQuantiserPy()
      : quantiser(0, 1, 0,0) {}

  void set_grid(int32_t min, int32_t max, float delta) {
    quantiser.min = min;
    quantiser.max = max;
    quantiser.delta = delta;
  }

  int32_t quantize_single(float W, float posterior_variance) {
    return quantiser.quantize_single(W, posterior_variance);
  }

  py::array_t<int32_t>
  quantize_multiple(py::array_t<float> W,
                    float posterior_variance,
                    py::array_t<float> delta, 
                    int32_t min_idx, 
                    int32_t max_idx) {

    auto winfo = W.request();
    auto w_ptr = (float*)winfo.ptr;


    auto delta_info = delta.request();
    auto delta_ptr = (float*)delta_info.ptr;

    auto quantized = py::array_t<int32_t, py::array::c_style>(winfo.size);
    int32_t *raw_ptr = (int32_t *)quantized.request().ptr;


    quantiser.quantize_multiple(w_ptr, winfo.size, posterior_variance, delta_ptr, min_idx, max_idx, raw_ptr);
    return quantized;

  }


  py::array_t<int32_t>
  quantize_multiple(py::array_t<float> W,
                    py::array_t<float> posterior_variance,
                    py::array_t<float> delta, 
                    py::array_t<int32_t> min_idx, 
                    py::array_t<int32_t> max_idx) {
    auto winfo = W.request();
    auto w_ptr = (float*)winfo.ptr;

    auto pv_info = posterior_variance.request();
    auto pv_ptr = (float*)pv_info.ptr;

    auto delta_info = delta.request();
    auto delta_ptr = (float*)delta_info.ptr;

    auto min_info = min_idx.request();
    auto min_ptr = (int32_t *)min_info.ptr;

    auto max_info = max_idx.request();
    auto max_ptr = (int32_t *)max_info.ptr;

    auto quantized = py::array_t<int32_t, py::array::c_style>(winfo.size);
    int32_t *raw_ptr = (int32_t *)quantized.request().ptr;


    if (winfo.size != pv_info.size) {
      std::cout << winfo.size << " " << pv_info.size << std::endl;
      throw std::range_error(
          "W and posterior_variance must have the same size");
    }

    quantiser.quantize_multiple(w_ptr, winfo.size, pv_ptr, delta_ptr, min_ptr, max_ptr, raw_ptr);
    return quantized;

  }


  py::array_t<int32_t>
  quantize_multiple(py::array_t<float> W,
                    py::array_t<float> posterior_variance) {
    auto winfo = W.request();
    auto w_ptr = (float*)winfo.ptr;

    auto pv_info = posterior_variance.request();
    auto pv_ptr = (float*)pv_info.ptr;
    auto quantized = py::array_t<int32_t, py::array::c_style>(winfo.size);
    int32_t *raw_ptr = (int32_t *)quantized.request().ptr;

    if (pv_info.size == 1) {
      quantiser.quantize_multiple(w_ptr, winfo.size, pv_ptr[0], raw_ptr);
      return quantized;
    }

    if (winfo.size != pv_info.size) {
      std::cout << winfo.size << " " << pv_info.size << std::endl;
      throw std::range_error(
          "W and posterior_variance must have the same size");
    }

    quantiser.quantize_multiple(w_ptr, winfo.size, pv_ptr, raw_ptr);
    return quantized;
  }

  py::array_t<int32_t> quantize_multiple(py::array_t<float> W,
                                         float posterior_variance) {
    auto winfo = W.request();
    auto w_ptr = (float*)winfo.ptr;

    auto quantized = py::array_t<int32_t, py::array::c_style>(winfo.size);
    int32_t *raw_ptr = (int32_t *)quantized.request().ptr;
    quantiser.quantize_multiple(w_ptr, winfo.size, posterior_variance,
                                raw_ptr);
    return quantized;
  };
  float get_lm() const { return quantiser.lambda; }
  void set_lm(float value) { quantiser.lambda = value; }
  float get_delta() const { return quantiser.delta; }
  void set_delta(float value) { quantiser.delta = value; }
  float get_min() const { return quantiser.min; }
  void set_min(float value) { quantiser.min = value; }
  float get_max() const { return quantiser.max; }
  void set_max(float value) { quantiser.max = value; }

private:
  DeepCabacRdQuantiser quantiser;
};

PYBIND11_MODULE(_core, m) {
  py::class_<DeepCabacRdQuantiserPy>(m, "DeepCabacRdQuantiser")
      .def(py::init<float, float, int32_t, int32_t>(), py::arg("lambda"),
           py::arg("delta"), py::arg("min"), py::arg("max"),
           R"pbdoc(Initialise the quantiser with lambda, delta, min and max.
Supply the rate-tradeoff parameter with lambda. 
Delta is the scale of weight indices (so a gridpoint is obtained by multiplying its index 
with delta). Min and max are the minimum respectively maximum indices one can quantise to,
they essentially bound the grid size.
          )pbdoc")
      .def(py::init<>(), R"pbdoc(Initialise the quantiser with default values.)pbdoc")
      .def(
          "quantize", &DeepCabacRdQuantiserPy::quantize_single, py::arg("w"),
          py::arg("posterior_variance"),
          R"pbdoc(Quantize a (set of) weights with a (set of) posterior variances.)pbdoc")
      .def("quantize",
           py::overload_cast<py::array_t<float>, float>(
               &DeepCabacRdQuantiserPy::quantize_multiple),
           py::arg("W"), py::arg("posterior_variance"))
      //.def("quantize",
      //     py::overload_cast<py::array_t<float>, py::array_t<float>, py::array_t<float>, py::array_t<int32_t>, py::array_t<int32_t>>(
      //         &DeepCabacRdQuantiserPy::quantize_multiple),
      //     py::arg("W"), py::arg("posterior_variance"), py::arg("delta"), py::arg("min_idx"), py::arg("max_idx"))
      .def("quantize",
           py::overload_cast<py::array_t<float>, float, py::array_t<float>, int32_t, int32_t>(
               &DeepCabacRdQuantiserPy::quantize_multiple),
           py::arg("W"), py::arg("posterior_variance"), py::arg("delta"), py::arg("min_idx"), py::arg("max_idx"))
      .def("quantize",
           py::overload_cast<py::array_t<float>, py::array_t<float>>(
               &DeepCabacRdQuantiserPy::quantize_multiple),
           py::arg("W"), py::arg("posterior_variance"))
      .def_property("lm", &DeepCabacRdQuantiserPy::get_lm, &DeepCabacRdQuantiserPy::set_lm)
      // .def_property("delta", &DeepCabacRdQuantiserPy::get_delta, &DeepCabacRdQuantiserPy::set_delta)
      // .def_property("min_idx", &DeepCabacRdQuantiserPy::get_min, &DeepCabacRdQuantiserPy::set_min)
      // .def_property("max_idx", &DeepCabacRdQuantiserPy::get_max, &DeepCabacRdQuantiserPy::get_max)
      .def("set_grid", &DeepCabacRdQuantiserPy::set_grid, py::arg("min"), py::arg("max"), py::arg("delta"),
           R"pbdoc(Set the grid size of the quantiser. The grid is defined by the minimum and maximum index)pbdoc")
      .doc() =
      R"pbdoc(Rate-Distortion quantiser that uses the entropy-model of DeepCABAC to estimate 
the entropy of the resulting bit string. Therefore, solves 
argmin_{min <= i <= max} rate(i) * lambda + 1/pv * (w - i)^2 * delta^2
           )pbdoc";
}
