#ifndef MVZK_UTIL_POLY_H__
#define MVZK_UTIL_POLY_H__

#if defined(__x86_64__)
#define NTT_USE_HEXL
#else
#define NTT_USE_NAIVE
#endif

#if defined(NTT_USE_HEXL)
#include "hexl/hexl.hpp"
#endif

#include <iostream>
#include "emp-tool/emp-tool.h"
#include "src/utils/ntt.h"

template<typename T>
class Poly {
public:
  std::size_t logK;
  std::size_t logN;
  std::size_t k;
  std::size_t n;
  std::size_t nDivK;

  // N-th root of unity
  T unityRootK;
  T unityRootN;
  T unityRoot2K;
  T unityRoot2N;

  // NTT
#if defined(NTT_USE_HEXL)
  intel::hexl::NTT nttK;
  intel::hexl::NTT nttN;
#endif
  Ntt<T> nttKNaive;
  Ntt<T> nttNNaive;

  // evaluation points
  std::vector<T> evalPointsN;
  std::vector<T> evalPointsK;
  std::vector<T> evalPoints2N;

  // lagrange polynomial table
  std::vector<T> lagrangeTableKByN;
  std::vector<T> lagrangeTableNByK;
  std::vector<std::vector<T>> lagrangeTableKByTp1;
  std::vector<std::vector<T>> lagrangeTableKByTp1Inv;

  // powers of generator
  std::vector<T> fpGenPowers;
  std::vector<T> fpInvGenPowers;

  Poly() {
    k = n = 0;
  }

  Poly(std::size_t log_k_, std::size_t log_n_) 
    : logK(log_k_), logN(log_n_) {

    n = 1 << log_n_;
    k = 1 << log_k_;
    nDivK = 1 << (log_n_ - log_k_);
    std::cout << "parameters (n, k): " << n << " " << k << std::endl;

    // TODO: unncessary
    unityRoot2K = primitiveRoot(2*k);
    unityRoot2N = primitiveRoot(2*n);

    // HEXL returns 2N-th root of unity, we need N-th
    unityRootK = unityRoot2K * unityRoot2K;
    unityRootN = unityRoot2N * unityRoot2N;

#if defined(NTT_USE_HEXL)
    nttK = intel::hexl::NTT(k, T::PR, unityRoot2K.val);
    nttN = intel::hexl::NTT(n, T::PR, unityRoot2N.val);
#endif
    nttKNaive = Ntt<T>(log_k_, unityRoot2K);
    nttNNaive = Ntt<T>(log_n_, unityRoot2N);
    //std::cout << "init ntt, field size: " << T::PR << std::endl;

    // evaluation points
    evalPointsK.resize(k);
    T powers = unityRoot2K;
    for(std::size_t i = 0; i < k; ++i) {
      evalPointsK[i] = powers;
      powers = powers * unityRootK;
    }

    //auto unity_powers_n = nttN.GetRootOfUnityPowers();
    evalPointsN.resize(n);
    powers = unityRoot2N;
    for(std::size_t i = 0; i < n; ++i) {
      evalPointsN[i] = powers;
      powers = powers * unityRootN;
    }

    evalPoints2N.resize(2*n);
    powers = unityRoot2N;
    for(std::size_t i = 0; i < 2*n; ++i) {
      evalPoints2N[i] = powers;
      powers = powers * unityRoot2N;
    }


    // generate powers of offsets
    T offset = unityRoot2N;
    if (nDivK >= 2) {
      for(std::size_t i = 0; i < nDivK-2; ++i)
        offset = offset * unityRoot2N;
    }
    initFpGenPowers(offset);
  }

  // evaluation from N points to another K points
  void nttEvalN2K(std::vector<T> &out,
      const std::vector<T> &in) {

    // check input dimension
    if(in.size() != n)
      emp::error("input length error");

    std::vector<T> poly_coeff(in);
    // make input bit-reverse order
    for(std::size_t i = 0; i < n; ++i) {
      uint64_t j = reverse_bits(i, logN);
      if(i != j)
        poly_coeff[i] = in[j];
    }

    nttNNaive.backwardInplace(poly_coeff);

    // evaluate to k points
    // shift coefficients by powers of offset
    for(std::size_t i = 0; i < n; ++i) {
      poly_coeff[i] = poly_coeff[i] * fpGenPowers[i];
    }

    // evaluate to n points
    // input in regular order
    // output in bit-reverse order
#if defined(NTT_USE_HEXL)
    nttN.ComputeForward((uint64_t*)(poly_coeff.data()),
                        (uint64_t*)(poly_coeff.data()),
                        1, 1);
#else
    nttNNaive.forwardInplace(poly_coeff);
#endif
  

    // only output k of them
    out.resize(k);
    std::size_t ptr = 0;
    for(std::size_t i = 0; i < k; ++i) {
      uint64_t j = reverse_bits(ptr, logN);
      out[i] = poly_coeff[j];
      ptr += nDivK;
    }
  }

  // evaluation from K points to another N points
  void nttEvalK2N(std::vector<T> &out,
      const std::vector<T> &in) {

    // check input dimension
    if(in.size() != k)
      emp::error("input length error");

    // make input bit-reverse order
    std::vector<T> poly_coeff(in);
    for(std::size_t i = 0; i < k; ++i) {
      uint64_t j = reverse_bits(i, logK);
      if(i != j)
        poly_coeff[i] = in[j];
    }

    // polynomial interpolation degree-(k-1)
    // input in bit-reverse order
    // output in regular order
#if defined(NTT_USE_HEXL)
    nttK.ComputeInverse((uint64_t*)(poly_coeff.data()),
                        (uint64_t*)(poly_coeff.data()),
                        1, 1);
#else
    nttKNaive.backwardInplace(poly_coeff);
#endif

    // make it bit-reverse order
    std::vector<T> buf(poly_coeff);
    for(std::size_t i = 0; i < k; ++i) {
      uint64_t j = reverse_bits(i, logK);
      if(i != j)
        poly_coeff[i] = buf[j];
    }

    // evaluate to n points
    // invoke dim-k NTT for nDivK times
    out.resize(n);
    nttNNaive.forwardInplaceFixDegree(out, poly_coeff, logK);
  }

  // evaluation from N points to another N points
  // from g, g^3, ..., g^{2n-1}
  // to g^2, g^4, ..., g^2n
  void nttEvalN2N(std::vector<T> &out,
      const std::vector<T> &in) {

    // check input dimension
    if(in.size() != n)
      emp::error("input length error");

    std::vector<T> poly_coeff(in);
    // make input bit-reverse order
    for(std::size_t i = 0; i < n; ++i) {
      uint64_t j = reverse_bits(i, logN);
      if(i != j)
        poly_coeff[i] = in[j];
    }

    nttNNaive.backwardInplace(poly_coeff);

    // evaluate to k points
    // shift coefficients by powers of offset
    for(std::size_t i = 1; i < n; ++i) {
      poly_coeff[i] = poly_coeff[i] * evalPoints2N[i-1];
    }

    // evaluate to n points
    // input in regular order
    // output in bit-reverse order
#if defined(NTT_USE_HEXL)
    nttN.ComputeForward((uint64_t*)(poly_coeff.data()),
                        (uint64_t*)(poly_coeff.data()),
                        1, 1);
#else
    nttNNaive.forwardInplace(poly_coeff);
#endif
  
    out.resize(n);
    for(std::size_t i = 0; i < n; ++i) {
      uint64_t j = reverse_bits(i, logN);
      out[i] = poly_coeff[j];
    }
  }

  // evaluation from N points to another K points
  void lagrangeEvalN2K(std::vector<T> &out,
      const std::vector<T> &in) {

    // check input dimension
    if(in.size() != n)
      emp::error("input length error");

    out.resize(k);

    std::size_t ptr = 0;
    for(std::size_t i = 0; i < k; ++i) {
      T res(0, false);
      for(std::size_t j = 0; j < n; ++j) {
        res = res + lagrangeTableKByN[ptr] * in[j];
        ptr++;
      }
      out[i] = res;
    }
  }

  // evaluation from 1 points to another K point 
  void lagrangeEvalOne2K(std::vector<T> &out,
      T in, std::size_t index) {

    out.resize(k);

    std::size_t ptr = index;
    for(std::size_t i = 0; i < k; ++i) {
      out[i] = lagrangeTableKByN[ptr] * in;
      ptr += n;
    }
  }

  T lagrangeEvalOne2K(T in, std::size_t index, std::size_t dest_index) {

    return lagrangeTableKByN[index + dest_index * n] * in;
  }

  // evaluation from K points to another N points
  void lagrangeEvalK2N(std::vector<T> &out,
      const std::vector<T> &in) {

    // check input dimension
    if(in.size() != k)
      emp::error("input length error");

    out.resize(n);

    std::size_t ptr = 0;
    for(std::size_t i = 0; i < n; ++i) {
      T res(0, false);
      for(std::size_t j = 0; j < k; ++j) {
        res = res + lagrangeTableNByK[ptr] * in[j];
        ptr++;
      }
      out[i] = res;
    }
  }

  void initFpGenPowers(T offset) {

    fpGenPowers.resize(n);
    fpGenPowers[0] = T(1, false);
    fpGenPowers[1] = offset;
    for(std::size_t i = 2; i < n; ++i)
      fpGenPowers[i] = fpGenPowers[i-1] * offset;

    fpInvGenPowers.resize(n);
    fpInvGenPowers[0] = T(1, false);
    fpInvGenPowers[1] = offset;
    fpInvGenPowers[1] = fpInvGenPowers[1].inv();
    for(std::size_t i = 2; i < n; ++i)
      fpInvGenPowers[i] = fpInvGenPowers[i-1] * fpInvGenPowers[1];
  }

  void initLagrangeTable() {
    // evaluate from n points to k points
    lagrangeTableKByN.resize(k*n);

    std::size_t ptr = 0;
    for(std::size_t i = 0; i < k; ++i) {
      T dest = evalPointsK[i];

      for(std::size_t j = 0; j < n; ++j) {
        T numer(1, false);
        T denom(1, false);

        for(std::size_t l = 0; l < j; ++l) {
          numer = numer * (dest - evalPointsN[l]);
          denom = denom * (evalPointsN[j] - evalPointsN[l]);
        }
        for(std::size_t l = j+1; l < n; ++l) {
          numer = numer * (dest - evalPointsN[l]);
          denom = denom * (evalPointsN[j] - evalPointsN[l]);
        }
        lagrangeTableKByN[ptr] = numer * (denom.inv());
        ptr++;
      }
    }

    // evaluate from k points to n points
    lagrangeTableNByK.resize(k*n);

    ptr = 0;
    for(std::size_t i = 0; i < n; ++i) {
      T dest = evalPointsN[i];

      for(std::size_t j = 0; j < k; ++j) {
        T numer(1, false);
        T denom(1, false);

        for(std::size_t l = 0; l < j; ++l) {
          numer = numer * (dest - evalPointsK[l]);
          denom = denom * (evalPointsK[j] - evalPointsK[l]);
        }
        for(std::size_t l = j+1; l < k; ++l) {
          numer = numer * (dest - evalPointsK[l]);
          denom = denom * (evalPointsK[j] - evalPointsK[l]);
        }
        lagrangeTableNByK[ptr] = numer * (denom.inv());
        ptr++;
      }
    }

    // evaluate from t+1 points to k points
    // k points are eval_points_n t, t+1, ..., n-1
    std::size_t t = n - k;
    lagrangeTableKByTp1.resize(k);
    
    for(std::size_t mm = 0; mm < k; ++mm) {

      lagrangeTableKByTp1[mm].resize(k*(t+1));

      ptr = 0;
      for(std::size_t i = 0; i < k; ++i) {
        T dest = evalPointsN[t+i];

        T numer(1, false);
        T denom(1, false);

        for(std::size_t l = 0; l < t; ++l) {
          numer = numer * (dest - evalPointsN[l]);
          denom = denom * (evalPointsK[mm] - evalPointsN[l]);
        }
        lagrangeTableKByTp1[mm][ptr] = numer * (denom.inv());
        ptr++;

        for(std::size_t j = 0; j < t; ++j) {
          T numer(1, false);
          T denom(1, false);

          numer = numer * (dest - evalPointsK[mm]);
          denom = denom * (evalPointsN[j] - evalPointsK[mm]);

          for(std::size_t l = 0; l < j; ++l) {
            numer = numer * (dest - evalPointsN[l]);
            denom = denom * (evalPointsN[j] - evalPointsN[l]);
          }
          for(std::size_t l = j+1; l < t; ++l) {
            numer = numer * (dest - evalPointsN[l]);
            denom = denom * (evalPointsN[j] - evalPointsN[l]);
          }
          lagrangeTableKByTp1[mm][ptr] = numer * (denom.inv());
          ptr++;
        }
      }
    }

    // evaluate from t+1 points to k points
    // points are eval_points_k 0, ..., k-1
    lagrangeTableKByTp1Inv.resize(k);
    
    for(std::size_t mm = 0; mm < k; ++mm) {

      lagrangeTableKByTp1Inv[mm].resize(k*(t+1));

      ptr = 0;
      for(std::size_t i = 0; i < k; ++i) {
        T dest = evalPointsK[i];

        T numer(1, false);
        T denom(1, false);

        for(std::size_t j = 0; j <= t; ++j) {
          T numer(1, false);
          T denom(1, false);

          for(std::size_t l = 0; l < j; ++l) {
            numer = numer * (dest - evalPointsN[l]);
            denom = denom * (evalPointsN[j] - evalPointsN[l]);
          }
          for(std::size_t l = j+1; l <= t; ++l) {
            numer = numer * (dest - evalPointsN[l]);
            denom = denom * (evalPointsN[j] - evalPointsN[l]);
          }
          lagrangeTableKByTp1Inv[mm][ptr] = numer * (denom.inv());
          ptr++;
        }
      }
    }


  }

  // generate primitive root of unity
  // the code is from libfqfft
  // don't use HEXL's because it seems to be wrong there
  T primitiveRoot(uint64_t degree) {
    std::vector<T> powers_of_gen(T::PR_bit_len);
    powers_of_gen[0] = T(1, false);
    powers_of_gen[1] = T(T::Generator, false);
    for(std::size_t i = 2; i < T::PR_bit_len; ++i) {
      powers_of_gen[i] = powers_of_gen[i-1] * powers_of_gen[i-1];
    }

    // compute r = g^{group_order/degree}
    // thus r^degree = g^group_order = -1
    T exponent = T(T::PR-1, false) * (T(degree, false).inv());
    T res(1, false);
    for(std::size_t i = 0; i < T::PR_bit_len; ++i) {
      uint64_t bit_decomp = (((exponent.val>>i)&1LL)==1LL)?1LL:0LL;
      if(bit_decomp == 1LL)
        res = res * powers_of_gen[i];
    }

    // compute r * r
    res = res * res;

    return res;
  }

};

#endif
