# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Any,
    Optional,
    Union,
)

import numpy as np

from deepmd.tf.env import (
    GLOBAL_TF_FLOAT_PRECISION,
    tf,
)
from deepmd.tf.utils.spin import (
    Spin,
)
from deepmd.utils.data_system import (
    DeepmdDataSystem,
)
from deepmd.utils.version import (
    check_version_compatibility,
)

# from deepmd.tf.descriptor import DescrptLocFrame
# from deepmd.tf.descriptor import DescrptSeA
# from deepmd.tf.descriptor import DescrptSeT
# from deepmd.tf.descriptor import DescrptSeAEbd
# from deepmd.tf.descriptor import DescrptSeAEf
# from deepmd.tf.descriptor import DescrptSeR
from .descriptor import (
    Descriptor,
)


@Descriptor.register("hybrid")
class DescrptHybrid(Descriptor):
    """Concate a list of descriptors to form a new descriptor.

    Parameters
    ----------
    list : list : list[Union[Descriptor, dict[str, Any]]]
            Build a descriptor from the concatenation of the list of descriptors.
            The descriptor can be either an object or a dictionary.
    """

    def __init__(
        self,
        list: list[Union[Descriptor, dict[str, Any]]],
        ntypes: Optional[int] = None,
        spin: Optional[Spin] = None,
        **kwargs,
    ) -> None:
        """Constructor."""
        # warning: list is conflict with built-in list
        descrpt_list = list
        if descrpt_list == [] or descrpt_list is None:
            raise RuntimeError(
                "cannot build descriptor from an empty list of descriptors."
            )
        formatted_descript_list = []
        for ii in descrpt_list:
            if isinstance(ii, Descriptor):
                formatted_descript_list.append(ii)
            elif isinstance(ii, dict):
                formatted_descript_list.append(
                    Descriptor(**ii, ntypes=ntypes, spin=spin)
                )
            else:
                raise NotImplementedError
        self.descrpt_list = formatted_descript_list
        self.numb_descrpt = len(self.descrpt_list)
        for ii in range(1, self.numb_descrpt):
            assert (
                self.descrpt_list[ii].get_ntypes() == self.descrpt_list[0].get_ntypes()
            ), f"number of atom types in {ii}th descriptor does not match others"

    def get_rcut(self) -> float:
        """Returns the cut-off radius."""
        all_rcut = [ii.get_rcut() for ii in self.descrpt_list]
        return np.max(all_rcut)

    def get_ntypes(self) -> int:
        """Returns the number of atom types."""
        return self.descrpt_list[0].get_ntypes()

    def get_dim_out(self) -> int:
        """Returns the output dimension of this descriptor."""
        all_dim_out = [ii.get_dim_out() for ii in self.descrpt_list]
        return sum(all_dim_out)

    def get_nlist(
        self,
    ) -> tuple[tf.Tensor, tf.Tensor, list[int], list[int]]:
        """Get the neighbor information of the descriptor, returns the
        nlist of the descriptor with the largest cut-off radius.

        Returns
        -------
        nlist
            Neighbor list
        rij
            The relative distance between the neighbor and the center atom.
        sel_a
            The number of neighbors with full information
        sel_r
            The number of neighbors with only radial information
        """
        maxr_idx = np.argmax([ii.get_rcut() for ii in self.descrpt_list])
        return self.get_nlist_i(maxr_idx)

    def get_nlist_i(self, ii: int) -> tuple[tf.Tensor, tf.Tensor, list[int], list[int]]:
        """Get the neighbor information of the ii-th descriptor.

        Parameters
        ----------
        ii : int
            The index of the descriptor

        Returns
        -------
        nlist
            Neighbor list
        rij
            The relative distance between the neighbor and the center atom.
        sel_a
            The number of neighbors with full information
        sel_r
            The number of neighbors with only radial information
        """
        return (
            self.descrpt_list[ii].nlist,
            self.descrpt_list[ii].rij,
            self.descrpt_list[ii].sel_a,
            self.descrpt_list[ii].sel_r,
        )

    def compute_input_stats(
        self,
        data_coord: list,
        data_box: list,
        data_atype: list,
        natoms_vec: list,
        mesh: list,
        input_dict: dict,
        mixed_type: bool = False,
        real_natoms_vec: Optional[list] = None,
        **kwargs,
    ) -> None:
        """Compute the statisitcs (avg and std) of the training data. The input will be normalized by the statistics.

        Parameters
        ----------
        data_coord
            The coordinates. Can be generated by deepmd.tf.model.make_stat_input
        data_box
            The box. Can be generated by deepmd.tf.model.make_stat_input
        data_atype
            The atom types. Can be generated by deepmd.tf.model.make_stat_input
        natoms_vec
            The vector for the number of atoms of the system and different types of atoms. Can be generated by deepmd.tf.model.make_stat_input
        mesh
            The mesh for neighbor searching. Can be generated by deepmd.tf.model.make_stat_input
        input_dict
            Dictionary for additional input
        mixed_type
            Whether to perform the mixed_type mode.
            If True, the input data has the mixed_type format (see doc/model/train_se_atten.md),
            in which frames in a system may have different natoms_vec(s), with the same nloc.
        real_natoms_vec
            If mixed_type is True, it takes in the real natoms_vec for each frame.
        **kwargs
            Additional keyword arguments.
        """
        for ii in self.descrpt_list:
            ii.compute_input_stats(
                data_coord,
                data_box,
                data_atype,
                natoms_vec,
                mesh,
                input_dict,
                mixed_type=mixed_type,
                real_natoms_vec=real_natoms_vec,
                **kwargs,
            )

    def merge_input_stats(self, stat_dict) -> None:
        """Merge the statisitcs computed from compute_input_stats to obtain the self.davg and self.dstd.

        Parameters
        ----------
        stat_dict
                The dict of statisitcs computed from compute_input_stats, including:
            sumr
                    The sum of radial statisitcs.
            suma
                    The sum of relative coord statisitcs.
            sumn
                    The sum of neighbor numbers.
            sumr2
                    The sum of square of radial statisitcs.
            suma2
                    The sum of square of relative coord statisitcs.
        """
        for ii in self.descrpt_list:
            ii.merge_input_stats(stat_dict)

    def build(
        self,
        coord_: tf.Tensor,
        atype_: tf.Tensor,
        natoms: tf.Tensor,
        box_: tf.Tensor,
        mesh: tf.Tensor,
        input_dict: dict,
        reuse: Optional[bool] = None,
        suffix: str = "",
    ) -> tf.Tensor:
        """Build the computational graph for the descriptor.

        Parameters
        ----------
        coord_
            The coordinate of atoms
        atype_
            The type of atoms
        natoms
            The number of atoms. This tensor has the length of Ntypes + 2
            natoms[0]: number of local atoms
            natoms[1]: total number of atoms held by this processor
            natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        box_ : tf.Tensor
            The box of the system
        mesh
            For historical reasons, only the length of the Tensor matters.
            if size of mesh == 6, pbc is assumed.
            if size of mesh == 0, no-pbc is assumed.
        input_dict
            Dictionary for additional inputs
        reuse
            The weights in the networks should be reused when get the variable.
        suffix
            Name suffix to identify this descriptor

        Returns
        -------
        descriptor
            The output descriptor
        """
        with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse):
            t_rcut = tf.constant(
                self.get_rcut(), name="rcut", dtype=GLOBAL_TF_FLOAT_PRECISION
            )
            t_ntypes = tf.constant(self.get_ntypes(), name="ntypes", dtype=tf.int32)
        all_dout = []
        for idx, ii in enumerate(self.descrpt_list):
            dout = ii.build(
                coord_,
                atype_,
                natoms,
                box_,
                mesh,
                input_dict,
                suffix=suffix + f"_{idx}",
                reuse=reuse,
            )
            dout = tf.reshape(dout, [-1, ii.get_dim_out()])
            all_dout.append(dout)
        dout = tf.concat(all_dout, axis=1)
        dout = tf.reshape(dout, [-1, natoms[0], self.get_dim_out()])
        return dout

    def prod_force_virial(
        self, atom_ener: tf.Tensor, natoms: tf.Tensor
    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
        """Compute force and virial.

        Parameters
        ----------
        atom_ener
            The atomic energy
        natoms
            The number of atoms. This tensor has the length of Ntypes + 2
            natoms[0]: number of local atoms
            natoms[1]: total number of atoms held by this processor
            natoms[i]: 2 <= i < Ntypes+2, number of type i atoms

        Returns
        -------
        force
            The force on atoms
        virial
            The total virial
        atom_virial
            The atomic virial
        """
        for idx, ii in enumerate(self.descrpt_list):
            ff, vv, av = ii.prod_force_virial(atom_ener, natoms)
            if idx == 0:
                force = ff
                virial = vv
                atom_virial = av
            else:
                force += ff
                virial += vv
                atom_virial += av
        return force, virial, atom_virial

    def enable_compression(
        self,
        min_nbor_dist: float,
        graph: tf.Graph,
        graph_def: tf.GraphDef,
        table_extrapolate: float = 5.0,
        table_stride_1: float = 0.01,
        table_stride_2: float = 0.1,
        check_frequency: int = -1,
        suffix: str = "",
    ) -> None:
        """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the
        training data.

        Parameters
        ----------
        min_nbor_dist : float
            The nearest distance between atoms
        graph : tf.Graph
            The graph of the model
        graph_def : tf.GraphDef
            The graph_def of the model
        table_extrapolate : float, default: 5.
            The scale of model extrapolation
        table_stride_1 : float, default: 0.01
            The uniform stride of the first table
        table_stride_2 : float, default: 0.1
            The uniform stride of the second table
        check_frequency : int, default: -1
            The overflow check frequency
        suffix : str, optional
            The suffix of the scope
        """
        for idx, ii in enumerate(self.descrpt_list):
            ii.enable_compression(
                min_nbor_dist,
                graph,
                graph_def,
                table_extrapolate,
                table_stride_1,
                table_stride_2,
                check_frequency,
                suffix=f"{suffix}_{idx}",
            )

    def enable_mixed_precision(self, mixed_prec: Optional[dict] = None) -> None:
        """Receive the mixed precision setting.

        Parameters
        ----------
        mixed_prec
            The mixed precision setting used in the embedding net
        """
        for idx, ii in enumerate(self.descrpt_list):
            ii.enable_mixed_precision(mixed_prec)

    def init_variables(
        self,
        graph: tf.Graph,
        graph_def: tf.GraphDef,
        suffix: str = "",
    ) -> None:
        """Init the embedding net variables with the given dict.

        Parameters
        ----------
        graph : tf.Graph
            The input frozen model graph
        graph_def : tf.GraphDef
            The input frozen model graph_def
        suffix : str, optional
            The suffix of the scope
        """
        for idx, ii in enumerate(self.descrpt_list):
            ii.init_variables(graph, graph_def, suffix=f"{suffix}_{idx}")

    def get_tensor_names(self, suffix: str = "") -> tuple[str]:
        """Get names of tensors.

        Parameters
        ----------
        suffix : str
            The suffix of the scope

        Returns
        -------
        tuple[str]
            Names of tensors
        """
        tensor_names = []
        for idx, ii in enumerate(self.descrpt_list):
            tensor_names.extend(ii.get_tensor_names(suffix=f"{suffix}_{idx}"))
        return tuple(tensor_names)

    def pass_tensors_from_frz_model(
        self,
        *tensors: tf.Tensor,
    ) -> None:
        """Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def.

        Parameters
        ----------
        *tensors : tf.Tensor
            passed tensors
        """
        jj = 0
        for ii in self.descrpt_list:
            n_tensors = len(ii.get_tensor_names())
            ii.pass_tensors_from_frz_model(*tensors[jj : jj + n_tensors])
            jj += n_tensors

    @property
    def explicit_ntypes(self) -> bool:
        """Explicit ntypes with type embedding."""
        return any(ii.explicit_ntypes for ii in self.descrpt_list)

    @classmethod
    def update_sel(
        cls,
        train_data: DeepmdDataSystem,
        type_map: Optional[list[str]],
        local_jdata: dict,
    ) -> tuple[dict, Optional[float]]:
        """Update the selection and perform neighbor statistics.

        Parameters
        ----------
        train_data : DeepmdDataSystem
            data used to do neighbor statistics
        type_map : list[str], optional
            The name of each type of atoms
        local_jdata : dict
            The local data refer to the current class

        Returns
        -------
        dict
            The updated local data
        float
            The minimum distance between two atoms
        """
        local_jdata_cpy = local_jdata.copy()
        new_list = []
        min_nbor_dist = None
        for sub_jdata in local_jdata["list"]:
            new_sub_jdata, min_nbor_dist_ = Descriptor.update_sel(
                train_data, type_map, sub_jdata
            )
            if min_nbor_dist_ is not None:
                min_nbor_dist = min_nbor_dist_
            new_list.append(new_sub_jdata)
        local_jdata_cpy["list"] = new_list
        return local_jdata_cpy, min_nbor_dist

    def serialize(self, suffix: str = "") -> dict:
        if hasattr(self, "type_embedding"):
            raise NotImplementedError("hybrid + type embedding is not supported")
        return {
            "@class": "Descriptor",
            "type": "hybrid",
            "@version": 1,
            "list": [
                descrpt.serialize(suffix=f"{suffix}_{idx}")
                for idx, descrpt in enumerate(self.descrpt_list)
            ],
        }

    @classmethod
    def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid":
        data = data.copy()
        class_name = data.pop("@class")
        assert class_name == "Descriptor"
        class_type = data.pop("type")
        assert class_type == "hybrid"
        check_version_compatibility(data.pop("@version"), 1, 1)
        obj = cls(
            list=[
                Descriptor.deserialize(ii, suffix=f"{suffix}_{idx}")
                for idx, ii in enumerate(data["list"])
            ],
        )
        # search for type embedding
        for ii in obj.descrpt_list:
            if hasattr(ii, "type_embedding"):
                raise NotImplementedError("hybrid + type embedding is not supported")
        return obj

    def get_dim_rot_mat_1(self) -> int:
        """Returns the first dimension of the rotation matrix. The rotation is of shape
        dim_1 x 3.

        Returns
        -------
        int
            the first dimension of the rotation matrix
        """
        return sum([ii.get_dim_rot_mat_1() for ii in self.descrpt_list])

    def get_rot_mat(self) -> tf.Tensor:
        """Get rotational matrix."""
        all_rot_mat = []
        for ii in self.descrpt_list:
            all_rot_mat.append(ii.get_rot_mat())
        return tf.concat(all_rot_mat, axis=2)
