#include "src/utils/poly.h"
#include "src/utils/fields/fp59.h"

template<typename T>
void test_lagrange() {

  std::size_t log_k = 4, log_n = 6;
  std::size_t k = 1 << log_k, n = 1 << log_n;

  Poly<FP59> poly(log_k, log_n);
  poly.initLagrangeTable();

  // test k to n
  std::vector<uint64_t> buf(k);
  std::vector<T> vec_a(k);
  std::vector<T> vec_b(n);

  PRG prg;
  prg.random_data(buf.data(), k * sizeof(uint64_t));
  for(std::size_t i = 0; i < k; ++i) {
    vec_a[i] = T(buf[i]);
  }
  std::vector<T> vec_c(vec_a);

  poly.lagrangeEvalK2N(vec_b, vec_a);
  poly.lagrangeEvalN2K(vec_a, vec_b);

  for(std::size_t i = 0; i <k; ++i) {
    if(!(vec_a[i] == vec_c[i])) {
      emp::error("lagrange error");
    }
  }

  // test individuals
  std::vector<T> vec_d(k);
  for(std::size_t i = 0; i < k; ++i)
    vec_d[i] = T(0, false);
  for(std::size_t j = 0; j < n; ++j) {
    std::vector<T> tmp(k);
    poly.lagrangeEvalOne2K(tmp, vec_b[j], j);
    for(std::size_t i = 0; i < k; ++i)
      vec_d[i] = vec_d[i] + tmp[i];
  }

  for(std::size_t i = 0; i < k; ++i) {
    if(!(vec_a[i] == vec_d[i])) {
      emp::error("lagrange individual error");
    }
  }
}

template<typename T>
void test_poly() {

  std::size_t log_k = 4, log_n = 6;
  std::size_t k = 1 << log_k, n = 1 << log_n;

  Poly<FP59> poly(log_k, log_n);
  poly.initLagrangeTable();

  // test k to n
  std::vector<uint64_t> buf(n);
  std::vector<T> vec_a(k);
  std::vector<T> vec_b(n);
  std::vector<T> vec_c(n);

  PRG prg;
  prg.random_data(buf.data(), n * sizeof(uint64_t));
  for(std::size_t i = 0; i < k; ++i) {
    vec_a[i] = T(buf[i]);
  }

  poly.nttEvalK2N(vec_b, vec_a);
  poly.lagrangeEvalK2N(vec_c, vec_a);

  for(std::size_t i = 0; i < n; ++i) {
    if(vec_b[i] != vec_c[i])
      std::cout << "error k2n" << std::endl;
      //emp::error("error k2n");
  }

  vec_a.resize(n);
  vec_b.resize(k);
  vec_c.resize(k);
  for(std::size_t i = 0; i < n; ++i) {
    vec_a[i] = T(buf[i]);
  }

  poly.nttEvalN2K(vec_b, vec_a);
  poly.lagrangeEvalN2K(vec_c, vec_a);

  for(std::size_t i = 0; i < k; ++i) {
    if(vec_b[i] != vec_c[i])
      emp::error("error n2k");
  }

}


template<typename T>
void bench() {

  std::size_t log_k = 6, log_n = 8;
  std::size_t k = 1 << log_k, n = 1 << log_n;

  Poly<FP59> poly(log_k, log_n);
  poly.initLagrangeTable();

  // test k to n
  std::vector<uint64_t> buf(n);
  std::vector<T> vec_a(k);
  std::vector<T> vec_b(n);
  std::vector<T> vec_c(n);

  PRG prg;
  prg.random_data(buf.data(), n * sizeof(uint64_t));
  for(std::size_t i = 0; i < k; ++i) {
    vec_a[i] = T(buf[i]);
  }

  std::size_t test_n = 100000;

  auto start = clock_start();
  for(std::size_t i = 0; i < test_n; ++i) {
    poly.nttEvalK2N(vec_b, vec_a);
  }
  std::cout << "ntt: k to n: " << time_from(start)/test_n << std::endl;

  start = clock_start();
  for(std::size_t i = 0; i < test_n; ++i) {
    poly.lagrangeEvalK2N(vec_c, vec_a);
  }
  std::cout << "lagrange k to n: " << time_from(start)/test_n << std::endl;

  for(std::size_t i = 0; i < n; ++i) {
    if(vec_b[i] != vec_c[i])
      emp::error("error k2n");
  }

  vec_a.resize(n);
  vec_b.resize(k);
  vec_c.resize(k);
  for(std::size_t i = 0; i < n; ++i) {
    vec_a[i] = T(buf[i]);
  }

  start = clock_start();
  for(std::size_t i = 0; i < test_n; ++i) {
    poly.nttEvalN2K(vec_b, vec_a);
  }
  std::cout << "ntt n to k: " << time_from(start)/test_n << std::endl;

  start = clock_start();
  for(std::size_t i = 0; i < test_n; ++i) {
    poly.lagrangeEvalN2K(vec_c, vec_a);
  }
  std::cout << "lagrange n to k: " << time_from(start)/test_n << std::endl;

  for(std::size_t i = 0; i < k; ++i) {
    if(vec_b[i] != vec_c[i])
      emp::error("error n2k");
  }

}

int main(void) {
  test_lagrange<FP59>();
  test_poly<FP59>();
  bench<FP59>();

  return 0;
}
