// SPDX-License-Identifier: LGPL-3.0-or-later
#include <gtest/gtest.h>

#include <algorithm>
#include <cmath>
#include <fstream>
#include <vector>

#include "deepmd.hpp"
#include "test_utils.h"

template <class VALUETYPE>
class TestInferDeepPotANFrames : public ::testing::Test {
 protected:
  std::vector<VALUETYPE> coord = {
      12.83, 2.56, 2.18, 12.09, 2.87, 2.74, 00.25, 3.32, 1.68,
      3.36,  3.00, 1.81, 3.51,  2.51, 2.60, 4.27,  3.22, 1.56,
      12.83, 2.56, 2.18, 12.09, 2.87, 2.74, 00.25, 3.32, 1.68,
      3.36,  3.00, 1.81, 3.51,  2.51, 2.60, 4.27,  3.22, 1.56};
  std::vector<int> atype = {0, 1, 1, 0, 1, 1};
  std::vector<VALUETYPE> box = {13., 0., 0., 0., 13., 0., 0., 0., 13.,
                                13., 0., 0., 0., 13., 0., 0., 0., 13.};
  std::vector<VALUETYPE> expected_e = {
      -9.275780747115504710e+01, -1.863501786584258468e+02,
      -1.863392472863538103e+02, -9.279281325486221021e+01,
      -1.863671545232153903e+02, -1.863619822847602165e+02,
      -9.275780747115504710e+01, -1.863501786584258468e+02,
      -1.863392472863538103e+02, -9.279281325486221021e+01,
      -1.863671545232153903e+02, -1.863619822847602165e+02};
  std::vector<VALUETYPE> expected_f = {
      -3.034045420701179663e-01, 8.405844663871177014e-01,
      7.696947487118485642e-02,  7.662001266663505117e-01,
      -1.880601391333554251e-01, -6.183333871091722944e-01,
      -5.036172391059643427e-01, -6.529525836149027151e-01,
      5.432962643022043459e-01,  6.382357912332115024e-01,
      -1.748518296794561167e-01, 3.457363524891907125e-01,
      1.286482986991941552e-03,  3.757251165286925043e-01,
      -5.972588700887541124e-01, -5.987006197104716154e-01,
      -2.004450304880958100e-01, 2.495901655353461868e-01,
      -3.034045420701179663e-01, 8.405844663871177014e-01,
      7.696947487118485642e-02,  7.662001266663505117e-01,
      -1.880601391333554251e-01, -6.183333871091722944e-01,
      -5.036172391059643427e-01, -6.529525836149027151e-01,
      5.432962643022043459e-01,  6.382357912332115024e-01,
      -1.748518296794561167e-01, 3.457363524891907125e-01,
      1.286482986991941552e-03,  3.757251165286925043e-01,
      -5.972588700887541124e-01, -5.987006197104716154e-01,
      -2.004450304880958100e-01, 2.495901655353461868e-01};
  std::vector<VALUETYPE> expected_v = {
      -2.912234126853306959e-01, -3.800610846612756388e-02,
      2.776624987489437202e-01,  -5.053761003913598976e-02,
      -3.152373041953385746e-01, 1.060894290092162379e-01,
      2.826389131596073745e-01,  1.039129970665329250e-01,
      -2.584378792325942586e-01, -3.121722367954994914e-01,
      8.483275876786681990e-02,  2.524662342344257682e-01,
      4.142176771106586414e-02,  -3.820285230785245428e-02,
      -2.727311173065460545e-02, 2.668859789777112135e-01,
      -6.448243569420382404e-02, -2.121731470426218846e-01,
      -8.624335220278558922e-02, -1.809695356746038597e-01,
      1.529875294531883312e-01,  -1.283658185172031341e-01,
      -1.992682279795223999e-01, 1.409924999632362341e-01,
      1.398322735274434292e-01,  1.804318474574856390e-01,
      -1.470309318999652726e-01, -2.593983661598450730e-01,
      -4.236536279233147489e-02, 3.386387920184946720e-02,
      -4.174017537818433543e-02, -1.003500282164128260e-01,
      1.525690815194478966e-01,  3.398976109910181037e-02,
      1.522253908435125536e-01,  -2.349125581341701963e-01,
      9.515545977581392825e-04,  -1.643218849228543846e-02,
      1.993234765412972564e-02,  6.027265332209678569e-04,
      -9.563256398907417355e-02, 1.510815124001868293e-01,
      -7.738094816888557714e-03, 1.502832772532304295e-01,
      -2.380965783745832010e-01, -2.309456719810296654e-01,
      -6.666961081213038098e-02, 7.955566551234216632e-02,
      -8.099093777937517447e-02, -3.386641099800401927e-02,
      4.447884755740908608e-02,  1.008593228579038742e-01,
      4.556718179228393811e-02,  -6.078081273849572641e-02,
      -2.912234126853306959e-01, -3.800610846612756388e-02,
      2.776624987489437202e-01,  -5.053761003913598976e-02,
      -3.152373041953385746e-01, 1.060894290092162379e-01,
      2.826389131596073745e-01,  1.039129970665329250e-01,
      -2.584378792325942586e-01, -3.121722367954994914e-01,
      8.483275876786681990e-02,  2.524662342344257682e-01,
      4.142176771106586414e-02,  -3.820285230785245428e-02,
      -2.727311173065460545e-02, 2.668859789777112135e-01,
      -6.448243569420382404e-02, -2.121731470426218846e-01,
      -8.624335220278558922e-02, -1.809695356746038597e-01,
      1.529875294531883312e-01,  -1.283658185172031341e-01,
      -1.992682279795223999e-01, 1.409924999632362341e-01,
      1.398322735274434292e-01,  1.804318474574856390e-01,
      -1.470309318999652726e-01, -2.593983661598450730e-01,
      -4.236536279233147489e-02, 3.386387920184946720e-02,
      -4.174017537818433543e-02, -1.003500282164128260e-01,
      1.525690815194478966e-01,  3.398976109910181037e-02,
      1.522253908435125536e-01,  -2.349125581341701963e-01,
      9.515545977581392825e-04,  -1.643218849228543846e-02,
      1.993234765412972564e-02,  6.027265332209678569e-04,
      -9.563256398907417355e-02, 1.510815124001868293e-01,
      -7.738094816888557714e-03, 1.502832772532304295e-01,
      -2.380965783745832010e-01, -2.309456719810296654e-01,
      -6.666961081213038098e-02, 7.955566551234216632e-02,
      -8.099093777937517447e-02, -3.386641099800401927e-02,
      4.447884755740908608e-02,  1.008593228579038742e-01,
      4.556718179228393811e-02,  -6.078081273849572641e-02};
  int natoms;
  int nframes = 2;
  std::vector<double> expected_tot_e;
  std::vector<VALUETYPE> expected_tot_v;

  deepmd::hpp::DeepPot dp;

  void SetUp() override {
    std::string file_name = "../../tests/infer/deeppot.pbtxt";
    deepmd::hpp::convert_pbtxt_to_pb("../../tests/infer/deeppot.pbtxt",
                                     "deeppot.pb");

    dp.init("deeppot.pb");

    natoms = expected_e.size() / nframes;
    EXPECT_EQ(nframes * natoms * 3, expected_f.size());
    EXPECT_EQ(nframes * natoms * 9, expected_v.size());
    expected_tot_e.resize(nframes);
    expected_tot_v.resize(static_cast<size_t>(nframes) * 9);
    std::fill(expected_tot_e.begin(), expected_tot_e.end(), 0.);
    std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
    for (int kk = 0; kk < nframes; ++kk) {
      for (int ii = 0; ii < natoms; ++ii) {
        expected_tot_e[kk] += expected_e[kk * natoms + ii];
      }
      for (int ii = 0; ii < natoms; ++ii) {
        for (int dd = 0; dd < 9; ++dd) {
          expected_tot_v[kk * 9 + dd] +=
              expected_v[kk * natoms * 9 + ii * 9 + dd];
        }
      }
    }
  };

  void TearDown() override { remove("deeppot.pb"); };
};

TYPED_TEST_SUITE(TestInferDeepPotANFrames, ValueTypes);

TYPED_TEST(TestInferDeepPotANFrames, cpu_build_nlist) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  int& nframes = this->nframes;
  std::vector<double>& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::hpp::DeepPot& dp = this->dp;
  std::vector<double> ener;
  std::vector<VALUETYPE> force, virial;
  dp.compute(ener, force, virial, coord, atype, box);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotANFrames, cpu_build_nlist_atomic) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  int& nframes = this->nframes;
  std::vector<double>& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::hpp::DeepPot& dp = this->dp;
  std::vector<double> ener;
  std::vector<VALUETYPE> force, virial, atom_ener, atom_vir;
  dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);
  EXPECT_EQ(atom_ener.size(), nframes * natoms);
  EXPECT_EQ(atom_vir.size(), nframes * natoms * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms; ++ii) {
    EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 9; ++ii) {
    EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotANFrames, cpu_lmp_nlist) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  int& nframes = this->nframes;
  std::vector<double>& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::hpp::DeepPot& dp = this->dp;
  float rc = dp.cutoff();
  std::vector<VALUETYPE> coord_first(coord.begin(), coord.begin() + 3 * natoms);
  std::vector<VALUETYPE> box_first(box.begin(), box.begin() + 9);
  int nloc = coord_first.size() / 3;
  std::vector<VALUETYPE> coord_cpy;
  std::vector<int> atype_cpy, mapping;
  std::vector<std::vector<int> > nlist_data;
  _build_nlist<VALUETYPE>(nlist_data, coord_cpy, atype_cpy, mapping,
                          coord_first, atype, box_first, rc);
  int nall = coord_cpy.size() / 3;
  std::vector<int> ilist(nloc), numneigh(nloc);
  std::vector<int*> firstneigh(nloc);
  deepmd::hpp::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]);
  convert_nlist(inlist, nlist_data);
  std::vector<VALUETYPE> coord_cpy2(nframes * nall * 3);
  for (int ii = 0; ii < nframes; ++ii) {
    for (int jj = 0; jj < nall * 3; ++jj) {
      coord_cpy2[ii * nall * 3 + jj] = coord_cpy[jj];
    }
  }

  std::vector<double> ener;
  std::vector<VALUETYPE> force_, virial;
  dp.compute(ener, force_, virial, coord_cpy2, atype_cpy, box, nall - nloc,
             inlist, 0);
  std::vector<VALUETYPE> force;
  _fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3, nframes);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }

  std::fill(ener.begin(), ener.end(), 0.0);
  std::fill(force_.begin(), force_.end(), 0.0);
  std::fill(virial.begin(), virial.end(), 0.0);
  dp.compute(ener, force_, virial, coord_cpy2, atype_cpy, box, nall - nloc,
             inlist, 1);
  _fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3, nframes);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotANFrames, cpu_lmp_nlist_atomic) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  int& nframes = this->nframes;
  std::vector<double>& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::hpp::DeepPot& dp = this->dp;
  float rc = dp.cutoff();
  std::vector<VALUETYPE> coord_first(coord.begin(), coord.begin() + 3 * natoms);
  std::vector<VALUETYPE> box_first(box.begin(), box.begin() + 9);
  int nloc = coord_first.size() / 3;
  std::vector<VALUETYPE> coord_cpy;
  std::vector<int> atype_cpy, mapping;
  std::vector<std::vector<int> > nlist_data;
  _build_nlist<VALUETYPE>(nlist_data, coord_cpy, atype_cpy, mapping,
                          coord_first, atype, box_first, rc);
  int nall = coord_cpy.size() / 3;
  std::vector<int> ilist(nloc), numneigh(nloc);
  std::vector<int*> firstneigh(nloc);
  deepmd::hpp::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]);
  convert_nlist(inlist, nlist_data);
  std::vector<VALUETYPE> coord_cpy2(nframes * nall * 3);
  for (int ii = 0; ii < nframes; ++ii) {
    for (int jj = 0; jj < nall * 3; ++jj) {
      coord_cpy2[ii * nall * 3 + jj] = coord_cpy[jj];
    }
  }

  std::vector<double> ener;
  std::vector<VALUETYPE> force_, atom_ener_, atom_vir_, virial;
  std::vector<VALUETYPE> force, atom_ener, atom_vir;
  dp.compute(ener, force_, virial, atom_ener_, atom_vir_, coord_cpy2, atype_cpy,
             box, nall - nloc, inlist, 0);
  _fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3, nframes);
  _fold_back<VALUETYPE>(atom_ener, atom_ener_, mapping, nloc, nall, 1, nframes);
  _fold_back<VALUETYPE>(atom_vir, atom_vir_, mapping, nloc, nall, 9, nframes);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);
  EXPECT_EQ(atom_ener.size(), nframes * natoms);
  EXPECT_EQ(atom_vir.size(), nframes * natoms * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms; ++ii) {
    EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 9; ++ii) {
    EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
  }

  std::fill(ener.begin(), ener.end(), 0.0);
  std::fill(force_.begin(), force_.end(), 0.0);
  std::fill(virial.begin(), virial.end(), 0.0);
  std::fill(atom_ener_.begin(), atom_ener_.end(), 0.0);
  std::fill(atom_vir_.begin(), atom_vir_.end(), 0.0);
  dp.compute(ener, force_, virial, atom_ener_, atom_vir_, coord_cpy2, atype_cpy,
             box, nall - nloc, inlist, 1);
  _fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3, nframes);
  _fold_back<VALUETYPE>(atom_ener, atom_ener_, mapping, nloc, nall, 1, nframes);
  _fold_back<VALUETYPE>(atom_vir, atom_vir_, mapping, nloc, nall, 9, nframes);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);
  EXPECT_EQ(atom_ener.size(), nframes * natoms);
  EXPECT_EQ(atom_vir.size(), nframes * natoms * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms; ++ii) {
    EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 9; ++ii) {
    EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotANFrames, cpu_lmp_nlist_2rc) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  int& nframes = this->nframes;
  std::vector<double>& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::hpp::DeepPot& dp = this->dp;
  float rc = dp.cutoff();
  std::vector<VALUETYPE> coord_first(coord.begin(), coord.begin() + 3 * natoms);
  std::vector<VALUETYPE> box_first(box.begin(), box.begin() + 9);
  int nloc = coord_first.size() / 3;
  std::vector<VALUETYPE> coord_cpy;
  std::vector<int> atype_cpy, mapping;
  std::vector<std::vector<int> > nlist_data;
  _build_nlist<VALUETYPE>(nlist_data, coord_cpy, atype_cpy, mapping,
                          coord_first, atype, box_first, rc * 2);
  int nall = coord_cpy.size() / 3;
  std::vector<int> ilist(nloc), numneigh(nloc);
  std::vector<int*> firstneigh(nloc);
  deepmd::hpp::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]);
  convert_nlist(inlist, nlist_data);
  std::vector<VALUETYPE> coord_cpy2(nframes * nall * 3);
  for (int ii = 0; ii < nframes; ++ii) {
    for (int jj = 0; jj < nall * 3; ++jj) {
      coord_cpy2[ii * nall * 3 + jj] = coord_cpy[jj];
    }
  }

  std::vector<double> ener;
  std::vector<VALUETYPE> force_(nall * 3, 0.0), virial(nframes * 9, 0.0);
  dp.compute(ener, force_, virial, coord_cpy2, atype_cpy, box, nall - nloc,
             inlist, 0);
  std::vector<VALUETYPE> force;
  _fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3, nframes);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }

  std::fill(ener.begin(), ener.end(), 0.0);
  std::fill(force_.begin(), force_.end(), 0.0);
  std::fill(virial.begin(), virial.end(), 0.0);
  dp.compute(ener, force_, virial, coord_cpy2, atype_cpy, box, nall - nloc,
             inlist, 1);
  _fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3, nframes);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotANFrames, cpu_lmp_nlist_type_sel) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  int& nframes = this->nframes;
  std::vector<double>& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::hpp::DeepPot& dp = this->dp;
  float rc = dp.cutoff();

  // add vir atoms
  int nvir = 2;
  std::vector<VALUETYPE> coord_vir(nvir * 3);
  std::vector<int> atype_vir(nvir, 2);
  for (int ii = 0; ii < nvir; ++ii) {
    coord_vir[ii] = coord[ii];
  }
  coord.insert(coord.begin(), coord_vir.begin(), coord_vir.end());
  atype.insert(atype.begin(), atype_vir.begin(), atype_vir.end());
  natoms += nvir;
  std::vector<VALUETYPE> expected_f_vir(nvir * 3, 0.0);
  // two frames
  expected_f.insert(expected_f.begin(), expected_f_vir.begin(),
                    expected_f_vir.end());
  expected_f.insert(expected_f.begin() + natoms * 3, expected_f_vir.begin(),
                    expected_f_vir.end());
  std::vector<VALUETYPE> coord_first(coord.begin(), coord.begin() + 3 * natoms);
  std::vector<VALUETYPE> box_first(box.begin(), box.begin() + 9);

  // build nlist
  int nloc = coord_first.size() / 3;
  std::vector<VALUETYPE> coord_cpy;
  std::vector<int> atype_cpy, mapping;
  std::vector<std::vector<int> > nlist_data;
  _build_nlist<VALUETYPE>(nlist_data, coord_cpy, atype_cpy, mapping,
                          coord_first, atype, box_first, rc);
  int nall = coord_cpy.size() / 3;
  std::vector<int> ilist(nloc), numneigh(nloc);
  std::vector<int*> firstneigh(nloc);
  deepmd::hpp::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]);
  convert_nlist(inlist, nlist_data);
  std::vector<VALUETYPE> coord_cpy2(nframes * nall * 3);
  for (int ii = 0; ii < nframes; ++ii) {
    for (int jj = 0; jj < nall * 3; ++jj) {
      coord_cpy2[ii * nall * 3 + jj] = coord_cpy[jj];
    }
  }

  // dp compute
  std::vector<double> ener;
  std::vector<VALUETYPE> force_(nall * 3, 0.0), virial(nframes * 9, 0.0);
  dp.compute(ener, force_, virial, coord_cpy2, atype_cpy, box, nall - nloc,
             inlist, 0);
  // fold back
  std::vector<VALUETYPE> force;
  _fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3, nframes);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotANFrames, cpu_lmp_nlist_type_sel_atomic) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  int& nframes = this->nframes;
  std::vector<double>& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::hpp::DeepPot& dp = this->dp;
  float rc = dp.cutoff();

  // add vir atoms
  int nvir = 2;
  std::vector<VALUETYPE> coord_vir(nvir * 3);
  std::vector<int> atype_vir(nvir, 2);
  for (int ii = 0; ii < nvir; ++ii) {
    coord_vir[ii] = coord[ii];
  }
  coord.insert(coord.begin(), coord_vir.begin(), coord_vir.end());
  atype.insert(atype.begin(), atype_vir.begin(), atype_vir.end());
  natoms += nvir;
  std::vector<VALUETYPE> expected_f_vir(nvir * 3, 0.0);
  // two frames
  expected_f.insert(expected_f.begin(), expected_f_vir.begin(),
                    expected_f_vir.end());
  expected_f.insert(expected_f.begin() + natoms * 3, expected_f_vir.begin(),
                    expected_f_vir.end());
  std::vector<VALUETYPE> coord_first(coord.begin(), coord.begin() + 3 * natoms);
  std::vector<VALUETYPE> box_first(box.begin(), box.begin() + 9);

  // build nlist
  int nloc = coord_first.size() / 3;
  std::vector<VALUETYPE> coord_cpy;
  std::vector<int> atype_cpy, mapping;
  std::vector<std::vector<int> > nlist_data;
  _build_nlist<VALUETYPE>(nlist_data, coord_cpy, atype_cpy, mapping,
                          coord_first, atype, box_first, rc);
  int nall = coord_cpy.size() / 3;
  std::vector<int> ilist(nloc), numneigh(nloc);
  std::vector<int*> firstneigh(nloc);
  deepmd::hpp::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]);
  convert_nlist(inlist, nlist_data);
  std::vector<VALUETYPE> coord_cpy2(nframes * nall * 3);
  for (int ii = 0; ii < nframes; ++ii) {
    for (int jj = 0; jj < nall * 3; ++jj) {
      coord_cpy2[ii * nall * 3 + jj] = coord_cpy[jj];
    }
  }

  // dp compute
  std::vector<double> ener;
  std::vector<VALUETYPE> force_(nall * 3, 0.0), virial(nframes * 9, 0.0),
      atomic_energy, atomic_virial;
  dp.compute(ener, force_, virial, atomic_energy, atomic_virial, coord_cpy2,
             atype_cpy, box, nall - nloc, inlist, 0);
  // fold back
  std::vector<VALUETYPE> force;
  _fold_back<VALUETYPE>(force, force_, mapping, nloc, nall, 3, nframes);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

template <class VALUETYPE>
class TestInferDeepPotANFramesNoPbc : public ::testing::Test {
 protected:
  std::vector<VALUETYPE> coord = {
      12.83, 2.56, 2.18, 12.09, 2.87, 2.74, 00.25, 3.32, 1.68,
      3.36,  3.00, 1.81, 3.51,  2.51, 2.60, 4.27,  3.22, 1.56,
      12.83, 2.56, 2.18, 12.09, 2.87, 2.74, 00.25, 3.32, 1.68,
      3.36,  3.00, 1.81, 3.51,  2.51, 2.60, 4.27,  3.22, 1.56};
  std::vector<int> atype = {0, 1, 1, 0, 1, 1};
  std::vector<VALUETYPE> box = {};
  std::vector<VALUETYPE> expected_e = {
      -9.255934839310273787e+01, -1.863253376736990106e+02,
      -1.857237299341402945e+02, -9.279308539717486326e+01,
      -1.863708105823244239e+02, -1.863635196514972563e+02,
      -9.255934839310273787e+01, -1.863253376736990106e+02,
      -1.857237299341402945e+02, -9.279308539717486326e+01,
      -1.863708105823244239e+02, -1.863635196514972563e+02};
  std::vector<VALUETYPE> expected_f = {
      -2.161037360255332107e+00, 9.052994347015581589e-01,
      1.635379623977007979e+00,  2.161037360255332107e+00,
      -9.052994347015581589e-01, -1.635379623977007979e+00,
      -1.167128117249453811e-02, 1.371975700096064992e-03,
      -1.575265180249604477e-03, 6.226508593971802341e-01,
      -1.816734122009256991e-01, 3.561766019664774907e-01,
      -1.406075393906316626e-02, 3.789140061530929526e-01,
      -6.018777878642909140e-01, -5.969188242856223736e-01,
      -1.986125696522633155e-01, 2.472764510780630642e-01,
      -2.161037360255332107e+00, 9.052994347015581589e-01,
      1.635379623977007979e+00,  2.161037360255332107e+00,
      -9.052994347015581589e-01, -1.635379623977007979e+00,
      -1.167128117249453811e-02, 1.371975700096064992e-03,
      -1.575265180249604477e-03, 6.226508593971802341e-01,
      -1.816734122009256991e-01, 3.561766019664774907e-01,
      -1.406075393906316626e-02, 3.789140061530929526e-01,
      -6.018777878642909140e-01, -5.969188242856223736e-01,
      -1.986125696522633155e-01, 2.472764510780630642e-01};
  std::vector<VALUETYPE> expected_v = {
      -7.042445481792056761e-01, 2.950213647777754078e-01,
      5.329418202437231633e-01,  2.950213647777752968e-01,
      -1.235900311906896754e-01, -2.232594111831812944e-01,
      5.329418202437232743e-01,  -2.232594111831813499e-01,
      -4.033073234276823849e-01, -8.949230984097404917e-01,
      3.749002169013777030e-01,  6.772391014992630298e-01,
      3.749002169013777586e-01,  -1.570527935667933583e-01,
      -2.837082722496912512e-01, 6.772391014992631408e-01,
      -2.837082722496912512e-01, -5.125052659994422388e-01,
      4.858210330291591605e-02,  -6.902596153269104431e-03,
      6.682612642430500391e-03,  -5.612247004554610057e-03,
      9.767795567660207592e-04,  -9.773758942738038254e-04,
      5.638322117219018645e-03,  -9.483806049779926932e-04,
      8.493873281881353637e-04,  -2.941738570564985666e-01,
      -4.482529909499673171e-02, 4.091569840186781021e-02,
      -4.509020615859140463e-02, -1.013919988807244071e-01,
      1.551440772665269030e-01,  4.181857726606644232e-02,
      1.547200233064863484e-01,  -2.398213304685777592e-01,
      -3.218625798524068354e-02, -1.012438450438508421e-02,
      1.271639330380921855e-02,  3.072814938490859779e-03,
      -9.556241797915024372e-02, 1.512251983492413077e-01,
      -8.277872384009607454e-03, 1.505412040827929787e-01,
      -2.386150620881526407e-01, -2.312295470054945568e-01,
      -6.631490213524345034e-02, 7.932427266386249398e-02,
      -8.053754366323923053e-02, -3.294595881137418747e-02,
      4.342495071150231922e-02,  1.004599500126941436e-01,
      4.450400364869536163e-02,  -5.951077548033092968e-02,
      -7.042445481792056761e-01, 2.950213647777754078e-01,
      5.329418202437231633e-01,  2.950213647777752968e-01,
      -1.235900311906896754e-01, -2.232594111831812944e-01,
      5.329418202437232743e-01,  -2.232594111831813499e-01,
      -4.033073234276823849e-01, -8.949230984097404917e-01,
      3.749002169013777030e-01,  6.772391014992630298e-01,
      3.749002169013777586e-01,  -1.570527935667933583e-01,
      -2.837082722496912512e-01, 6.772391014992631408e-01,
      -2.837082722496912512e-01, -5.125052659994422388e-01,
      4.858210330291591605e-02,  -6.902596153269104431e-03,
      6.682612642430500391e-03,  -5.612247004554610057e-03,
      9.767795567660207592e-04,  -9.773758942738038254e-04,
      5.638322117219018645e-03,  -9.483806049779926932e-04,
      8.493873281881353637e-04,  -2.941738570564985666e-01,
      -4.482529909499673171e-02, 4.091569840186781021e-02,
      -4.509020615859140463e-02, -1.013919988807244071e-01,
      1.551440772665269030e-01,  4.181857726606644232e-02,
      1.547200233064863484e-01,  -2.398213304685777592e-01,
      -3.218625798524068354e-02, -1.012438450438508421e-02,
      1.271639330380921855e-02,  3.072814938490859779e-03,
      -9.556241797915024372e-02, 1.512251983492413077e-01,
      -8.277872384009607454e-03, 1.505412040827929787e-01,
      -2.386150620881526407e-01, -2.312295470054945568e-01,
      -6.631490213524345034e-02, 7.932427266386249398e-02,
      -8.053754366323923053e-02, -3.294595881137418747e-02,
      4.342495071150231922e-02,  1.004599500126941436e-01,
      4.450400364869536163e-02,  -5.951077548033092968e-02};
  int natoms;
  int nframes = 2;
  std::vector<double> expected_tot_e;
  std::vector<VALUETYPE> expected_tot_v;

  deepmd::hpp::DeepPot dp;

  void SetUp() override {
    std::string file_name = "../../tests/infer/deeppot.pbtxt";
    deepmd::hpp::convert_pbtxt_to_pb(file_name, "deeppot.pb");

    dp.init("deeppot.pb");

    natoms = expected_e.size() / nframes;
    EXPECT_EQ(nframes * natoms * 3, expected_f.size());
    EXPECT_EQ(nframes * natoms * 9, expected_v.size());
    expected_tot_e.resize(nframes);
    expected_tot_v.resize(static_cast<size_t>(nframes) * 9);
    std::fill(expected_tot_e.begin(), expected_tot_e.end(), 0.);
    std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
    for (int kk = 0; kk < nframes; ++kk) {
      for (int ii = 0; ii < natoms; ++ii) {
        expected_tot_e[kk] += expected_e[kk * natoms + ii];
      }
      for (int ii = 0; ii < natoms; ++ii) {
        for (int dd = 0; dd < 9; ++dd) {
          expected_tot_v[kk * 9 + dd] +=
              expected_v[kk * natoms * 9 + ii * 9 + dd];
        }
      }
    }
  };

  void TearDown() override { remove("deeppot.pb"); };
};

TYPED_TEST_SUITE(TestInferDeepPotANFramesNoPbc, ValueTypes);

TYPED_TEST(TestInferDeepPotANFramesNoPbc, cpu_build_nlist) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  int& nframes = this->nframes;
  std::vector<double>& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::hpp::DeepPot& dp = this->dp;
  std::vector<double> ener;
  std::vector<VALUETYPE> force, virial;
  dp.compute(ener, force, virial, coord, atype, box);

  EXPECT_EQ(ener.size(), nframes);
  EXPECT_EQ(force.size(), nframes * natoms * 3);
  EXPECT_EQ(virial.size(), nframes * 9);

  for (int ii = 0; ii < nframes; ++ii) {
    EXPECT_LT(fabs(ener[ii] - expected_tot_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < nframes * 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}
