#include "bench_harness.h"
#include "bench_utils.h"
#include <math.h>
#include <stdint.h>
#include <stdlib.h>
static inline void swap_pts(int a, int b, double *px, double *py) {
  double tx = px[a];
  double ty = py[a];
  px[a] = px[b];
  py[a] = py[b];
  px[b] = tx;
  py[b] = ty;
}
static inline double coord_axis(int idx, int axis, const double *px,
                                const double *py) {
  return axis == 0 ? px[idx] : py[idx];
}
static void select_k(int l, int r, int k, int axis, double *px, double *py) {
  while (l < r) {
    double pivot = coord_axis((l + r) >> 1, axis, px, py);
    int i = l;
    int j = r;
    while (i <= j) {
      while (coord_axis(i, axis, px, py) < pivot)
        i++;
      while (coord_axis(j, axis, px, py) > pivot)
        j--;
      if (i <= j) {
        swap_pts(i, j, px, py);
        i++;
        j--;
      }
    }
    if (k <= j) {
      r = j;
    } else if (k >= i) {
      l = i;
    } else {
      break;
    }
  }
}
static int build_kd(int l, int r, int depth, double *px, double *py, int *leftc,
                    int *rightc, int *axisv) {
  if (l > r)
    return -1;
  int axis = depth & 1;
  int mid = (l + r) >> 1;
  select_k(l, r, mid, axis, px, py);
  axisv[mid] = axis;
  int L = build_kd(l, mid - 1, depth + 1, px, py, leftc, rightc, axisv);
  int R = build_kd(mid + 1, r, depth + 1, px, py, leftc, rightc, axisv);
  leftc[mid] = L;
  rightc[mid] = R;
  return mid;
}
static inline double dist2_pt(double qx, double qy, int idx, const double *px,
                              const double *py) {
  double dx = qx - px[idx];
  double dy = qy - py[idx];
  return dx * dx + dy * dy;
}
static void nn_search(int node, double qx, double qy, double *bestd2,
                      int *besti, const double *px, const double *py,
                      const int *leftc, const int *rightc, const int *axisv) {
  if (node < 0)
    return;
  double d2 = dist2_pt(qx, qy, node, px, py);
  if (d2 < *bestd2) {
    *bestd2 = d2;
    *besti = node;
  }
  int ax = axisv[node];
  double diff = ax == 0 ? (qx - px[node]) : (qy - py[node]);
  int first = diff < 0.0 ? leftc[node] : rightc[node];
  int second = diff < 0.0 ? rightc[node] : leftc[node];
  nn_search(first, qx, qy, bestd2, besti, px, py, leftc, rightc, axisv);
  double diff2 = diff * diff;
  if (diff2 < *bestd2) {
    nn_search(second, qx, qy, bestd2, besti, px, py, leftc, rightc, axisv);
  }
}
static double run_kdtree(int n, double *px, double *py, int *leftc, int *rightc,
                         int *axisv, const double *qx, const double *qy,
                         int qn) {
  int root = build_kd(0, n - 1, 0, px, py, leftc, rightc, axisv);
  double acc = 0.0;
  for (int i = 0; i < qn; i++) {
    double bd2 = 1e300;
    int bi = -1;
    nn_search(root, qx[i], qy[i], &bd2, &bi, px, py, leftc, rightc, axisv);
    acc += sqrt(bd2);
  }
  double outv = 0.0;
  outv = acc;
  return outv;
}
BENCH_MAIN_SCALAR3(
    T004_Module_050, KDTREE, 4096, 16384, 65536, int qn = n / 4;
    if (qn < 1) qn = 1;
    double *px = (double *)malloc((size_t)n * sizeof(double));
    double *py = (double *)malloc((size_t)n * sizeof(double));
    int *leftc = (int *)malloc((size_t)n * sizeof(int));
    int *rightc = (int *)malloc((size_t)n * sizeof(int));
    int *axisv = (int *)malloc((size_t)n * sizeof(int));
    double *qx = (double *)malloc((size_t)qn * sizeof(double));
    double *qy = (double *)malloc((size_t)qn * sizeof(double));
    double ans_scalar = 0.0;
    ,
    {
      bench_rng64_t rng = bench_rng_init(seed);
      for (int i = 0; i < n; i++) {
        px[i] = 100.0 * bench_rng_double_signed(&rng);
        py[i] = 100.0 * bench_rng_double_signed(&rng);
      }
      for (int i = 0; i < qn; i++) {
        qx[i] = 100.0 * bench_rng_double_signed(&rng);
        qy[i] = 100.0 * bench_rng_double_signed(&rng);
      }
    },
    ans_scalar = run_kdtree(n, px, py, leftc, rightc, axisv, qx, qy, qn),
    ans_scalar, free(px);
    free(py); free(leftc); free(rightc); free(axisv); free(qx); free(qy);)
