# SPDX-License-Identifier: LGPL-3.0-or-later
"""The model that takes the coordinates, cell and atom types as input
and predicts some property. The models are automatically generated from
atomic models by the `deepmd.dpmodel.make_model` method.

The `make_model` method does the reduction, auto-differentiation and
communication of the atomic properties according to output variable
definition `deepmd.dpmodel.OutputVariableDef`.

All models should be inherited from :class:`deepmd.pt.model.model.model.BaseModel`.
Models generated by `make_model` have already done it.
"""

import copy
import json
from typing import (
    Optional,
)

import numpy as np

from deepmd.pt.model.atomic_model import (
    DPAtomicModel,
    PairTabAtomicModel,
)
from deepmd.pt.model.descriptor.base_descriptor import (
    BaseDescriptor,
)
from deepmd.pt.model.task import (
    BaseFitting,
)
from deepmd.utils.spin import (
    Spin,
)

from .dipole_model import (
    DipoleModel,
)
from .dos_model import (
    DOSModel,
)
from .dp_linear_model import (
    LinearEnergyModel,
)
from .dp_model import (
    DPModelCommon,
)
from .dp_zbl_model import (
    DPZBLModel,
)
from .ener_model import (
    EnergyModel,
)
from .frozen import (
    FrozenModel,
)
from .make_hessian_model import (
    make_hessian_model,
)
from .make_model import (
    make_model,
)
from .model import (
    BaseModel,
)
from .polar_model import (
    PolarModel,
)
from .property_model import (
    PropertyModel,
)
from .spin_model import (
    SpinEnergyModel,
    SpinModel,
)

from .pretrain_descriptor_model import (
    PretrainDescriptorModel,
)


def _get_standard_model_components(model_params, ntypes):
    if "type_embedding" in model_params:
        raise ValueError(
            "In the PyTorch backend, type_embedding is not at the model level, but within the descriptor. See type embedding documentation for details."
        )
    # descriptor
    model_params["descriptor"]["ntypes"] = ntypes
    model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"])
    descriptor = BaseDescriptor(**model_params["descriptor"])
    # fitting
    fitting_net = model_params.get("fitting_net", {})
    fitting_net["type"] = fitting_net.get("type", "ener")
    fitting_net["ntypes"] = descriptor.get_ntypes()
    fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
    fitting_net["mixed_types"] = descriptor.mixed_types()
    if fitting_net["type"] in ["dipole", "polar"]:
        fitting_net["embedding_width"] = descriptor.get_dim_emb()
    fitting_net["dim_descrpt"] = descriptor.get_dim_out()
    grad_force = "direct" not in fitting_net["type"]
    if not grad_force:
        fitting_net["out_dim"] = descriptor.get_dim_emb()
        if "ener" in fitting_net["type"]:
            fitting_net["return_energy"] = True
    fitting = BaseFitting(**fitting_net)
    return descriptor, fitting, fitting_net["type"]


def get_spin_model(model_params):
    model_params = copy.deepcopy(model_params)
    if not model_params["spin"]["use_spin"] or isinstance(
        model_params["spin"]["use_spin"][0], int
    ):
        use_spin = np.full(len(model_params["type_map"]), False, dtype=bool)
        use_spin[model_params["spin"]["use_spin"]] = True
        model_params["spin"]["use_spin"] = use_spin.tolist()
    # include virtual spin and placeholder types
    model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]]
    spin = Spin(
        use_spin=model_params["spin"]["use_spin"],
        virtual_scale=model_params["spin"]["virtual_scale"],
    )
    pair_exclude_types = spin.get_pair_exclude_types(
        exclude_types=model_params.get("pair_exclude_types", None)
    )
    model_params["pair_exclude_types"] = pair_exclude_types
    # for descriptor data stat
    model_params["descriptor"]["exclude_types"] = pair_exclude_types
    atom_exclude_types = spin.get_atom_exclude_types(
        exclude_types=model_params.get("atom_exclude_types", None)
    )
    model_params["atom_exclude_types"] = atom_exclude_types
    if (
        "env_protection" not in model_params["descriptor"]
        or model_params["descriptor"]["env_protection"] == 0.0
    ):
        model_params["descriptor"]["env_protection"] = 0.01
    if model_params["descriptor"]["type"] in ["se_e2_a"]:
        # only expand sel for se_e2_a
        model_params["descriptor"]["sel"] += model_params["descriptor"]["sel"]
    backbone_model = get_standard_model(model_params)
    return SpinEnergyModel(backbone_model=backbone_model, spin=spin)


def get_linear_model(model_params):
    model_params = copy.deepcopy(model_params)
    weights = model_params.get("weights", "mean")
    list_of_models = []
    ntypes = len(model_params["type_map"])
    for sub_model_params in model_params["models"]:
        if "descriptor" in sub_model_params:
            # descriptor
            sub_model_params["descriptor"]["ntypes"] = ntypes
            descriptor, fitting, _ = _get_standard_model_components(
                sub_model_params, ntypes
            )
            list_of_models.append(
                DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
            )

        else:  # must be pairtab
            assert (
                "type" in sub_model_params and sub_model_params["type"] == "pairtab"
            ), "Sub-models in LinearEnergyModel must be a DPModel or a PairTable Model"
            list_of_models.append(
                PairTabAtomicModel(
                    sub_model_params["tab_file"],
                    sub_model_params["rcut"],
                    sub_model_params["sel"],
                    type_map=model_params["type_map"],
                )
            )

    atom_exclude_types = model_params.get("atom_exclude_types", [])
    pair_exclude_types = model_params.get("pair_exclude_types", [])
    return LinearEnergyModel(
        models=list_of_models,
        type_map=model_params["type_map"],
        weights=weights,
        atom_exclude_types=atom_exclude_types,
        pair_exclude_types=pair_exclude_types,
    )


def get_zbl_model(model_params):
    model_params = copy.deepcopy(model_params)
    ntypes = len(model_params["type_map"])
    descriptor, fitting, _ = _get_standard_model_components(model_params, ntypes)
    dp_model = DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
    # pairtab
    filepath = model_params["use_srtab"]
    pt_model = PairTabAtomicModel(
        filepath,
        descriptor.get_rcut(),
        descriptor.get_sel(),
        type_map=model_params["type_map"],
    )

    rmin = model_params["sw_rmin"]
    rmax = model_params["sw_rmax"]
    atom_exclude_types = model_params.get("atom_exclude_types", [])
    pair_exclude_types = model_params.get("pair_exclude_types", [])
    model = DPZBLModel(
        dp_model,
        pt_model,
        rmin,
        rmax,
        type_map=model_params["type_map"],
        atom_exclude_types=atom_exclude_types,
        pair_exclude_types=pair_exclude_types,
    )
    model.model_def_script = json.dumps(model_params)
    return model


def _can_be_converted_to_float(value) -> Optional[bool]:
    try:
        float(value)
        return True
    except (TypeError, ValueError):
        # return false for any failure...
        return False


def _convert_preset_out_bias_to_array(preset_out_bias, type_map):
    if preset_out_bias is not None:
        for kk in preset_out_bias:
            if len(preset_out_bias[kk]) != len(type_map):
                raise ValueError(
                    "length of the preset_out_bias should be the same as the type_map"
                )
            for jj in range(len(preset_out_bias[kk])):
                if preset_out_bias[kk][jj] is not None:
                    if isinstance(preset_out_bias[kk][jj], list):
                        bb = preset_out_bias[kk][jj]
                    elif _can_be_converted_to_float(preset_out_bias[kk][jj]):
                        bb = [float(preset_out_bias[kk][jj])]
                    else:
                        raise ValueError(
                            f"unsupported type/value of the {jj}th element of "
                            f"preset_out_bias['{kk}'] "
                            f"{type(preset_out_bias[kk][jj])}"
                        )
                    preset_out_bias[kk][jj] = np.array(bb)
    return preset_out_bias


def get_standard_model(model_params):
    model_params_old = model_params
    model_params = copy.deepcopy(model_params)
    ntypes = len(model_params["type_map"])
    descriptor, fitting, fitting_net_type = _get_standard_model_components(
        model_params, ntypes
    )
    atom_exclude_types = model_params.get("atom_exclude_types", [])
    pair_exclude_types = model_params.get("pair_exclude_types", [])
    preset_out_bias = model_params.get("preset_out_bias")
    preset_out_bias = _convert_preset_out_bias_to_array(
        preset_out_bias, model_params["type_map"]
    )
    training_stage = model_params.get("stage", [])
    
    if training_stage == "pretrain":
        modelcls = PretrainDescriptorModel
    else:
        if fitting_net_type == "dipole":
            modelcls = DipoleModel
        elif fitting_net_type == "polar":
            modelcls = PolarModel
        elif fitting_net_type == "dos":
            modelcls = DOSModel
        elif fitting_net_type in ["ener", "direct_force_ener"]:
            modelcls = EnergyModel
        elif fitting_net_type == "property":
            modelcls = PropertyModel
        else:
            raise RuntimeError(f"Unknown fitting type: {fitting_net_type}")

    model = modelcls(
        descriptor=descriptor,
        fitting=fitting,
        type_map=model_params["type_map"],
        atom_exclude_types=atom_exclude_types,
        pair_exclude_types=pair_exclude_types,
        preset_out_bias=preset_out_bias,
    )
    model.model_def_script = json.dumps(model_params_old)
    return model


def get_model(model_params):
    
    model_type = model_params.get("type", "standard")
    
    if model_type == "standard":
        if "spin" in model_params:
            return get_spin_model(model_params)
        elif "use_srtab" in model_params:
            return get_zbl_model(model_params)
        else:
            return get_standard_model(model_params)
    elif model_type == "linear_ener":
        return get_linear_model(model_params)
    else:
       
        return BaseModel.get_class_by_type(model_type).get_model(model_params)


__all__ = [
    "BaseModel",
    "get_model",
    "DPModelCommon",
    "EnergyModel",
    "DipoleModel",
    "PolarModel",
    "DOSModel",
    "FrozenModel",
    "SpinModel",
    "SpinEnergyModel",
    "DPZBLModel",
    "make_model",
    "make_hessian_model",
    "LinearEnergyModel",
    "PretrainDescriptorModel"
]
