#ifndef MVZK_UTIL_NTT_H__
#define MVZK_UTIL_NTT_H__

#include <vector>
#include <algorithm>
#include <cmath>
#include "emp-tool/emp-tool.h"
#include "src/utils/utils.h"

template<typename T>
class Ntt {
public:
  std::size_t logDim;
  T unityRoot;
  std::size_t dim;
  T invDim;

  std::vector<T> unity_root_powers;
  std::vector<T> inv_unity_root_powers;

  Ntt() {}

  Ntt(std::size_t log_dim, T unity_root)
    : logDim(log_dim), unityRoot(unity_root) {

    dim = 1 << logDim;
    T t_dim(dim);
    invDim = t_dim.inv();

    unity_root_powers.resize(dim);
    unity_root_powers[0] = T(1, false);
    T powers(1, false);
    for(std::size_t i = 1; i < dim; ++i) {
      powers = powers * unity_root;
      uint64_t j = reverse_bits(i, logDim);
      unity_root_powers[j] = powers;
    }

    T inv_unity_root = unity_root.inv();
    inv_unity_root_powers.resize(dim);
    inv_unity_root_powers[0] = T(1, false);
    powers = T(1, false);
    for(std::size_t i = 1; i < dim; ++i) {
      powers = powers * inv_unity_root;
      uint64_t j = reverse_bits(i, logDim);
      inv_unity_root_powers[j] = powers;
    }
  }

  // forward NTT
  // from libfqfft,
  // optimized by precomputed power of unity root
  // output in bit-reverse order
  void nttInternalInplace(std::vector<T> &res, 
                          const std::vector<T> &powers) {

    std::size_t m = 1, t = dim>>1;

    for(std::size_t s = 1; s <= logDim; ++s) {

      std::size_t l = 0;

      for(std::size_t i = 0; i < m; ++i) {

        T w = powers[m+i];
        for(std::size_t j = l; j < l + t; ++j) {
          T u = res[j];
          T v = res[j+t] * w;
          res[j] = u + v;
          res[j+t] = u - v;
        }

        l = l + (t<<1);
      }

      t = t >> 1;
      m = m << 1;
    }

  }

  // output in regular order
  void inttInternalInplace(std::vector<T> &res, 
                          const std::vector<T> &powers) {

    std::size_t m = dim>>1, t = 1;

    for(std::size_t s = 1; s <= logDim; ++s) {

      std::size_t l = 0;

      for(std::size_t i = 0; i < m; ++i) {

        T w = powers[m+i];
        for(std::size_t j = l; j < l + t; ++j) {
          T u = res[j];
          T v = res[j+t];
          res[j] = u + v;
          v = u - v;
          res[j+t] = v * w;
        }

        l = l + (t<<1);
      }

      t = t << 1;
      m = m >> 1;
    }

    // 1/n
    for(std::size_t i = 0; i < dim; ++i) {
      res[i] = res[i] * invDim;
    }

  }


  // input dimension lower than output dimension
  void nttInternalInplace(std::vector<T> &res, 
                          const std::vector<T> &in,
                          std::size_t log_dim_in,
                          const std::vector<T> &powers) {

    std::size_t dim_in = 1 << log_dim_in;

    res.resize(dim);

    // first round
    std::size_t duplicity_of_initial_elements = 1 << (logDim - log_dim_in);
    if(duplicity_of_initial_elements > 1) {
      for(std::size_t i = 0, l = 0; i < dim; i += duplicity_of_initial_elements, l++) {
        for(std::size_t j = 0; j < duplicity_of_initial_elements; ++j) {
          res[i+j] = in[l];
        }
      }
    }

    std::size_t m = dim_in>>1, t = 1<<(logDim-log_dim_in);

    for(std::size_t s = 1; s <= log_dim_in; ++s) {

      std::size_t l = 0;

      for(std::size_t i = 0; i < m; ++i) {

        
        for(std::size_t j = l, jj = 0; j < l + t; ++j, ++jj) {
          uint64_t idx = reverse_bits(m*(2*jj+1), logDim);
          T u = res[j];
          T v = res[j+t] * powers[idx];
          res[j] = u + v;
          res[j+t] = u - v;
        }

        l = l + (t<<1);
      }

      t = t << 1;
      m = m >> 1;
    }

  }



  void forwardInplace(std::vector<T> &res) {
    if(res.size() != dim)
      emp::error("ntt: input dimension mismatch");

    nttInternalInplace(res, unity_root_powers);
  }

  void forwardInplaceFixDegree(std::vector<T> &res,
                               const std::vector<T> &in,
                               std::size_t log_dim_in) {
    if(in.size() >= dim)
      emp::error("ntt: input dimension mismatch");

    nttInternalInplace(res, in, log_dim_in, unity_root_powers);
  }

  void backwardInplace(std::vector<T> &res) {
    if(res.size() != dim)
      emp::error("ntt: input dimension mismatch");

    inttInternalInplace(res, inv_unity_root_powers);
  }
};

#endif
