# SPDX-License-Identifier: LGPL-3.0-or-later
import os

import numpy as np

import deepmd.tf.op  # noqa: F401
from deepmd.tf.common import (
    make_default_mesh,
    select_idx_map,
)
from deepmd.tf.env import (
    GLOBAL_TF_FLOAT_PRECISION,
    op_module,
    tf,
)
from deepmd.tf.infer.deep_dipole import DeepDipoleOld as DeepDipole
from deepmd.tf.infer.ewald_recp import (
    EwaldRecp,
)
from deepmd.tf.utils.data import (
    DeepmdData,
)
from deepmd.tf.utils.sess import (
    run_sess,
)


class DipoleChargeModifier(DeepDipole):
    """Parameters
    ----------
    model_name
            The model file for the DeepDipole model
    model_charge_map
            Gives the amount of charge for the wfcc
    sys_charge_map
            Gives the amount of charge for the real atoms
    ewald_h
            Grid spacing of the reciprocal part of Ewald sum. Unit: A
    ewald_beta
            Splitting parameter of the Ewald sum. Unit: A^{-1}
    """

    def __init__(
        self,
        model_name: str,
        model_charge_map: list[float],
        sys_charge_map: list[float],
        ewald_h: float = 1,
        ewald_beta: float = 1,
    ) -> None:
        """Constructor."""
        # the dipole model is loaded with prefix 'dipole_charge'
        self.modifier_prefix = "dipole_charge"
        # init dipole model
        DeepDipole.__init__(
            self, model_name, load_prefix=self.modifier_prefix, default_tf_graph=True
        )
        self.model_name = model_name
        self.model_charge_map = model_charge_map
        self.sys_charge_map = sys_charge_map
        self.sel_type = list(self.get_sel_type())
        # init ewald recp
        self.ewald_h = ewald_h
        self.ewald_beta = ewald_beta
        self.er = EwaldRecp(self.ewald_h, self.ewald_beta)
        # dimension of dipole
        self.ext_dim = 3
        self.t_ndesc = self.graph.get_tensor_by_name(
            os.path.join(self.modifier_prefix, "descrpt_attr/ndescrpt:0")
        )
        self.t_sela = self.graph.get_tensor_by_name(
            os.path.join(self.modifier_prefix, "descrpt_attr/sel:0")
        )
        [self.ndescrpt, self.sel_a] = run_sess(self.sess, [self.t_ndesc, self.t_sela])
        self.sel_r = [0 for ii in range(len(self.sel_a))]
        self.nnei_a = np.cumsum(self.sel_a)[-1]
        self.nnei_r = np.cumsum(self.sel_r)[-1]
        self.nnei = self.nnei_a + self.nnei_r
        self.ndescrpt_a = self.nnei_a * 4
        self.ndescrpt_r = self.nnei_r * 1
        assert self.ndescrpt == self.ndescrpt_a + self.ndescrpt_r
        self.force = None
        self.ntypes = len(self.sel_a)

    def build_fv_graph(self) -> tf.Tensor:
        """Build the computational graph for the force and virial inference."""
        with tf.variable_scope("modifier_attr"):
            t_mdl_name = tf.constant(self.model_name, name="mdl_name", dtype=tf.string)
            t_modi_type = tf.constant(
                self.modifier_prefix, name="type", dtype=tf.string
            )
            t_mdl_charge_map = tf.constant(
                " ".join([str(ii) for ii in self.model_charge_map]),
                name="mdl_charge_map",
                dtype=tf.string,
            )
            t_sys_charge_map = tf.constant(
                " ".join([str(ii) for ii in self.sys_charge_map]),
                name="sys_charge_map",
                dtype=tf.string,
            )
            t_ewald_h = tf.constant(self.ewald_h, name="ewald_h", dtype=tf.float64)
            t_ewald_b = tf.constant(
                self.ewald_beta, name="ewald_beta", dtype=tf.float64
            )
        with self.graph.as_default():
            return self._build_fv_graph_inner()

    def _build_fv_graph_inner(self):
        self.t_ef = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_ef")
        nf = 10
        nfxnas = 64 * nf
        nfxna = 192 * nf
        nf = -1
        nfxnas = -1
        nfxna = -1
        self.t_box_reshape = tf.reshape(self.t_box, [-1, 9])
        t_nframes = tf.shape(self.t_box_reshape)[0]

        # (nframes x natoms) x ndescrpt
        self.descrpt = self.graph.get_tensor_by_name(
            os.path.join(self.modifier_prefix, "o_rmat:0")
        )
        self.descrpt_deriv = self.graph.get_tensor_by_name(
            os.path.join(self.modifier_prefix, "o_rmat_deriv:0")
        )
        self.nlist = self.graph.get_tensor_by_name(
            os.path.join(self.modifier_prefix, "o_nlist:0")
        )
        self.rij = self.graph.get_tensor_by_name(
            os.path.join(self.modifier_prefix, "o_rij:0")
        )
        # self.descrpt_reshape = tf.reshape(self.descrpt, [nf, 192 * self.ndescrpt])
        # self.descrpt_deriv = tf.reshape(self.descrpt_deriv, [nf, 192 * self.ndescrpt * 3])

        # nframes x (natoms_sel x 3)
        self.t_ef_reshape = tf.reshape(self.t_ef, [t_nframes, -1])
        # nframes x (natoms x 3)
        self.t_ef_reshape = self._enrich(self.t_ef_reshape, dof=3)
        # (nframes x natoms) x 3
        self.t_ef_reshape = tf.reshape(self.t_ef_reshape, [nfxna, 3])
        # nframes x (natoms_sel x 3)
        self.t_tensor_reshape = tf.reshape(self.t_tensor, [t_nframes, -1])
        # nframes x (natoms x 3)
        self.t_tensor_reshape = self._enrich(self.t_tensor_reshape, dof=3)
        # (nframes x natoms) x 3
        self.t_tensor_reshape = tf.reshape(self.t_tensor_reshape, [nfxna, 3])
        # (nframes x natoms) x ndescrpt
        [self.t_ef_d] = tf.gradients(
            self.t_tensor_reshape, self.descrpt, self.t_ef_reshape
        )
        # nframes x (natoms x ndescrpt)
        self.t_ef_d = tf.reshape(self.t_ef_d, [nf, self.t_natoms[0] * self.ndescrpt])
        # t_ef_d is force (with -1), prod_forc takes deriv, so we need the opposite
        self.t_ef_d_oppo = -self.t_ef_d

        force = op_module.prod_force_se_a(
            self.t_ef_d_oppo,
            self.descrpt_deriv,
            self.nlist,
            self.t_natoms,
            n_a_sel=self.nnei_a,
            n_r_sel=self.nnei_r,
        )
        virial, atom_virial = op_module.prod_virial_se_a(
            self.t_ef_d_oppo,
            self.descrpt_deriv,
            self.rij,
            self.nlist,
            self.t_natoms,
            n_a_sel=self.nnei_a,
            n_r_sel=self.nnei_r,
        )
        force = tf.identity(force, name="o_dm_force")
        virial = tf.identity(virial, name="o_dm_virial")
        atom_virial = tf.identity(atom_virial, name="o_dm_av")
        return force, virial, atom_virial

    def _enrich(self, dipole, dof=3):
        coll = []
        sel_start_idx = 0
        for type_i in range(self.ntypes):
            if type_i in self.sel_type:
                di = tf.slice(
                    dipole,
                    [0, sel_start_idx * dof],
                    [-1, self.t_natoms[2 + type_i] * dof],
                )
                sel_start_idx += self.t_natoms[2 + type_i]
            else:
                di = tf.zeros(
                    [tf.shape(dipole)[0], self.t_natoms[2 + type_i] * dof],
                    dtype=GLOBAL_TF_FLOAT_PRECISION,
                )
            coll.append(di)
        return tf.concat(coll, axis=1)

    def _slice_descrpt_deriv(self, deriv):
        coll = []
        start_idx = 0
        for type_i in range(self.ntypes):
            if type_i in self.sel_type:
                di = tf.slice(
                    deriv,
                    [0, start_idx * self.ndescrpt],
                    [-1, self.t_natoms[2 + type_i] * self.ndescrpt],
                )
                coll.append(di)
            start_idx += self.t_natoms[2 + type_i]
        return tf.concat(coll, axis=1)

    def eval(
        self,
        coord: np.ndarray,
        box: np.ndarray,
        atype: np.ndarray,
        eval_fv: bool = True,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Evaluate the modification.

        Parameters
        ----------
        coord
            The coordinates of atoms
        box
            The simulation region. PBC is assumed
        atype
            The atom types
        eval_fv
            Evaluate force and virial

        Returns
        -------
        tot_e
            The energy modification
        tot_f
            The force modification
        tot_v
            The virial modification
        """
        atype = np.array(atype, dtype=int)
        coord, atype, imap = self.sort_input(coord, atype)
        # natoms = coord.shape[1] // 3
        natoms = atype.size
        nframes = coord.shape[0]
        box = np.reshape(box, [nframes, 9])
        atype = np.reshape(atype, [natoms])
        sel_idx_map = select_idx_map(atype, self.sel_type)
        nsel = len(sel_idx_map)
        # setup charge
        charge = np.zeros([natoms])  # pylint: disable=no-explicit-dtype
        for ii in range(natoms):
            charge[ii] = self.sys_charge_map[atype[ii]]
        charge = np.tile(charge, [nframes, 1])

        # add wfcc
        all_coord, all_charge, dipole = self._extend_system(coord, box, atype, charge)

        # print('compute er')
        batch_size = 5
        tot_e = []
        all_f = []
        all_v = []
        for ii in range(0, nframes, batch_size):
            e, f, v = self.er.eval(
                all_coord[ii : ii + batch_size],
                all_charge[ii : ii + batch_size],
                box[ii : ii + batch_size],
            )
            tot_e.append(e)
            all_f.append(f)
            all_v.append(v)
        tot_e = np.concatenate(tot_e, axis=0)
        all_f = np.concatenate(all_f, axis=0)
        all_v = np.concatenate(all_v, axis=0)
        # print('finish  er')
        # reshape
        tot_e.reshape([nframes, 1])

        tot_f = None
        tot_v = None
        if self.force is None:
            self.force, self.virial, self.av = self.build_fv_graph()
        if eval_fv:
            # compute f
            ext_f = all_f[:, natoms * 3 :]
            corr_f = []
            corr_v = []
            corr_av = []
            for ii in range(0, nframes, batch_size):
                f, v, av = self._eval_fv(
                    coord[ii : ii + batch_size],
                    box[ii : ii + batch_size],
                    atype,
                    ext_f[ii : ii + batch_size],
                )
                corr_f.append(f)
                corr_v.append(v)
                corr_av.append(av)
            corr_f = np.concatenate(corr_f, axis=0)
            corr_v = np.concatenate(corr_v, axis=0)
            corr_av = np.concatenate(corr_av, axis=0)
            tot_f = all_f[:, : natoms * 3] + corr_f
            for ii in range(nsel):
                orig_idx = sel_idx_map[ii]
                tot_f[:, orig_idx * 3 : orig_idx * 3 + 3] += ext_f[
                    :, ii * 3 : ii * 3 + 3
                ]
            tot_f = self.reverse_map(np.reshape(tot_f, [nframes, -1, 3]), imap)
            # reshape
            tot_f = tot_f.reshape([nframes, natoms, 3])
            # compute v
            dipole3 = np.reshape(dipole, [nframes, nsel, 3])
            ext_f3 = np.reshape(ext_f, [nframes, nsel, 3])
            ext_f3 = np.transpose(ext_f3, [0, 2, 1])
            # fd_corr_v = -np.matmul(ext_f3, dipole3).T.reshape([nframes, 9])
            # fd_corr_v = -np.matmul(ext_f3, dipole3)
            # fd_corr_v = np.transpose(fd_corr_v, [0, 2, 1]).reshape([nframes, 9])
            fd_corr_v = -np.matmul(ext_f3, dipole3).reshape([nframes, 9])
            # print(all_v, '\n', corr_v, '\n', fd_corr_v)
            tot_v = all_v + corr_v + fd_corr_v
            # reshape
            tot_v = tot_v.reshape([nframes, 9])

        return tot_e, tot_f, tot_v

    def _eval_fv(self, coords, cells, atom_types, ext_f):
        # reshape the inputs
        cells = np.reshape(cells, [-1, 9])
        nframes = cells.shape[0]
        coords = np.reshape(coords, [nframes, -1])
        natoms = coords.shape[1] // 3

        # sort inputs
        coords, atom_types, imap, sel_at, sel_imap = self.sort_input(
            coords, atom_types, sel_atoms=self.get_sel_type()
        )

        # make natoms_vec and default_mesh
        natoms_vec = self.make_natoms_vec(atom_types)
        assert natoms_vec[0] == natoms
        default_mesh = make_default_mesh(True, False)

        # evaluate
        tensor = []
        feed_dict_test = {}
        feed_dict_test[self.t_natoms] = natoms_vec
        feed_dict_test[self.t_type] = np.tile(atom_types, [nframes, 1]).reshape([-1])
        feed_dict_test[self.t_coord] = coords.reshape([-1])
        feed_dict_test[self.t_box] = cells.reshape([-1])
        feed_dict_test[self.t_mesh] = default_mesh.reshape([-1])
        feed_dict_test[self.t_ef] = ext_f.reshape([-1])
        # print(run_sess(self.sess, tf.shape(self.t_tensor), feed_dict = feed_dict_test))
        fout, vout, avout = run_sess(
            self.sess, [self.force, self.virial, self.av], feed_dict=feed_dict_test
        )
        # print('fout: ', fout.shape, fout)
        fout = self.reverse_map(np.reshape(fout, [nframes, -1, 3]), imap)
        fout = np.reshape(fout, [nframes, -1])
        return fout, vout, avout

    def _extend_system(self, coord, box, atype, charge):
        natoms = coord.shape[1] // 3
        nframes = coord.shape[0]
        # sel atoms and setup ref coord
        sel_idx_map = select_idx_map(atype, self.sel_type)
        nsel = len(sel_idx_map)
        coord3 = coord.reshape([nframes, natoms, 3])
        ref_coord = coord3[:, sel_idx_map, :]
        ref_coord = np.reshape(ref_coord, [nframes, nsel * 3])

        batch_size = 8
        all_dipole = []
        for ii in range(0, nframes, batch_size):
            dipole = DeepDipole.eval(
                self, coord[ii : ii + batch_size], box[ii : ii + batch_size], atype
            )
            all_dipole.append(dipole)
        dipole = np.concatenate(all_dipole, axis=0)
        assert dipole.shape[0] == nframes
        dipole = np.reshape(dipole, [nframes, nsel * 3])

        wfcc_coord = ref_coord + dipole
        # wfcc_coord = dipole
        wfcc_charge = np.zeros([nsel])  # pylint: disable=no-explicit-dtype
        for ii in range(nsel):
            orig_idx = self.sel_type.index(atype[sel_idx_map[ii]])
            wfcc_charge[ii] = self.model_charge_map[orig_idx]
        wfcc_charge = np.tile(wfcc_charge, [nframes, 1])

        wfcc_coord = np.reshape(wfcc_coord, [nframes, nsel * 3])
        wfcc_charge = np.reshape(wfcc_charge, [nframes, nsel])

        all_coord = np.concatenate((coord, wfcc_coord), axis=1)
        all_charge = np.concatenate((charge, wfcc_charge), axis=1)

        return all_coord, all_charge, dipole

    def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
        """Modify data.

        Parameters
        ----------
        data
            Internal data of DeepmdData.
            Be a dict, has the following keys
            - coord         coordinates
            - box           simulation box
            - type          atom types
            - find_energy   tells if data has energy
            - find_force    tells if data has force
            - find_virial   tells if data has virial
            - energy        energy
            - force         force
            - virial        virial
        data_sys : DeepmdData
            The data system.
        """
        if (
            "find_energy" not in data
            and "find_force" not in data
            and "find_virial" not in data
        ):
            return

        get_nframes = None
        coord = data["coord"][:get_nframes, :]
        if not data_sys.pbc:
            raise RuntimeError("Open systems (nopbc) are not supported")
        box = data["box"][:get_nframes, :]
        atype = data["type"][:get_nframes, :]
        atype = atype[0]
        nframes = coord.shape[0]

        tot_e, tot_f, tot_v = self.eval(coord, box, atype)

        # print(tot_f[:,0])

        if "find_energy" in data and data["find_energy"] == 1.0:
            data["energy"] -= tot_e.reshape(data["energy"].shape)
        if "find_force" in data and data["find_force"] == 1.0:
            data["force"] -= tot_f.reshape(data["force"].shape)
        if "find_virial" in data and data["find_virial"] == 1.0:
            data["virial"] -= tot_v.reshape(data["virial"].shape)
