
#include "bench_harness.h"
#include "bench_utils.h"
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
void kernel_run(int n, const int *pa, const int *pb, double *ans_out) {
  uint32_t mods[3];
  uint32_t roots[3];
  mods[0] = 998244353U;
  mods[1] = 1004535809U;
  mods[2] = 469762049U;
  roots[0] = 3U;
  roots[1] = 3U;
  roots[2] = 3U;
  int len_ab = n / 4;
  if (len_ab < 1)
    len_ab = 1;
  int need = 1;
  while (need < (2 * len_ab))
    need <<= 1;
  int L = need;
  int lg = 0;
  while ((1 << lg) < L)
    lg++;
  int *rev = (int *)malloc((size_t)L * sizeof(int));
  int32_t *F1 = (int32_t *)malloc((size_t)L * sizeof(int32_t));
  int32_t *F2 = (int32_t *)malloc((size_t)L * sizeof(int32_t));
  int32_t *F3 = (int32_t *)malloc((size_t)L * sizeof(int32_t));
  int32_t *G1 = (int32_t *)malloc((size_t)L * sizeof(int32_t));
  int32_t *G2 = (int32_t *)malloc((size_t)L * sizeof(int32_t));
  int32_t *G3 = (int32_t *)malloc((size_t)L * sizeof(int32_t));
  if (!rev || !F1 || !F2 || !F3 || !G1 || !G2 || !G3) {
    if (rev)
      free(rev);
    if (F1)
      free(F1);
    if (F2)
      free(F2);
    if (F3)
      free(F3);
    if (G1)
      free(G1);
    if (G2)
      free(G2);
    if (G3)
      free(G3);
    *ans_out = 0.0;
    return;
  }
  for (int i = 0; i < L; i++) {
    int x = i;
    int r = 0;
    for (int j = 0; j < lg; j++) {
      r = (r << 1) | (x & 1);
      x >>= 1;
    }
    rev[i] = r;
  }
  for (int i = 0; i < L; i++) {
    uint32_t va = (i < len_ab) ? (uint32_t)(pa[i] >= 0 ? pa[i] : -pa[i]) : 0U;
    uint32_t vb = (i < len_ab) ? (uint32_t)(pb[i] >= 0 ? pb[i] : -pb[i]) : 0U;
    F1[i] = (int32_t)(va % mods[0]);
    F2[i] = (int32_t)(va % mods[1]);
    F3[i] = (int32_t)(va % mods[2]);
    G1[i] = (int32_t)(vb % mods[0]);
    G2[i] = (int32_t)(vb % mods[1]);
    G3[i] = (int32_t)(vb % mods[2]);
  }
  for (int which = 0; which < 3; which++) {
    uint32_t mod = mods[which];
    uint32_t root = roots[which];
    int32_t *FA;
    int32_t *GB;
    if (which == 0) {
      FA = F1;
      GB = G1;
    } else if (which == 1) {
      FA = F2;
      GB = G2;
    } else {
      FA = F3;
      GB = G3;
    }
    for (int i = 0; i < L; i++) {
      int j = rev[i];
      if (i < j) {
        int32_t tmp = FA[i];
        FA[i] = FA[j];
        FA[j] = tmp;
      }
    }
    for (int len = 2; len <= L; len <<= 1) {
      uint64_t pw_base = (uint64_t)root;
      uint64_t pw_exp = ((uint64_t)(mod - 1U) / (uint64_t)len);
      uint64_t wlen = 1ULL;
      while (pw_exp) {
        if (pw_exp & 1ULL) {
          wlen = (wlen * pw_base) % mod;
        }
        pw_base = (pw_base * pw_base) % mod;
        pw_exp >>= 1ULL;
      }
      for (int i = 0; i < L; i += len) {
        uint64_t w = 1ULL;
        int half = len >> 1;
        for (int j = 0; j < half; j++) {
          uint32_t u = (uint32_t)FA[i + j];
          uint32_t vtmp = (uint32_t)FA[i + j + half];
          uint64_t vv = (uint64_t)vtmp * w % mod;
          uint32_t nv = (uint32_t)vv;
          uint32_t nx = u + nv;
          if (nx >= mod)
            nx -= mod;
          FA[i + j] = (int32_t)nx;
          uint32_t ny = (u >= nv) ? (u - nv) : (u + mod - nv);
          FA[i + j + half] = (int32_t)ny;
          w = (w * wlen) % mod;
        }
      }
    }
    for (int i = 0; i < L; i++) {
      int j = rev[i];
      if (i < j) {
        int32_t tmp = GB[i];
        GB[i] = GB[j];
        GB[j] = tmp;
      }
    }
    for (int len = 2; len <= L; len <<= 1) {
      uint64_t pw_base2 = (uint64_t)root;
      uint64_t pw_exp2 = ((uint64_t)(mod - 1U) / (uint64_t)len);
      uint64_t wlen2 = 1ULL;
      while (pw_exp2) {
        if (pw_exp2 & 1ULL) {
          wlen2 = (wlen2 * pw_base2) % mod;
        }
        pw_base2 = (pw_base2 * pw_base2) % mod;
        pw_exp2 >>= 1ULL;
      }
      for (int i = 0; i < L; i += len) {
        uint64_t w2 = 1ULL;
        int half2 = len >> 1;
        for (int j = 0; j < half2; j++) {
          uint32_t u2 = (uint32_t)GB[i + j];
          uint32_t vtmp2 = (uint32_t)GB[i + j + half2];
          uint64_t vv2 = (uint64_t)vtmp2 * w2 % mod;
          uint32_t nv2 = (uint32_t)vv2;
          uint32_t nx2 = u2 + nv2;
          if (nx2 >= mod)
            nx2 -= mod;
          GB[i + j] = (int32_t)nx2;
          uint32_t ny2 = (u2 >= nv2) ? (u2 - nv2) : (u2 + mod - nv2);
          GB[i + j + half2] = (int32_t)ny2;
          w2 = (w2 * wlen2) % mod;
        }
      }
    }
    for (int i = 0; i < L; i++) {
      uint64_t prod = (uint64_t)((uint32_t)FA[i]) * (uint64_t)((uint32_t)GB[i]);
      prod %= mod;
      FA[i] = (int32_t)prod;
    }
    uint64_t base_inv = (uint64_t)root;
    uint64_t exp_inv = (uint64_t)(mod - 2U);
    uint64_t root_inv = 1ULL;
    while (exp_inv) {
      if (exp_inv & 1ULL) {
        root_inv = (root_inv * base_inv) % mod;
      }
      base_inv = (base_inv * base_inv) % mod;
      exp_inv >>= 1ULL;
    }
    for (int i = 0; i < L; i++) {
      int j = rev[i];
      if (i < j) {
        int32_t tmp = FA[i];
        FA[i] = FA[j];
        FA[j] = tmp;
      }
    }
    for (int len = 2; len <= L; len <<= 1) {
      uint64_t pw_base3 = root_inv;
      uint64_t pw_exp3 = ((uint64_t)(mod - 1U) / (uint64_t)len);
      uint64_t wlen3 = 1ULL;
      while (pw_exp3) {
        if (pw_exp3 & 1ULL) {
          wlen3 = (wlen3 * pw_base3) % mod;
        }
        pw_base3 = (pw_base3 * pw_base3) % mod;
        pw_exp3 >>= 1ULL;
      }
      for (int i = 0; i < L; i += len) {
        uint64_t w3 = 1ULL;
        int half3 = len >> 1;
        for (int j = 0; j < half3; j++) {
          uint32_t u3 = (uint32_t)FA[i + j];
          uint32_t vtmp3 = (uint32_t)FA[i + j + half3];
          uint64_t vv3 = (uint64_t)vtmp3 * w3 % mod;
          uint32_t nv3 = (uint32_t)vv3;
          uint32_t nx3 = u3 + nv3;
          if (nx3 >= mod)
            nx3 -= mod;
          FA[i + j] = (int32_t)nx3;
          uint32_t ny3 = (u3 >= nv3) ? (u3 - nv3) : (u3 + mod - nv3);
          FA[i + j + half3] = (int32_t)ny3;
          w3 = (w3 * wlen3) % mod;
        }
      }
    }
    uint64_t b4 = (uint64_t)L;
    uint64_t e4 = (uint64_t)(mod - 2U);
    uint64_t invL = 1ULL;
    while (e4) {
      if (e4 & 1ULL) {
        invL = (invL * b4) % mod;
      }
      b4 = (b4 * b4) % mod;
      e4 >>= 1ULL;
    }
    for (int i = 0; i < L; i++) {
      uint64_t z = (uint64_t)((uint32_t)FA[i]) * invL % mod;
      FA[i] = (int32_t)z;
    }
  }
  uint64_t m1 = (uint64_t)mods[0];
  uint64_t m2 = (uint64_t)mods[1];
  uint64_t m3 = (uint64_t)mods[2];
  uint64_t inv_m1_mod_m2 = 1ULL;
  {
    uint64_t basep = m1 % m2;
    uint64_t expp = m2 - 2ULL;
    uint64_t resv = 1ULL;
    while (expp) {
      if (expp & 1ULL) {
        __uint128_t tmp = (__uint128_t)resv * basep;
        resv = (uint64_t)(tmp % m2);
      }
      __uint128_t tmp2 = (__uint128_t)basep * basep;
      basep = (uint64_t)(tmp2 % m2);
      expp >>= 1ULL;
    }
    inv_m1_mod_m2 = resv;
  }
  uint64_t m12 = m1 * m2;
  uint64_t inv_m12_mod_m3 = 1ULL;
  {
    uint64_t basep = m12 % m3;
    uint64_t expp = m3 - 2ULL;
    uint64_t resv = 1ULL;
    while (expp) {
      if (expp & 1ULL) {
        __uint128_t tmp = (__uint128_t)resv * basep;
        resv = (uint64_t)(tmp % m3);
      }
      __uint128_t tmp2 = (__uint128_t)basep * basep;
      basep = (uint64_t)(tmp2 % m3);
      expp >>= 1ULL;
    }
    inv_m12_mod_m3 = resv;
  }
  uint64_t checksum = 0ULL;
  int outLen = 2 * len_ab - 1;
  if (outLen < 1)
    outLen = 1;
  if (outLen > L)
    outLen = L;
  for (int i = 0; i < outLen; i++) {
    uint64_t r1 = (uint64_t)(uint32_t)F1[i];
    uint64_t r2 = (uint64_t)(uint32_t)F2[i];
    uint64_t r3 = (uint64_t)(uint32_t)F3[i];
    uint64_t x = r1;
    uint64_t diff12 = (r2 >= (x % m2)) ? (r2 - (x % m2)) : (r2 + m2 - (x % m2));
    uint64_t t12 = ((__uint128_t)diff12 * inv_m1_mod_m2) % m2;
    uint64_t x12 = x + t12 * m1;
    uint64_t diff123 =
        (r3 >= (x12 % m3)) ? (r3 - (x12 % m3)) : (r3 + m3 - (x12 % m3));
    uint64_t t123 = ((__uint128_t)diff123 * inv_m12_mod_m3) % m3;
    __uint128_t big = (__uint128_t)x12 + (__uint128_t)t123 * (__uint128_t)m12;
    uint64_t coef = (uint64_t)big;
    checksum += coef * (uint64_t)(i + 1);
  }
  double ans_double = 0.0;

  ans_double = (double)checksum;

  *ans_out = ans_double;
  free(rev);
  free(F1);
  free(F2);
  free(F3);
  free(G1);
  free(G2);
  free(G3);
}
BENCH_MAIN_SCALAR3(
    T003_Code_037, NTT, 4096, 16384, 65536,
    int *pa = (int *)malloc((size_t)n * sizeof(int));
    int *pb = (int *)malloc((size_t)n * sizeof(int)); double ans_scalar = 0.0;
    ,
    {
      bench_rng64_t rng = bench_rng_init(seed);
      for (int i = 0; i < n; i++) {
        pa[i] = (int)(bench_rng_next(&rng) % 1000ULL);
        pb[i] = (int)(bench_rng_next(&rng) % 1000ULL);
      }
    },
    kernel_run(n, pa, pb, &ans_scalar), ans_scalar, free(pa);
    free(pb);)
