#include "bench_harness.h"
#include "bench_utils.h"
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#define RUN_BLK 64
static inline int cmp_default(int a, int b) {
  if (a < b)
    return -1;
  if (a > b)
    return 1;
  return 0;
}
static void insertion_sort_run(int *k, double *v, int start, int end,
                               int (*cmp)(int, int)) {
  for (int i = start + 1; i < end; i++) {
    int ki = k[i];
    double vi = v[i];
    int j = i - 1;
    while (j >= start && cmp(k[j], ki) > 0) {
      k[j + 1] = k[j];
      v[j + 1] = v[j];
      j--;
    }
    k[j + 1] = ki;
    v[j + 1] = vi;
  }
}
static void sort_full(int *k, double *v, int n, int (*cmp)(int, int)) {
  if (n <= 1)
    return;
  int *tk = (int *)malloc((size_t)n * sizeof(int));
  double *tv = (double *)malloc((size_t)n * sizeof(double));
  if (!tk || !tv) {
    if (tk)
      free(tk);
    if (tv)
      free(tv);
    return;
  }
  for (int start = 0; start < n; start += RUN_BLK) {
    int end = start + RUN_BLK;
    if (end > n)
      end = n;
    insertion_sort_run(k, v, start, end, cmp);
  }
  int *src_k = k;
  double *src_v = v;
  int *dst_k = tk;
  double *dst_v = tv;
  for (int width = RUN_BLK; width < n; width <<= 1) {
    for (int start = 0; start < n; start += (width << 1)) {
      int mid = start + width;
      int end = start + (width << 1);
      if (mid > n)
        mid = n;
      if (end > n)
        end = n;
      int i = start;
      int j = mid;
      int pos = start;
      while (i < mid && j < end) {
        if (cmp(src_k[i], src_k[j]) <= 0) {
          dst_k[pos] = src_k[i];
          dst_v[pos] = src_v[i];
          i++;
        } else {
          dst_k[pos] = src_k[j];
          dst_v[pos] = src_v[j];
          j++;
        }
        pos++;
      }
      while (i < mid) {
        dst_k[pos] = src_k[i];
        dst_v[pos] = src_v[i];
        i++;
        pos++;
      }
      while (j < end) {
        dst_k[pos] = src_k[j];
        dst_v[pos] = src_v[j];
        j++;
        pos++;
      }
    }
    int *tmpk = src_k;
    src_k = dst_k;
    dst_k = tmpk;
    double *tmpv = src_v;
    src_v = dst_v;
    dst_v = tmpv;
  }
  if (src_k != k) {
    memcpy(k, src_k, (size_t)n * sizeof(int));
    memcpy(v, src_v, (size_t)n * sizeof(double));
  }
  free(tk);
  free(tv);
}
static double merge_join_sum(int *lk, double *lv, int nL, int *rk, double *rv,
                             int nR) {
  int i = 0;
  int j = 0;
  double acc = 0.0;
  while (i < nL && j < nR) {
    int a = lk[i];
    int b = rk[j];
    if (a == b) {
      int key = a;
      double sumL = 0.0;
      double sumR = 0.0;
      while (i < nL && lk[i] == key) {
        sumL += lv[i];
        i++;
      }
      while (j < nR && rk[j] == key) {
        sumR += rv[j];
        j++;
      }
      acc += sumL * sumR;
    } else if (a < b) {
      i++;
    } else {
      j++;
    }
  }
  return acc;
}
static double smj_run(int n, int *lk, double *lv, int *rk, double *rv) {
  int nL = n / 2;
  if (nL < 1)
    nL = 1;
  int nR = n - nL;
  sort_full(lk, lv, nL, cmp_default);
  sort_full(rk, rv, nR, cmp_default);
  double acc = merge_join_sum(lk, lv, nL, rk, rv, nR);
  double ans = 0.0;

  ans = acc;
  return ans;
}
BENCH_MAIN_SCALAR3(
    T004_Module_052, SMJOIN, 4096, 16384, 65536,
    int *lk = (int *)malloc((size_t)n * sizeof(int));
    double *lv = (double *)malloc((size_t)n * sizeof(double));
    int *rk = (int *)malloc((size_t)n * sizeof(int));
    double *rv = (double *)malloc((size_t)n * sizeof(double));
    double ans_scalar = 0.0;
    ,
    {
      bench_rng64_t rng = bench_rng_init(seed);
      int nL = n / 2;
      if (nL < 1)
        nL = 1;
      int nR = n - nL;
      for (int i = 0; i < nL; i++) {
        lk[i] = (int)(bench_rng_next(&rng) & 0xFFFFu);
        lv[i] = bench_rng_double_signed(&rng);
      }
      for (int i = 0; i < nR; i++) {
        if ((bench_rng_next(&rng) & 1ULL) == 0 && nL > 0) {
          int idx = (int)(bench_rng_next(&rng) % (unsigned long long)nL);
          rk[i] = lk[idx];
        } else {
          rk[i] = (int)(bench_rng_next(&rng) & 0xFFFFu);
        }
        rv[i] = bench_rng_double_signed(&rng);
      }
    },
    ans_scalar = smj_run(n, lk, lv, rk, rv), ans_scalar, free(lk);
    free(lv); free(rk); free(rv);)
