#ifndef MVZK_UTIL_LAGRANGE_H__
#define MVZK_UTIL_LAGRANGE_H__

//#include "emp-tool/emp-tool.h"

template<typename T>
class Lagrange {
public:
  std::size_t n; // degree
  std::size_t nPoints;

  // evaluation points
  std::vector<T> evalPoints;

  // lagrange polynomial table
  std::vector<T> lagrangeTable;
  std::vector<T> lagCoeffDivLow;
  std::vector<T> lagCoeffDivHigh;


  Lagrange() {
    n = 0;
  }

  Lagrange(std::size_t n_) : n(n_) {

    // evaluation points

    nPoints = 2 * n + 1;
    evalPoints.resize(nPoints);
    T powers(3);
    for(std::size_t i = 0; i < 2*n+1; ++i) {
      evalPoints[i] = powers;
      powers = powers * 3;
    }
  }

  // evaluation from N points to another K points
  void lagrangeEvalShiftPoints(std::vector<T> &out,
      const std::vector<T> &in) {

    // check input dimension
    if(in.size() != n+1)
      emp::error("input length error");

    out.resize(n);

    std::size_t ptr = 0;
    for(std::size_t i = 0; i < n; ++i) {
      T res(0, false);
      for(std::size_t j = 0; j < n+1; ++j) {
        res = res + lagrangeTable[ptr] * in[j];
        ptr++;
      }
      out[i] = res;
    }
  }

  void initLagrangeTable() {
    // evaluate from n points to k points
    lagrangeTable.resize(n*(n+1));

    std::size_t ptr = 0;
    for(std::size_t i = 0; i < n; ++i) {
      T dest = evalPoints[i+n+1];

      for(std::size_t j = 0; j < n+1; ++j) {
        T numer(1, false);
        T denom(1, false);

        for(std::size_t l = 0; l < j; ++l) {
          numer = numer * (dest - evalPoints[l]);
          denom = denom * (evalPoints[j] - evalPoints[l]);
        }
        for(std::size_t l = j+1; l < n; ++l) {
          numer = numer * (dest - evalPoints[l]);
          denom = denom * (evalPoints[j] - evalPoints[l]);
        }
        lagrangeTable[ptr] = numer * (denom.inv());
        ptr++;
      }
    }


    lagCoeffDivLow.resize(n+1);
    for(std::size_t i = 0; i < n+1; ++i) {
      T res(1, false);
      for(std::size_t j = 0; j < n+1; ++j) {
        if(j == i) continue;
        res = res * (evalPoints[i] - evalPoints[j]);
      }
      lagCoeffDivLow[i] = res.inv();
    }

    lagCoeffDivHigh.resize(2*n+1);
    for(std::size_t i = 0; i < 2*n+1; ++i) {
      T res(1, false);
      for(std::size_t j = 0; j < 2*n+1; ++j) {
        if(j == i) continue;
        res = res * (evalPoints[i] - evalPoints[j]);
      }
      lagCoeffDivHigh[i] = res.inv();
    }

  }


  void computeLagCoeff(std::vector<T> &common_lag_coeff_low,
                       std::vector<T> &common_lag_coeff_high,
                       T dest) {

    common_lag_coeff_low.resize(n+1);
    for(std::size_t i = 0; i < n+1; ++i) {
      T res(1, false);
      for(std::size_t j = 0; j < n+1; ++j) {
        if(j == i) continue;
        res = res * (dest - evalPoints[j]);
      }
      common_lag_coeff_low[i] = res * lagCoeffDivLow[i];
    }
    
    common_lag_coeff_high.resize(2*n+1);
    for(std::size_t i = 0; i < 2*n+1; ++i) {
      T res(1, false);
      for(std::size_t j = 0; j < 2*n+1; ++j) {
        if(j == i) continue;
        res = res * (dest - evalPoints[j]);
      }
      common_lag_coeff_high[i] = res * lagCoeffDivHigh[i];
    }
  }

};

#endif
