#include <algorithm>
#include <memory>

#include <immintrin.h>
#include <cblas.h>

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)

namespace py = pybind11;

#define NUM_AGGS 3

inline void default_sum(const int32_t dense_cols, float *sums, const float a,
                        const float *b, const int32_t j) {
  for (int32_t k = 0; k < dense_cols; k++) {
    sums[k] += a * b[(j * dense_cols) + k];
  }
}

inline void sum(const int32_t dense_cols, float *sums, const float a,
                const float *b, const int32_t j) {
  __m256 a_vec = _mm256_set1_ps(a);
  for (int32_t k = 0; k < dense_cols; k += 8) {
    __m256 s = _mm256_loadu_ps(sums + k);
    __m256 b_vec = _mm256_loadu_ps(b + (j * dense_cols) + k);

    // sum:
    __m256 a_times_b = _mm256_mul_ps(a_vec, b_vec);
    __m256 res = _mm256_add_ps(s, a_times_b);

    _mm256_storeu_ps(sums + k, res);
  }
}

inline void max(const int32_t dense_cols, float *maxs, const float a,
                const float *b, const int32_t j) {
  for (int32_t k = 0; k < dense_cols; k += 8) {
    __m256 ma = _mm256_loadu_ps(maxs + k);
    __m256 b_vec = _mm256_loadu_ps(b + (j * dense_cols) + k);
    __m256 max_vec = _mm256_max_ps(ma, b_vec);
    _mm256_storeu_ps(maxs + k, max_vec);
  }
}

inline void min(const int32_t dense_cols, float *mins, const float a,
                const float *b, const int32_t j) {
  for (int32_t k = 0; k < dense_cols; k += 8) {
    __m256 mi = _mm256_loadu_ps(mins + k);
    __m256 b_vec = _mm256_loadu_ps(b + (j * dense_cols) + k);
    __m256 min_vec = _mm256_min_ps(mi, b_vec);
    _mm256_storeu_ps(mins + k, min_vec);
  }
}

inline void manual_loop_fused(const int32_t dense_cols, float *sums,
                              float *maxs, float *mins, const float a,
                              const float *b, const int32_t j) {
  // Going faster than this (at least on Skylake MBP) is difficult
  // unfortunately the store pipeline is not as fat as the load pipeline
  // -> stores are the bottleneck :-(
  __m256 a_vec = _mm256_set1_ps(a);
  for (int32_t k = 0; k < dense_cols; k += 8) {
    __m256 s = _mm256_loadu_ps(sums + k);
    __m256 ma = _mm256_loadu_ps(maxs + k);
    __m256 mi = _mm256_loadu_ps(mins + k);
    __m256 b_vec = _mm256_loadu_ps(b + (j * dense_cols) + k);

    // sum:
    __m256 a_times_b = _mm256_mul_ps(a_vec, b_vec);
    __m256 res = _mm256_add_ps(s, a_times_b);

    // max
    __m256 max_vec = _mm256_max_ps(ma, b_vec);

    // min
    __m256 min_vec = _mm256_min_ps(mi, b_vec);
    _mm256_storeu_ps(sums + k, res);
    _mm256_storeu_ps(maxs + k, max_vec);
    _mm256_storeu_ps(mins + k, min_vec);
  }
}

#define csr(inner_kernel)                                                      \
  void csr_##inner_kernel(                                                     \
      int32_t n_row, int32_t n_col, py::array_t<int32_t> row_ptr,              \
      py::array_t<int32_t> col_idx, py::array_t<float> nonzero,                \
      py::array_t<float> dense, py::array_t<float> out) {                      \
                                                                               \
    if (dense.ndim() != 2 || out.ndim() != 2) {                                \
      throw std::invalid_argument("Dense matrices should be 2D");              \
    }                                                                          \
    if (dense.shape(0) != n_col) {                                             \
      throw std::invalid_argument("dense rows doesn't match sparse cols");     \
    }                                                                          \
    if (out.shape(0) != n_row || out.shape(1) != dense.shape(1)) {             \
      throw std::invalid_argument("Out matrix is wrong shape");                \
    }                                                                          \
    if (dense.shape(1) % 8 != 0) {                                             \
      throw std::invalid_argument("feature size must be divisible by 8");      \
    }                                                                          \
                                                                               \
    auto dense_cols = dense.shape(1);                                          \
                                                                               \
    auto ci = static_cast<int32_t *>(col_idx.request().ptr);                   \
    auto rp = static_cast<int32_t *>(row_ptr.request().ptr);                   \
    auto nz = static_cast<float *>(nonzero.request().ptr);                     \
    auto c = static_cast<float *>(out.request().ptr);                          \
    auto b = static_cast<float *>(dense.request().ptr);                        \
                                                                               \
    for (int32_t i = 0; i < n_row; i++) {                                      \
      for (int32_t jj = rp[i]; jj < rp[i + 1]; jj++) {                         \
        const int32_t j = ci[jj];                                              \
        const float a = nz[jj];                                                \
        inner_kernel(dense_cols, c + (i * dense_cols), a, b, j);               \
      }                                                                        \
    }                                                                          \
  }

csr(sum);
csr(default_sum);
csr(max);
csr(min);

void csr_fused(int32_t n_row, int32_t n_col, py::array_t<int32_t> row_ptr,
               py::array_t<int32_t> col_idx, py::array_t<float> nonzero,
               py::array_t<float> dense, py::array_t<float> weightings,
               py::array_t<float> out) {

  if (dense.ndim() != 2 || out.ndim() != 2) {
    throw std::invalid_argument("Dense matrices should be 2D");
  }
  if (dense.shape(0) != n_col) {
    throw std::invalid_argument("dense rows doesn't match sparse cols");
  }
  if (out.shape(0) != n_row || out.shape(1) != dense.shape(1)) {
    throw std::invalid_argument("Out matrix is wrong shape");
  }
  if (dense.shape(1) % 8 != 0) {
    throw std::invalid_argument("feature size must be divisible by 8");
  }
  if (weightings.shape(0) != n_row || weightings.shape(1) != NUM_AGGS) {
    throw std::invalid_argument("Weightings have wrong shape");
  }

  auto dense_cols = dense.shape(1);

  auto ci = static_cast<int32_t *>(col_idx.request().ptr);
  auto rp = static_cast<int32_t *>(row_ptr.request().ptr);
  auto nz = static_cast<float *>(nonzero.request().ptr);
  auto c = static_cast<float *>(out.request().ptr);
  auto b = static_cast<float *>(dense.request().ptr);
  auto w = static_cast<float *>(weightings.request().ptr);

  std::vector<float> sums(dense_cols, 0.0);
  std::vector<float> maxs(dense_cols, 0.0);
  std::vector<float> mins(dense_cols, 0.0);

  // this bit is the CSR matmul
  for (int32_t i = 0; i < n_row; i++) {

    std::fill(sums.begin(), sums.end(), 0.0f);
    std::fill(maxs.begin(), maxs.end(), std::numeric_limits<float>::lowest());
    std::fill(mins.begin(), mins.end(), std::numeric_limits<float>::max());

    for (int32_t jj = rp[i]; jj < rp[i + 1]; jj++) {
      const int32_t j = ci[jj];
      const float a = nz[jj];
      //   default_innerloop(dense_cols, sums, a, b, j);
      // manual_loop(dense_cols, sums.data(), a, b, j);
      manual_loop_fused(dense_cols, sums.data(), maxs.data(), mins.data(), a, b,
                        j);
    }
    for (int32_t k = 0; k < dense_cols; k++) {
      c[(i * dense_cols) + k] = w[(i * NUM_AGGS) + 0] * sums[k] +
                                w[(i * NUM_AGGS) + 1] * maxs[k] +
                                w[(i * NUM_AGGS) + 2] * mins[k];
    }
  }
}


void csr_fused_materialize(int32_t n_row, int32_t n_col, py::array_t<int32_t> row_ptr,
               py::array_t<int32_t> col_idx, py::array_t<float> nonzero,
               py::array_t<float> dense, py::array_t<float> weightings,
               py::array_t<float> out_mat, py::array_t<float> out_weighted) {

  if (dense.ndim() != 2 || out_weighted.ndim() != 2) {
    throw std::invalid_argument("Dense matrices should be 2D");
  }
  if (dense.shape(0) != n_col) {
    throw std::invalid_argument("dense rows doesn't match sparse cols");
  }


  // 2 out matrices, just in case the optimizer does some funny things and realizes
  // that some stuff can be elided otherwise :-)
  if (out_mat.shape(0) != 3 || out_mat.shape(1) != n_row || out_mat.shape(2) != dense.shape(1)) {
    throw std::invalid_argument("Out (weighted) matrix is wrong shape");
  }
  if (out_weighted.shape(0) != n_row || out_weighted.shape(1) != dense.shape(1)) {
    throw std::invalid_argument("Out (weighted) matrix is wrong shape");
  }

  if (dense.shape(1) % 8 != 0) {
    throw std::invalid_argument("feature size must be divisible by 8");
  }
  if (weightings.shape(0) != n_row || weightings.shape(1) != NUM_AGGS) {
    throw std::invalid_argument("Weightings have wrong shape");
  }

  auto dense_cols = dense.shape(1);

  auto ci = static_cast<int32_t *>(col_idx.request().ptr);
  auto rp = static_cast<int32_t *>(row_ptr.request().ptr);
  auto nz = static_cast<float *>(nonzero.request().ptr);
  auto om = static_cast<float *>(out_mat.request().ptr);
  auto ow = static_cast<float *>(out_weighted.request().ptr);
  auto b = static_cast<float *>(dense.request().ptr);
  auto w = static_cast<float *>(weightings.request().ptr);

  auto sums = om + 0 * (dense_cols * n_row);
  auto maxs = om + 1 * (dense_cols * n_row);
  auto mins = om + 2 * (dense_cols * n_row);

  std::fill(sums, sums + (dense_cols * n_row), 0.0f);
  std::fill(maxs, maxs + (dense_cols * n_row), std::numeric_limits<float>::lowest());
  std::fill(mins, mins + (dense_cols * n_row), std::numeric_limits<float>::max());

  // this bit is the CSR matmul
  for (int32_t i = 0; i < n_row; i++) {
    auto shift = i * dense_cols;
    auto s = sums + shift;
    auto ma = maxs + shift;
    auto mi = mins + shift;
    for (int32_t jj = rp[i]; jj < rp[i + 1]; jj++) {
      const int32_t j = ci[jj];
      const float a = nz[jj];
      manual_loop_fused(dense_cols, s, ma, mi, a, b,
                        j);
    }
    for (int32_t k = 0; k < dense_cols; k++) {
      ow[(i * dense_cols) + k] = w[(i * NUM_AGGS) + 0] * s[k] +
                                w[(i * NUM_AGGS) + 1] * ma[k] +
                                w[(i * NUM_AGGS) + 2] * mi[k];
    }
  }
}


namespace py = pybind11;

PYBIND11_MODULE(aggfuse_cpu, m) {
  m.doc() = "Aggregator fusion in C++";

  m.def("csr_sum", &csr_sum);
  m.def("csr_default_sum", &csr_default_sum);
  m.def("csr_min", &csr_min);
  m.def("csr_max", &csr_max);
  m.def("aggfuse_fp32", &csr_fused);
  m.def("aggfuse_fp32_mat", &csr_fused_materialize);

#ifdef VERSION_INFO
  m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
#else
  m.attr("__version__") = "dev";
#endif
}
