// SPDX-License-Identifier: LGPL-3.0-or-later
#include "neighbor_list.h"

#include <iostream>
#include <limits>

#include "device.h"
// #include <iomanip>

// using namespace std;
enum {
  MAX_WARN_IDX_OUT_OF_BOUND = 10,
};

bool is_loc(const std::vector<int>& idx,
            const std::vector<int>& nat_stt,
            const std::vector<int>& nat_end) {
  bool ret = true;
  for (int dd = 0; dd < 3; ++dd) {
    ret = ret && idx[dd] >= nat_stt[dd];
  }
  for (int dd = 0; dd < 3; ++dd) {
    ret = ret && idx[dd] < nat_end[dd];
  }
  return ret;
}

int collapse_index(const std::vector<int>& idx, const std::vector<int>& size) {
  return (idx[0] * size[1] + idx[1]) * size[2] + idx[2];
}

void expand_index(std::vector<int>& o_idx,
                  const int& i_idx,
                  const std::vector<int>& size) {
  int tmp1 = i_idx / size[2];
  o_idx[2] = i_idx - tmp1 * size[2];
  o_idx[0] = tmp1 / size[1];
  o_idx[1] = tmp1 - o_idx[0] * size[1];
}

void build_clist(std::vector<std::vector<int> >& clist,
                 const std::vector<double>& coord,
                 const int& nloc,
                 const std::vector<int>& nat_stt,
                 const std::vector<int>& nat_end,
                 const std::vector<int>& ext_stt,
                 const std::vector<int>& ext_end,
                 const SimulationRegion<double>& region,
                 const std::vector<int>& global_grid) {
  static int count_warning_loc_idx_lower = 0;
  static int count_warning_loc_idx_upper = 0;
  static int count_warning_ghost_idx_lower = 0;
  static int count_warning_ghost_idx_upper = 0;
  // compute region info, in terms of internal coord
  int nall = coord.size() / 3;
  std::vector<int> ext_ncell(3);
  for (int dd = 0; dd < 3; ++dd) {
    ext_ncell[dd] = ext_end[dd] - ext_stt[dd];
  }
  int ncell = ext_ncell[0] * ext_ncell[1] * ext_ncell[2];
  std::vector<double> cell_size(3);
  for (int dd = 0; dd < 3; ++dd) {
    cell_size[dd] = 1. / global_grid[dd];
  }
  std::vector<double> nat_orig(3);
  for (int dd = 0; dd < 3; ++dd) {
    nat_orig[dd] = nat_stt[dd] * cell_size[dd];
  }
  std::vector<int> idx_orig_shift(3);
  for (int dd = 0; dd < 3; ++dd) {
    idx_orig_shift[dd] = nat_stt[dd] - ext_stt[dd];
  }

  // allocate the reserve the cell list
  clist.resize(ncell);
  int esti_natom_per_cell = nall / ncell + 10;
  for (unsigned ii = 0; ii < clist.size(); ++ii) {
    clist[ii].clear();
    clist[ii].reserve(esti_natom_per_cell);
  }

  // build the cell list
  for (int ii = 0; ii < nloc; ++ii) {
    double inter[3];
    region.phys2Inter(inter, &(coord[ii * 3]));
    std::vector<int> idx(3);
    for (int dd = 0; dd < 3; ++dd) {
      idx[dd] = (inter[dd] - nat_orig[dd]) / cell_size[dd];
      if (inter[dd] - nat_orig[dd] < 0.) {
        idx[dd]--;
      }
      if (idx[dd] < nat_stt[dd]) {
        if (count_warning_loc_idx_lower < MAX_WARN_IDX_OUT_OF_BOUND) {
          std::cerr << "# warning: loc idx out of lower bound (ignored if "
                       "warned for more than "
                    << MAX_WARN_IDX_OUT_OF_BOUND << " times) " << std::endl;
          count_warning_loc_idx_lower++;
        }
        idx[dd] = nat_stt[dd];
      } else if (idx[dd] >= nat_end[dd]) {
        if (count_warning_loc_idx_upper < MAX_WARN_IDX_OUT_OF_BOUND) {
          std::cerr << "# warning: loc idx out of upper bound (ignored if "
                       "warned for more than "
                    << MAX_WARN_IDX_OUT_OF_BOUND << " times) " << std::endl;
          count_warning_loc_idx_upper++;
        }
        idx[dd] = nat_end[dd] - 1;
      }
      idx[dd] += idx_orig_shift[dd];
    }
    clist[collapse_index(idx, ext_ncell)].push_back(ii);
  }
  for (int ii = nloc; ii < nall; ++ii) {
    double inter[3];
    region.phys2Inter(inter, &(coord[ii * 3]));
    std::vector<int> idx(3);
    for (int dd = 0; dd < 3; ++dd) {
      idx[dd] = (inter[dd] - nat_orig[dd]) / cell_size[dd];
      if (inter[dd] - nat_orig[dd] < 0.) {
        idx[dd]--;
      }
      if (idx[dd] < ext_stt[dd]) {
        if (count_warning_ghost_idx_lower < MAX_WARN_IDX_OUT_OF_BOUND &&
            fabs((inter[dd] - nat_orig[dd]) - (ext_stt[dd] * cell_size[dd])) >
                fabs(ext_stt[dd] * cell_size[dd]) *
                    std::numeric_limits<double>::epsilon() * 5.) {
          std::cerr << "# warning: ghost idx out of lower bound (ignored if "
                       "warned for more than "
                    << MAX_WARN_IDX_OUT_OF_BOUND << " times) " << std::endl;
          count_warning_ghost_idx_lower++;
        }
        idx[dd] = ext_stt[dd];
      } else if (idx[dd] >= ext_end[dd]) {
        if (count_warning_ghost_idx_upper < MAX_WARN_IDX_OUT_OF_BOUND) {
          std::cerr << "# warning: ghost idx out of upper bound (ignored if "
                       "warned for more than "
                    << MAX_WARN_IDX_OUT_OF_BOUND << " times) " << std::endl;
          count_warning_ghost_idx_upper++;
        }
        idx[dd] = ext_end[dd] - 1;
      }
      idx[dd] += idx_orig_shift[dd];
    }
    clist[collapse_index(idx, ext_ncell)].push_back(ii);
  }
}

void build_clist(std::vector<std::vector<int> >& clist,
                 const std::vector<double>& coord,
                 const std::vector<int>& sel,
                 const std::vector<int>& nat_stt,
                 const std::vector<int>& nat_end,
                 const SimulationRegion<double>& region) {
  static int count_warning_loc_idx_lower = 0;
  static int count_warning_loc_idx_upper = 0;
  // compute region info, in terms of internal coord
  int nall = coord.size() / 3;
  std::vector<int> nat_ncell(3);
  for (int dd = 0; dd < 3; ++dd) {
    nat_ncell[dd] = nat_end[dd] - nat_stt[dd];
  }
  int ncell = nat_ncell[0] * nat_ncell[1] * nat_ncell[2];
  std::vector<double> cell_size(3);
  for (int dd = 0; dd < 3; ++dd) {
    cell_size[dd] = 1. / nat_end[dd];
  }
  std::vector<double> nat_orig(3);
  for (int dd = 0; dd < 3; ++dd) {
    nat_orig[dd] = nat_stt[dd] * cell_size[dd];
  }

  // allocate the reserve the cell list
  clist.resize(ncell);
  int esti_natom_per_cell = nall / ncell + 10;
  for (unsigned ii = 0; ii < clist.size(); ++ii) {
    clist[ii].clear();
    clist[ii].reserve(esti_natom_per_cell);
  }

  // build the cell list
  for (unsigned _ = 0; _ < sel.size(); ++_) {
    int ii = sel[_];
    double inter[3];
    region.phys2Inter(inter, &(coord[ii * 3]));
    std::vector<int> idx(3);
    for (int dd = 0; dd < 3; ++dd) {
      idx[dd] = (inter[dd] - nat_orig[dd]) / cell_size[dd];
      if (inter[dd] - nat_orig[dd] < 0.) {
        idx[dd]--;
      }
      if (idx[dd] < nat_stt[dd]) {
        if (count_warning_loc_idx_lower < MAX_WARN_IDX_OUT_OF_BOUND) {
          std::cerr << "# warning: loc idx out of lower bound (ignored if "
                       "warned for more than "
                    << MAX_WARN_IDX_OUT_OF_BOUND << " times) " << std::endl;
          count_warning_loc_idx_lower++;
        }
        idx[dd] = nat_stt[dd];
      } else if (idx[dd] >= nat_end[dd]) {
        if (count_warning_loc_idx_upper < MAX_WARN_IDX_OUT_OF_BOUND) {
          std::cerr << "# warning: loc idx out of upper bound (ignored if "
                       "warned for more than "
                    << MAX_WARN_IDX_OUT_OF_BOUND << " times) " << std::endl;
          count_warning_loc_idx_upper++;
        }
        idx[dd] = nat_end[dd] - 1;
      }
    }
    clist[collapse_index(idx, nat_ncell)].push_back(ii);
  }
}

void build_nlist_cell(std::vector<std::vector<int> >& nlist0,
                      std::vector<std::vector<int> >& nlist1,
                      const int& cidx,
                      const int& tidx,
                      const std::vector<std::vector<int> >& clist,
                      const std::vector<double>& coord,
                      const double& rc02,
                      const double& rc12,
                      const std::vector<int>& shift = {0, 0, 0},
                      const std::vector<double>& boxt = {0., 0., 0., 0., 0., 0.,
                                                         0., 0., 0.}) {
  int nloc = nlist0.size();
  // loop over c (current) cell
  for (unsigned ii = 0; ii < clist[cidx].size(); ++ii) {
    int i_idx = clist[cidx][ii];
    // assert (i_idx < nloc);
    // loop over t (target) cell
    for (unsigned jj = 0; jj < clist[tidx].size(); ++jj) {
      int j_idx = clist[tidx][jj];
      if (cidx == tidx && j_idx <= i_idx) {
        continue;
      }
      double diff[3];
      for (int dd0 = 0; dd0 < 3; ++dd0) {
        diff[dd0] = coord[i_idx * 3 + dd0] - coord[j_idx * 3 + dd0];
        for (int dd1 = 0; dd1 < 3; ++dd1) {
          diff[dd0] += shift[dd1] * boxt[3 * dd1 + dd0];
        }
      }
      double r2 = deepmd::dot3(diff, diff);
      if (r2 < rc02) {
        if (i_idx < nloc) {
          nlist0[i_idx].push_back(j_idx);
        }
        if (j_idx < nloc) {
          nlist0[j_idx].push_back(i_idx);
        }
      } else if (r2 < rc12) {
        if (i_idx < nloc) {
          nlist1[i_idx].push_back(j_idx);
        }
        if (j_idx < nloc) {
          nlist1[j_idx].push_back(i_idx);
        }
      }
    }
  }
}

void build_nlist_cell(std::vector<std::vector<int> >& nlist0,
                      std::vector<std::vector<int> >& nlist1,
                      const int& cidx,
                      const int& tidx,
                      const std::vector<std::vector<int> >& clist0,
                      const std::vector<std::vector<int> >& clist1,
                      const std::vector<double>& coord,
                      const double& rc02,
                      const double& rc12,
                      const std::vector<int>& shift = {0, 0, 0},
                      const std::vector<double>& boxt = {0., 0., 0., 0., 0., 0.,
                                                         0., 0., 0.}) {
  // loop over c (current) cell
  for (unsigned ii = 0; ii < clist0[cidx].size(); ++ii) {
    int i_idx = clist0[cidx][ii];
    if (i_idx >= nlist0.size()) {
      continue;
    }
    // loop over t (target) cell
    for (unsigned jj = 0; jj < clist1[tidx].size(); ++jj) {
      int j_idx = clist1[tidx][jj];
      if (cidx == tidx && j_idx == i_idx) {
        continue;
      }
      double diff[3];
      for (int dd0 = 0; dd0 < 3; ++dd0) {
        diff[dd0] = coord[i_idx * 3 + dd0] - coord[j_idx * 3 + dd0];
        for (int dd1 = 0; dd1 < 3; ++dd1) {
          diff[dd0] += shift[dd1] * boxt[3 * dd1 + dd0];
        }
      }
      double r2 = deepmd::dot3(diff, diff);
      if (r2 < rc02) {
        nlist0[i_idx].push_back(j_idx);
      } else if (r2 < rc12) {
        nlist1[i_idx].push_back(j_idx);
      }
    }
  }
}

void build_nlist(std::vector<std::vector<int> >& nlist0,
                 std::vector<std::vector<int> >& nlist1,
                 const std::vector<double>& coord,
                 const int& nloc,
                 const double& rc0,
                 const double& rc1,
                 const std::vector<int>& nat_stt_,
                 const std::vector<int>& nat_end_,
                 const std::vector<int>& ext_stt_,
                 const std::vector<int>& ext_end_,
                 const SimulationRegion<double>& region,
                 const std::vector<int>& global_grid) {
  // normalize the index
  // i require that the ext_stt = {0, 0, 0}
  std::vector<int> nat_stt(nat_stt_);
  std::vector<int> nat_end(nat_end_);
  std::vector<int> ext_stt(ext_stt_);
  std::vector<int> ext_end(ext_end_);

  // compute the clist
  std::vector<std::vector<int> > clist;
  build_clist(clist, coord, nloc, nat_stt, nat_end, ext_stt, ext_end, region,
              global_grid);

  // compute the region info
  int nall = coord.size() / 3;
  std::vector<int> ext_ncell(3);
  for (int dd = 0; dd < 3; ++dd) {
    ext_ncell[dd] = ext_end[dd] - ext_stt[dd];
  }

  // compute number of iter according to the cut-off
  assert(rc0 <= rc1);
  std::vector<int> niter(3);
  double to_face[3];
  region.toFaceDistance(to_face);
  for (int dd = 0; dd < 3; ++dd) {
    double cell_size = to_face[dd] / nat_end[dd];
    niter[dd] = rc1 / cell_size;
    if (niter[dd] * cell_size < rc1) {
      niter[dd] += 1;
    }
    assert(niter[dd] * cell_size >= rc1);
  }
  // check the validity of the iters
  for (int dd = 0; dd < 3; ++dd) {
    assert(nat_stt[dd] - niter[dd] >= ext_stt[dd]);
    assert(nat_end[dd] + niter[dd] <= ext_end[dd]);
  }

  // allocate the nlists
  double density = nloc / region.getVolume();
  nlist0.resize(nloc);
  for (int ii = 0; ii < nloc; ++ii) {
    nlist0[ii].clear();
    int esti = 4. / 3. * 3.14 * (rc0 * rc0 * rc0) * density * 1.5 + 20;
    if (esti < 0) {
      esti = 10;
    }
    nlist0[ii].reserve(esti);
  }
  nlist1.resize(nloc);
  for (int ii = 0; ii < nloc; ++ii) {
    nlist1[ii].clear();
    int esti =
        4. / 3. * 3.14 * (rc1 * rc1 * rc1 - rc0 * rc0 * rc0) * density * 1.5 +
        20;
    if (esti < 0) {
      esti = 10;
    }
    nlist1[ii].reserve(esti);
  }

  // shift of the idx origin
  std::vector<int> idx_orig_shift(3);
  for (int dd = 0; dd < 3; ++dd) {
    idx_orig_shift[dd] = nat_stt[dd] - ext_stt[dd];
  }

  // compute the nlists
  double rc02 = 0;
  if (rc0 > 0) {
    rc02 = rc0 * rc0;
  }
  double rc12 = rc1 * rc1;
  std::vector<int> cidx(3);
  for (cidx[0] = nat_stt[0]; cidx[0] < nat_end[0]; ++cidx[0]) {
    for (cidx[1] = nat_stt[1]; cidx[1] < nat_end[1]; ++cidx[1]) {
      for (cidx[2] = nat_stt[2]; cidx[2] < nat_end[2]; ++cidx[2]) {
        std::vector<int> mcidx(3);
        for (int dd = 0; dd < 3; ++dd) {
          mcidx[dd] = cidx[dd] + idx_orig_shift[dd];
        }
        int clp_cidx = collapse_index(mcidx, ext_ncell);
        std::vector<int> tidx(3);
        for (tidx[0] = cidx[0] - niter[0]; tidx[0] < cidx[0] + niter[0] + 1;
             ++tidx[0]) {
          for (tidx[1] = cidx[1] - niter[1]; tidx[1] < cidx[1] + niter[1] + 1;
               ++tidx[1]) {
            for (tidx[2] = cidx[2] - niter[2]; tidx[2] < cidx[2] + niter[2] + 1;
                 ++tidx[2]) {
              std::vector<int> mtidx(3);
              for (int dd = 0; dd < 3; ++dd) {
                mtidx[dd] = tidx[dd] + idx_orig_shift[dd];
              }
              int clp_tidx = collapse_index(mtidx, ext_ncell);
              if (is_loc(tidx, nat_stt, nat_end) && clp_tidx < clp_cidx) {
                continue;
              }
              build_nlist_cell(nlist0, nlist1, clp_cidx, clp_tidx, clist, coord,
                               rc02, rc12);
            }
          }
        }
      }
    }
  }
}

// assume nat grid is the global grid. only used for serial simulations
void build_nlist(std::vector<std::vector<int> >& nlist0,
                 std::vector<std::vector<int> >& nlist1,
                 const std::vector<double>& coord,
                 const double& rc0,
                 const double& rc1,
                 const std::vector<int>& grid,
                 const SimulationRegion<double>& region) {
  // assuming nloc == nall
  int nloc = coord.size() / 3;
  // compute the clist
  std::vector<int> nat_stt(3, 0);
  std::vector<int> nat_end(grid);
  std::vector<std::vector<int> > clist;
  build_clist(clist, coord, nloc, nat_stt, nat_end, nat_stt, nat_end, region,
              nat_end);

  // compute the region info
  int nall = coord.size() / 3;
  std::vector<int> nat_ncell(3);
  for (int dd = 0; dd < 3; ++dd) {
    nat_ncell[dd] = nat_end[dd] - nat_stt[dd];
  }

  // compute number of iter according to the cut-off
  assert(rc0 <= rc1);
  std::vector<int> niter(3);
  double to_face[3];
  region.toFaceDistance(to_face);
  for (int dd = 0; dd < 3; ++dd) {
    double cell_size = to_face[dd] / nat_end[dd];
    niter[dd] = rc1 / cell_size;
    if (niter[dd] * cell_size < rc1) {
      niter[dd] += 1;
    }
    assert(niter[dd] * cell_size >= rc1);
  }
  // check the validity of the iters
  for (int dd = 0; dd < 3; ++dd) {
    assert(niter[dd] <= (nat_end[dd] - nat_stt[dd]) / 2);
  }

  // allocate the nlists
  double density = nall / region.getVolume();
  nlist0.resize(nloc);
  for (int ii = 0; ii < nloc; ++ii) {
    nlist0[ii].clear();
    nlist0[ii].reserve(4. / 3. * 3.14 * (rc0 * rc0 * rc0) * density * 1.5 + 20);
  }
  nlist1.resize(nloc);
  for (int ii = 0; ii < nloc; ++ii) {
    nlist1[ii].clear();
    nlist1[ii].reserve(4. / 3. * 3.14 * (rc1 * rc1 * rc1 - rc0 * rc0 * rc0) *
                           density * 1.5 +
                       20);
  }

  // physical cell size
  std::vector<double> phys_cs(9);
  for (int dd = 0; dd < 9; ++dd) {
    phys_cs[dd] = region.getBoxTensor()[dd];
  }

  // compute the nlists
  double rc02 = 0;
  if (rc0 > 0) {
    rc02 = rc0 * rc0;
  }
  double rc12 = rc1 * rc1;

#ifdef HALF_NEIGHBOR_LIST
  std::vector<int> cidx(3);
  for (cidx[0] = nat_stt[0]; cidx[0] < nat_end[0]; ++cidx[0]) {
    for (cidx[1] = nat_stt[1]; cidx[1] < nat_end[1]; ++cidx[1]) {
      for (cidx[2] = nat_stt[2]; cidx[2] < nat_end[2]; ++cidx[2]) {
#else
  int idx_range[3];
  idx_range[0] = nat_end[0] - nat_stt[0];
  idx_range[1] = nat_end[1] - nat_stt[1];
  idx_range[2] = nat_end[2] - nat_stt[2];
  int idx_total = idx_range[0] * idx_range[1] * idx_range[2];
#pragma omp parallel for
  for (int tmpidx = 0; tmpidx < idx_total; ++tmpidx) {
    std::vector<int> cidx(3);
    cidx[0] = nat_stt[0] + tmpidx / (idx_range[1] * idx_range[2]);
    int tmpidx1 = tmpidx - cidx[0] * idx_range[1] * idx_range[2];
    cidx[1] = nat_stt[1] + tmpidx1 / idx_range[2];
    cidx[2] = nat_stt[2] + tmpidx1 - cidx[1] * idx_range[2];
    {
      {
#endif
        int clp_cidx = collapse_index(cidx, nat_ncell);
        std::vector<int> tidx(3);
        std::vector<int> stidx(3);
        std::vector<int> shift(3);
        for (tidx[0] = cidx[0] - niter[0]; tidx[0] < cidx[0] + niter[0] + 1;
             ++tidx[0]) {
          shift[0] = 0;
          if (tidx[0] < 0) {
            shift[0] += 1;
          } else if (tidx[0] >= nat_ncell[0]) {
            shift[0] -= 1;
          }
          stidx[0] = tidx[0] + shift[0] * nat_ncell[0];
          for (tidx[1] = cidx[1] - niter[1]; tidx[1] < cidx[1] + niter[1] + 1;
               ++tidx[1]) {
            shift[1] = 0;
            if (tidx[1] < 0) {
              shift[1] += 1;
            } else if (tidx[1] >= nat_ncell[1]) {
              shift[1] -= 1;
            }
            stidx[1] = tidx[1] + shift[1] * nat_ncell[1];
            for (tidx[2] = cidx[2] - niter[2]; tidx[2] < cidx[2] + niter[2] + 1;
                 ++tidx[2]) {
              shift[2] = 0;
              if (tidx[2] < 0) {
                shift[2] += 1;
              } else if (tidx[2] >= nat_ncell[2]) {
                shift[2] -= 1;
              }
              stidx[2] = tidx[2] + shift[2] * nat_ncell[2];
              int clp_tidx = collapse_index(stidx, nat_ncell);
#ifdef HALF_NEIGHBOR_LIST
              if (clp_tidx < clp_cidx) {
                continue;
              }
              build_nlist_cell(nlist0, nlist1, clp_cidx, clp_tidx, clist, coord,
                               rc02, rc12, shift, phys_cs);
#else
              build_nlist_cell(nlist0, nlist1, clp_cidx, clp_tidx, clist, clist,
                               coord, rc02, rc12, shift, phys_cs);
#endif
            }
          }
        }
      }
    }
  }
}

void build_nlist(std::vector<std::vector<int> >& nlist0,
                 std::vector<std::vector<int> >& nlist1,
                 const std::vector<double>& coord,
                 const std::vector<int>& sel0,
                 const std::vector<int>& sel1,
                 const double& rc0,
                 const double& rc1,
                 const std::vector<int>& grid,
                 const SimulationRegion<double>& region) {
  int nloc = coord.size() / 3;
  // compute the clist
  std::vector<int> nat_stt(3, 0);
  std::vector<int> nat_end(grid);
  std::vector<std::vector<int> > clist0, clist1;
  build_clist(clist0, coord, sel0, nat_stt, nat_end, region);
  build_clist(clist1, coord, sel1, nat_stt, nat_end, region);

  // compute the region info
  int nall = coord.size() / 3;
  std::vector<int> nat_ncell(3);
  for (int dd = 0; dd < 3; ++dd) {
    nat_ncell[dd] = nat_end[dd] - nat_stt[dd];
  }

  // compute number of iter according to the cut-off
  assert(rc0 <= rc1);
  std::vector<int> niter(3);
  double to_face[3];
  region.toFaceDistance(to_face);
  for (int dd = 0; dd < 3; ++dd) {
    double cell_size = to_face[dd] / nat_end[dd];
    niter[dd] = rc1 / cell_size;
    if (niter[dd] * cell_size < rc1) {
      niter[dd] += 1;
    }
    assert(niter[dd] * cell_size >= rc1);
  }
  // check the validity of the iters
  for (int dd = 0; dd < 3; ++dd) {
    assert(niter[dd] <= (nat_end[dd] - nat_stt[dd]) / 2);
  }

  // allocate the nlists
  double density = nall / region.getVolume();
  nlist0.resize(nloc);
  for (int ii = 0; ii < nloc; ++ii) {
    nlist0[ii].clear();
    nlist0[ii].reserve(4. / 3. * 3.14 * (rc0 * rc0 * rc0) * density * 1.5 + 20);
  }
  nlist1.resize(nloc);
  for (int ii = 0; ii < nloc; ++ii) {
    nlist1[ii].clear();
    nlist1[ii].reserve(4. / 3. * 3.14 * (rc1 * rc1 * rc1 - rc0 * rc0 * rc0) *
                           density * 1.5 +
                       20);
  }

  // physical cell size
  std::vector<double> phys_cs(9);
  for (int dd = 0; dd < 9; ++dd) {
    phys_cs[dd] = region.getBoxTensor()[dd];
  }

  // compute the nlists
  double rc02 = 0;
  if (rc0 > 0) {
    rc02 = rc0 * rc0;
  }
  double rc12 = rc1 * rc1;
  std::vector<int> cidx(3);
  for (cidx[0] = nat_stt[0]; cidx[0] < nat_end[0]; ++cidx[0]) {
    for (cidx[1] = nat_stt[1]; cidx[1] < nat_end[1]; ++cidx[1]) {
      for (cidx[2] = nat_stt[2]; cidx[2] < nat_end[2]; ++cidx[2]) {
        int clp_cidx = collapse_index(cidx, nat_ncell);
        std::vector<int> tidx(3);
        std::vector<int> stidx(3);
        std::vector<int> shift(3);
        for (tidx[0] = cidx[0] - niter[0]; tidx[0] < cidx[0] + niter[0] + 1;
             ++tidx[0]) {
          shift[0] = 0;
          if (tidx[0] < 0) {
            shift[0] += 1;
          } else if (tidx[0] >= nat_ncell[0]) {
            shift[0] -= 1;
          }
          stidx[0] = tidx[0] + shift[0] * nat_ncell[0];
          for (tidx[1] = cidx[1] - niter[1]; tidx[1] < cidx[1] + niter[1] + 1;
               ++tidx[1]) {
            shift[1] = 0;
            if (tidx[1] < 0) {
              shift[1] += 1;
            } else if (tidx[1] >= nat_ncell[1]) {
              shift[1] -= 1;
            }
            stidx[1] = tidx[1] + shift[1] * nat_ncell[1];
            for (tidx[2] = cidx[2] - niter[2]; tidx[2] < cidx[2] + niter[2] + 1;
                 ++tidx[2]) {
              shift[2] = 0;
              if (tidx[2] < 0) {
                shift[2] += 1;
              } else if (tidx[2] >= nat_ncell[2]) {
                shift[2] -= 1;
              }
              stidx[2] = tidx[2] + shift[2] * nat_ncell[2];
              int clp_tidx = collapse_index(stidx, nat_ncell);
              build_nlist_cell(nlist0, nlist1, clp_cidx, clp_tidx, clist0,
                               clist1, coord, rc02, rc12, shift, phys_cs);
            }
          }
        }
      }
    }
  }
}

void build_nlist(std::vector<std::vector<int> >& nlist0,
                 std::vector<std::vector<int> >& nlist1,
                 const std::vector<double>& posi3,
                 const double& rc0_,
                 const double& rc1_,
                 const SimulationRegion<double>* region) {
  double rc0(rc0_);
  double rc1(rc1_);
  assert(rc0 <= rc1);
  double rc02 = rc0 * rc0;
  // negative rc0 means not applying rc0
  if (rc0 < 0) {
    rc02 = 0;
  }
  double rc12 = rc1 * rc1;

  unsigned natoms = posi3.size() / 3;
  nlist0.clear();
  nlist1.clear();
  nlist0.resize(natoms);
  nlist1.resize(natoms);
  for (unsigned ii = 0; ii < natoms; ++ii) {
    nlist0[ii].reserve(60);
    nlist1[ii].reserve(60);
  }
  for (unsigned ii = 0; ii < natoms; ++ii) {
    for (unsigned jj = ii + 1; jj < natoms; ++jj) {
      double diff[3];
      if (region != NULL) {
        region->diffNearestNeighbor(posi3[jj * 3 + 0], posi3[jj * 3 + 1],
                                    posi3[jj * 3 + 2], posi3[ii * 3 + 0],
                                    posi3[ii * 3 + 1], posi3[ii * 3 + 2],
                                    diff[0], diff[1], diff[2]);
      } else {
        diff[0] = posi3[jj * 3 + 0] - posi3[ii * 3 + 0];
        diff[1] = posi3[jj * 3 + 1] - posi3[ii * 3 + 1];
        diff[2] = posi3[jj * 3 + 2] - posi3[ii * 3 + 2];
      }
      double r2 = deepmd::dot3(diff, diff);
      if (r2 < rc02) {
        nlist0[ii].push_back(jj);
        nlist0[jj].push_back(ii);
      } else if (r2 < rc12) {
        nlist1[ii].push_back(jj);
        nlist1[jj].push_back(ii);
      }
    }
  }
}

static int compute_pbc_shift(int idx, int ncell) {
  int shift = 0;
  if (idx < 0) {
    shift = 1;
    while (idx + shift * ncell < 0) {
      shift++;
    }
  } else if (idx >= ncell) {
    shift = -1;
    while (idx + shift * ncell >= ncell) {
      shift--;
    }
  }
  assert(idx + shift * ncell >= 0 && idx + shift * ncell < ncell);
  return shift;
}

void copy_coord(std::vector<double>& out_c,
                std::vector<int>& out_t,
                std::vector<int>& mapping,
                std::vector<int>& ncell,
                std::vector<int>& ngcell,
                const std::vector<double>& in_c,
                const std::vector<int>& in_t,
                const double& rc,
                const SimulationRegion<double>& region) {
  int nloc = in_c.size() / 3;
  assert(nloc == in_t.size());

  ncell.resize(3);
  ngcell.resize(3);
  double to_face[3];
  double cell_size[3];
  region.toFaceDistance(to_face);
  for (int dd = 0; dd < 3; ++dd) {
    ncell[dd] = to_face[dd] / rc;
    if (ncell[dd] == 0) {
      ncell[dd] = 1;
    }
    cell_size[dd] = to_face[dd] / ncell[dd];
    ngcell[dd] = int(rc / cell_size[dd]) + 1;
    assert(cell_size[dd] * ngcell[dd] >= rc);
  }
  int total_ncell = (2 * ngcell[0] + ncell[0]) * (2 * ngcell[1] + ncell[1]) *
                    (2 * ngcell[2] + ncell[2]);
  int loc_ncell = (ncell[0]) * (ncell[1]) * (ncell[2]);
  int esti_ntotal = total_ncell / loc_ncell * nloc + 10;

  // alloc
  out_c.reserve(esti_ntotal * 6);
  out_t.reserve(esti_ntotal * 2);
  mapping.reserve(esti_ntotal * 2);

  // build cell list
  std::vector<std::vector<int> > clist;
  std::vector<int> nat_stt(3, 0);
  build_clist(clist, in_c, nloc, nat_stt, ncell, nat_stt, ncell, region, ncell);

  // copy local atoms
  out_c.resize(static_cast<size_t>(nloc) * 3);
  out_t.resize(nloc);
  mapping.resize(nloc);
  copy(in_c.begin(), in_c.end(), out_c.begin());
  copy(in_t.begin(), in_t.end(), out_t.begin());
  for (int ii = 0; ii < nloc; ++ii) {
    mapping[ii] = ii;
  }

  // push ghost
  std::vector<int> ii(3), jj(3), pbc_shift(3, 0);
  double pbc_shift_d[3];
  for (ii[0] = -ngcell[0]; ii[0] < ncell[0] + ngcell[0]; ++ii[0]) {
    pbc_shift[0] = compute_pbc_shift(ii[0], ncell[0]);
    pbc_shift_d[0] = pbc_shift[0];
    jj[0] = ii[0] + pbc_shift[0] * ncell[0];
    for (ii[1] = -ngcell[1]; ii[1] < ncell[1] + ngcell[1]; ++ii[1]) {
      pbc_shift[1] = compute_pbc_shift(ii[1], ncell[1]);
      pbc_shift_d[1] = pbc_shift[1];
      jj[1] = ii[1] + pbc_shift[1] * ncell[1];
      for (ii[2] = -ngcell[2]; ii[2] < ncell[2] + ngcell[2]; ++ii[2]) {
        pbc_shift[2] = compute_pbc_shift(ii[2], ncell[2]);
        pbc_shift_d[2] = pbc_shift[2];
        jj[2] = ii[2] + pbc_shift[2] * ncell[2];
        // local cell, continue
        if (ii[0] >= 0 && ii[0] < ncell[0] && ii[1] >= 0 && ii[1] < ncell[1] &&
            ii[2] >= 0 && ii[2] < ncell[2]) {
          continue;
        }
        double shift_v[3];
        region.inter2Phys(shift_v, pbc_shift_d);
        int cell_idx = collapse_index(jj, ncell);
        std::vector<int>& cur_clist = clist[cell_idx];
        for (int kk = 0; kk < cur_clist.size(); ++kk) {
          int p_idx = cur_clist[kk];
          double shifted_coord[3];
          out_c.push_back(in_c[p_idx * 3 + 0] - shift_v[0]);
          out_c.push_back(in_c[p_idx * 3 + 1] - shift_v[1]);
          out_c.push_back(in_c[p_idx * 3 + 2] - shift_v[2]);
          out_t.push_back(in_t[p_idx]);
          mapping.push_back(p_idx);
          // double phys[3];
          // for (int dd = 0; dd < 3; ++dd) phys[dd] = in_c[p_idx*3+dd] -
          // shift_v[dd]; double inter[3]; region.phys2Inter(inter, phys); if (
          // inter[0] >= 0 && inter[0] < 1 && 	inter[1] >= 0 && inter[1] < 1 &&
          // 	inter[2] >= 0 && inter[2] < 1 ){
          //   std::cout << out_c.size()  / 3 << " "
          // 	 << inter[0] << " "
          // 	 << inter[1] << " "
          // 	 << inter[2] << " "
          // 	 << std::endl;
          //   std::cout << "err here inner" << std::endl;
          //   exit(1);
          // }
        }
      }
    }
  }
}

using namespace deepmd;

void deepmd::convert_nlist(InputNlist& to_nlist,
                           std::vector<std::vector<int> >& from_nlist) {
  to_nlist.inum = from_nlist.size();
  for (int ii = 0; ii < to_nlist.inum; ++ii) {
    to_nlist.ilist[ii] = ii;
    to_nlist.numneigh[ii] = from_nlist[ii].size();
    to_nlist.firstneigh[ii] = &from_nlist[ii][0];
  }
}

int deepmd::max_numneigh(const InputNlist& nlist) {
  int max_num = 0;
  for (int ii = 0; ii < nlist.inum; ++ii) {
    if (nlist.numneigh[ii] > max_num) {
      max_num = nlist.numneigh[ii];
    }
  }
  return max_num;
}

template <typename FPTYPE>
int deepmd::build_nlist_cpu(InputNlist& nlist,
                            int* max_list_size,
                            const FPTYPE* c_cpy,
                            const int& nloc,
                            const int& nall,
                            const int& mem_size_,
                            const float& rcut) {
  const int mem_size = mem_size_;
  *max_list_size = 0;
  nlist.inum = nloc;
  FPTYPE rcut2 = rcut * rcut;
  std::vector<int> jlist;
  jlist.reserve(mem_size);
  for (int ii = 0; ii < nlist.inum; ++ii) {
    nlist.ilist[ii] = ii;
    jlist.clear();
    for (int jj = 0; jj < nall; ++jj) {
      if (jj == ii) {
        continue;
      }
      FPTYPE diff[3];
      for (int dd = 0; dd < 3; ++dd) {
        diff[dd] = c_cpy[ii * 3 + dd] - c_cpy[jj * 3 + dd];
      }
      FPTYPE diff2 = deepmd::dot3(diff, diff);
      if (diff2 < rcut2) {
        jlist.push_back(jj);
      }
    }
    if (jlist.size() > mem_size) {
      *max_list_size = jlist.size();
      return 1;
    } else {
      int list_size = jlist.size();
      nlist.numneigh[ii] = list_size;
      if (list_size > *max_list_size) {
        *max_list_size = list_size;
      }
      std::copy(jlist.begin(), jlist.end(), nlist.firstneigh[ii]);
    }
  }
  return 0;
}

void deepmd::use_nei_info_cpu(int* nlist,
                              int* ntype,
                              bool* nmask,
                              const int* type,
                              const int* nlist_map,
                              const int nloc,
                              const int nnei,
                              const int ntypes,
                              const bool b_nlist_map) {
  if (b_nlist_map) {
    for (int ii = 0; ii < nloc; ++ii) {
      for (int jj = 0; jj < nnei; ++jj) {
        int nlist_idx = ii * nnei + jj;
        int record = nlist[nlist_idx];
        if (record >= 0) {
          int temp = nlist_map[record];
          nlist[nlist_idx] = temp;
          ntype[nlist_idx] = type[temp];
          nmask[nlist_idx] = true;
        } else {
          ntype[nlist_idx] = ntypes;
          nmask[nlist_idx] = false;
        }
      }
    }
  } else {
    for (int ii = 0; ii < nloc; ++ii) {
      for (int jj = 0; jj < nnei; ++jj) {
        int nlist_idx = ii * nnei + jj;
        int record = nlist[nlist_idx];
        if (record >= 0) {
          ntype[nlist_idx] = type[record];
          nmask[nlist_idx] = true;
        } else {
          ntype[nlist_idx] = ntypes;
          nmask[nlist_idx] = false;
        }
      }
    }
  }
}

template int deepmd::build_nlist_cpu<double>(InputNlist& nlist,
                                             int* max_list_size,
                                             const double* c_cpy,
                                             const int& nloc,
                                             const int& nall,
                                             const int& mem_size,
                                             const float& rcut);

template int deepmd::build_nlist_cpu<float>(InputNlist& nlist,
                                            int* max_list_size,
                                            const float* c_cpy,
                                            const int& nloc,
                                            const int& nall,
                                            const int& mem_size,
                                            const float& rcut);

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
void deepmd::convert_nlist_gpu_device(InputNlist& gpu_nlist,
                                      InputNlist& cpu_nlist,
                                      int*& gpu_memory,
                                      const int& max_nbor_size) {
  const int inum = cpu_nlist.inum;
  gpu_nlist.inum = inum;
  malloc_device_memory(gpu_nlist.ilist, inum);
  malloc_device_memory(gpu_nlist.numneigh, inum);
  malloc_device_memory(gpu_nlist.firstneigh, inum);
  memcpy_host_to_device(gpu_nlist.ilist, cpu_nlist.ilist, inum);
  memcpy_host_to_device(gpu_nlist.numneigh, cpu_nlist.numneigh, inum);
  int** _firstneigh = NULL;
  _firstneigh = (int**)malloc(sizeof(int*) * inum);
  for (int ii = 0; ii < inum; ii++) {
    memcpy_host_to_device(gpu_memory + ii * max_nbor_size,
                          cpu_nlist.firstneigh[ii], cpu_nlist.numneigh[ii]);
    _firstneigh[ii] = gpu_memory + ii * max_nbor_size;
  }
  memcpy_host_to_device(gpu_nlist.firstneigh, _firstneigh, inum);
  free(_firstneigh);
}

void deepmd::free_nlist_gpu_device(InputNlist& gpu_nlist) {
  delete_device_memory(gpu_nlist.ilist);
  delete_device_memory(gpu_nlist.numneigh);
  delete_device_memory(gpu_nlist.firstneigh);
}
#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
