from typing import Callable, Dict, Type

from .blocks import (
    ComplexAgnosticNonLinearResidualInteractionBlock,
    RadialBesselEmbeddingBlock,
    LinearNodeEmbeddingBlock,
    NonLinearBlock,
    InteractionBlock,
    LinearReadoutBlock,
    EquivariantProductBasisBlock,
    ComplexAgnosticResidualInteractionBlock,
    NonLinearReadoutBlock,
    RadialLorentzianEmbeddingBlock,
    SimpleRadialEmbeddingBlock,
)
from .loss import ClassificationLoss
from .radial import BesselBasis, LorentzianBasis, PolynomialCutoff
from .irreps_tools import tp_out_irreps_with_instructions, linear_out_irreps
from .spherical_harmonics import SphericalHarmonics
from .utils import minkowski_norm, compute_avg_num_neighbors
from .models import (
    LorentzBOTNet,
    SingleReadoutModel,
    LorentzMACEModel,
    LorentzMACELayer,
)

interaction_classes: Dict[str, Type[InteractionBlock]] = {
    "ComplexAgnosticResidualInteractionBlock": ComplexAgnosticResidualInteractionBlock,
    "ComplexAgnosticNonLinearResidualInteractionBlock": ComplexAgnosticNonLinearResidualInteractionBlock,
}

basis_classes: Dict[str, Callable] = {
    "RadialBesselEmbeddingBlock": RadialBesselEmbeddingBlock,
    "RadialLorentzianEmbeddingBlock": RadialLorentzianEmbeddingBlock,
    "SimpleRadialEmbeddingBlock": SimpleRadialEmbeddingBlock,
}

__all__ = [
    "tp_out_irreps_with_instructions",
    "linear_out_irreps",
    "SphericalHarmonics",
    "PolynomialCutoff",
    "BesselBasis",
    "EnergyForcesLoss" "NonLinearReadoutBlock",
    "NonLinearReadoutBlock",
    "EquivariantProductBasisBlock",
    "LinearReadoutBlock",
    "RadialEmbeddingBlock",
    "LinearNodeEmbeddingBlock",
    "NonLinearBlock",
    "LorentzBOTNet",
    "ClassificationLoss",
    "compute_avg_num_neighbors",
    "minkowski_norm",
    "LorentzianBasis",
    "SingleReadoutModel",
    "LorentzMACEModel",
    "LorentzMACELayer",
]
