// SPDX-License-Identifier: LGPL-3.0-or-later
#include "pair_tab.h"

#include <cassert>
#include <cmath>
#include <iostream>
#include <vector>

#include "errors.h"

inline void _pair_tabulated_inter(double& ener,
                                  double& fscale,
                                  const double* table_info,
                                  const double* table_data,
                                  const double* dr) {
  // info size: 3
  const double& rmin = table_info[0];
  const double& hh = table_info[1];
  const double hi = 1. / hh;
  const unsigned nspline = unsigned(table_info[2] + 0.1);
  const unsigned ndata = nspline * 4;

  double r2 = dr[0] * dr[0] + dr[1] * dr[1] + dr[2] * dr[2];
  double rr = sqrt(r2);
  double uu = (rr - rmin) * hi;
  // std::cout << rr << " " << rmin << " " << hh << " " << uu << std::endl;
  if (uu < 0) {
    std::cerr << "coord go beyond table lower boundary" << std::endl;
    throw deepmd::deepmd_exception();
  }
  int idx = uu;
  if (idx >= nspline) {
    fscale = ener = 0;
    return;
  }
  uu -= idx;
  assert(idx >= 0);
  assert(uu >= 0 && uu < 1);

  const double& a3 = table_data[4 * idx + 0];
  const double& a2 = table_data[4 * idx + 1];
  const double& a1 = table_data[4 * idx + 2];
  const double& a0 = table_data[4 * idx + 3];

  double etmp = (a3 * uu + a2) * uu + a1;
  ener = etmp * uu + a0;
  fscale = (2. * a3 * uu + a2) * uu + etmp;
  fscale *= -hi;
}

template <typename FPTYPE>
void _pair_tab_jloop(FPTYPE* energy,
                     FPTYPE* force,
                     FPTYPE* virial,
                     int& jiter,
                     const int& i_idx,
                     const int& nnei,
                     const int& i_type_shift,
                     const double* p_table_info,
                     const double* p_table_data,
                     const int& tab_stride,
                     const FPTYPE* rij,
                     const FPTYPE* scale,
                     const int* type,
                     const int* nlist,
                     const int* natoms,
                     const std::vector<int>& sel) {
  const FPTYPE i_scale = scale[i_idx];
  for (int ss = 0; ss < sel.size(); ++ss) {
    int j_type = ss;
    const double* cur_table_data =
        p_table_data + (i_type_shift + j_type) * tab_stride;
    for (int jj = 0; jj < sel[ss]; ++jj) {
      int j_idx = nlist[i_idx * nnei + jiter];
      if (j_idx < 0) {
        jiter++;
        continue;
      }
      assert(j_type == type[j_idx]);
      double dr[3];
      for (int dd = 0; dd < 3; ++dd) {
        dr[dd] = rij[(i_idx * nnei + jiter) * 3 + dd];
      }
      double r2 = dr[0] * dr[0] + dr[1] * dr[1] + dr[2] * dr[2];
      double ri = 1. / sqrt(r2);
      double ener, fscale;
      _pair_tabulated_inter(ener, fscale, p_table_info, cur_table_data, dr);
      energy[i_idx] += 0.5 * ener;
      for (int dd = 0; dd < 3; ++dd) {
        force[i_idx * 3 + dd] -= fscale * dr[dd] * ri * 0.5 * i_scale;
        force[j_idx * 3 + dd] += fscale * dr[dd] * ri * 0.5 * i_scale;
      }
      for (int dd0 = 0; dd0 < 3; ++dd0) {
        for (int dd1 = 0; dd1 < 3; ++dd1) {
          virial[i_idx * 9 + dd0 * 3 + dd1] +=
              0.5 * fscale * dr[dd0] * dr[dd1] * ri * 0.5 * i_scale;
          virial[j_idx * 9 + dd0 * 3 + dd1] +=
              0.5 * fscale * dr[dd0] * dr[dd1] * ri * 0.5 * i_scale;
        }
      }
      jiter++;
    }
  }
}

inline void _cum_sum(std::vector<int>& sec, const std::vector<int>& n_sel) {
  sec.resize(n_sel.size() + 1);
  sec[0] = 0;
  for (int ii = 1; ii < sec.size(); ++ii) {
    sec[ii] = sec[ii - 1] + n_sel[ii - 1];
  }
}

template <typename FPTYPE>
void deepmd::pair_tab_cpu(FPTYPE* energy,
                          FPTYPE* force,
                          FPTYPE* virial,
                          const double* p_table_info,
                          const double* p_table_data,
                          const FPTYPE* rij,
                          const FPTYPE* scale,
                          const int* type,
                          const int* nlist,
                          const int* natoms,
                          const std::vector<int>& sel_a,
                          const std::vector<int>& sel_r) {
  std::vector<int> sec_a;
  std::vector<int> sec_r;
  _cum_sum(sec_a, sel_a);
  _cum_sum(sec_r, sel_r);
  const int nloc = natoms[0];
  const int nall = natoms[1];
  const int nnei = sec_a.back() + sec_r.back();
  const int ntypes = int(p_table_info[3] + 0.1);
  const int nspline = p_table_info[2] + 0.1;
  const int tab_stride = 4 * nspline;

  // fill results with 0
  for (int ii = 0; ii < nloc; ++ii) {
    int i_idx = ii;
    energy[i_idx] = 0;
  }
  for (int ii = 0; ii < nall; ++ii) {
    int i_idx = ii;
    force[i_idx * 3 + 0] = (FPTYPE)0.;
    force[i_idx * 3 + 1] = (FPTYPE)0.;
    force[i_idx * 3 + 2] = (FPTYPE)0.;
    for (int dd = 0; dd < 9; ++dd) {
      virial[i_idx * 9 + dd] = (FPTYPE)0.;
    }
  }
  // compute force of a frame
  int i_idx = 0;
  for (int tt = 0; tt < ntypes; ++tt) {
    for (int ii = 0; ii < natoms[2 + tt]; ++ii) {
      int i_type = type[i_idx];
      assert(i_type == tt);
      const int i_type_shift = i_type * ntypes;
      int jiter = 0;
      // a neighbor
      _pair_tab_jloop(energy, force, virial, jiter, i_idx, nnei, i_type_shift,
                      p_table_info, p_table_data, tab_stride, rij, scale, type,
                      nlist, natoms, sel_a);
      // r neighbor
      _pair_tab_jloop(energy, force, virial, jiter, i_idx, nnei, i_type_shift,
                      p_table_info, p_table_data, tab_stride, rij, scale, type,
                      nlist, natoms, sel_r);
      i_idx++;
    }
  }
}

template void deepmd::pair_tab_cpu<float>(float* energy,
                                          float* force,
                                          float* virial,
                                          const double* table_info,
                                          const double* table_data,
                                          const float* rij,
                                          const float* scale,
                                          const int* type,
                                          const int* nlist,
                                          const int* natoms,
                                          const std::vector<int>& sel_a,
                                          const std::vector<int>& sel_r);

template void deepmd::pair_tab_cpu<double>(double* energy,
                                           double* force,
                                           double* virial,
                                           const double* table_info,
                                           const double* table_data,
                                           const double* rij,
                                           const double* scale,
                                           const int* type,
                                           const int* nlist,
                                           const int* natoms,
                                           const std::vector<int>& sel_a,
                                           const std::vector<int>& sel_r);
