import warnings
from typing import Sequence, Optional, Callable, Any

import math

import torch
import torch.nn as nn
from e3nn import o3

from nequip.data import AtomicDataDict
from nequip.nn import (
    SequentialGraphNetwork,
    ConvNetLayer,
    ApplyFactor,
)
from nequip.nn.embedding import (
    EdgeLengthNormalizer,
    BesselEdgeLengthEncoding,
    PolynomialCutoff,
    NodeTypeEmbed,
    SphericalHarmonicEdgeAttrs,
)

from modules import AuxdensityHeadForNequip

ATOM_PROPERTY_KEY = "ATOM_PROPERTY"

class NequipArch(nn.Module):
    def __init__(
        self,
        r_max: float,
        type_names: Sequence[str],
        # convnet params
        radial_mlp_depth: Sequence[int],
        radial_mlp_width: Sequence[int],
        feature_irreps_hidden: Sequence[str | o3.Irreps],
        # irreps and dims
        irreps_edge_sh: int | str | o3.Irreps,
        type_embed_num_features: int,
        # edge length encoding
        per_edge_type_cutoff: Optional[dict[str, float | dict[str, float]]] = None,
        num_bessels: int = 8,
        bessel_trainable: bool = False,
        polynomial_cutoff_p: int = 6,
        # edge sum normalization
        avg_num_neighbors: Optional[float] = None,
        # == things that generally shouldn't be changed ==
        # convnet
        convnet_resnet: bool = False,
        convnet_nonlinearity_type: str = "gate",
        convnet_nonlinearity_scalars: dict[int, Callable] = {"e": "silu", "o": "tanh"},
        convnet_nonlinearity_gates: dict[int, Callable] = {"e": "silu", "o": "tanh"},
        task_head_specs: dict[str, Any] = {},

        auxbasis: str = 'def2-universal-jfit',
    ):
        super().__init__()

        self.type_names = type_names

        # === sanity checks and warnings ===
        assert all(
            tn.isalnum() for tn in type_names
        ), "`type_names` must contain only alphanumeric characters"

        # require every convnet layer to be specified explicitly in a list
        # infer num_layers from the list size
        assert (
            len(radial_mlp_depth) == len(radial_mlp_width) == len(feature_irreps_hidden)
        ), f"radial_mlp_depth: {radial_mlp_depth}, radial_mlp_width: {radial_mlp_width}, feature_irreps_hidden: {feature_irreps_hidden} should all have the same length"
        num_layers = len(radial_mlp_depth)

        if avg_num_neighbors is None:
            warnings.warn(
                "Found `avg_num_neighbors=None` -- it is recommended to set `avg_num_neighbors` for normalization and better numerics during training."
            )

        # === encode and embed features ===
        # == edge tensor embedding ==
        spharm = SphericalHarmonicEdgeAttrs(
            irreps_edge_sh=irreps_edge_sh,
            component_order='std',
        )
        # == edge scalar embedding ==
        edge_norm = EdgeLengthNormalizer(
            r_max=r_max,
            type_names=type_names,
            per_edge_type_cutoff=per_edge_type_cutoff,
            irreps_in=spharm.irreps_out,
        )
        bessel_encode = BesselEdgeLengthEncoding(
            num_bessels=num_bessels,
            trainable=bessel_trainable,
            cutoff=PolynomialCutoff(polynomial_cutoff_p),
            edge_invariant_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
            irreps_in=edge_norm.irreps_out,
        )
        # for backwards compatibility of NequIP's bessel encoding
        factor = ApplyFactor(
            in_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
            factor=(2 * math.pi) / (r_max * r_max),
            irreps_in=bessel_encode.irreps_out,
        )
        # == node scalar embedding ==
        type_embed = NodeTypeEmbed(
            type_names=type_names,
            num_features=type_embed_num_features,
            irreps_in=factor.irreps_out,
        )
        modules = {
            "spharm": spharm,
            "edge_norm": edge_norm,
            "bessel_encode": bessel_encode,
            "factor": factor,
            "type_embed": type_embed,
        }
        prev_irreps_out = type_embed.irreps_out

        # === convnet layers ===
        for layer_i in range(num_layers):
            current_convnet = ConvNetLayer(
                irreps_in=prev_irreps_out,
                feature_irreps_hidden=feature_irreps_hidden[layer_i],
                convolution_kwargs={
                    "radial_mlp_depth": radial_mlp_depth[layer_i],
                    "radial_mlp_width": radial_mlp_width[layer_i],
                    "avg_num_neighbors": avg_num_neighbors,
                    # to ensure isolated atom limit
                    "use_sc": layer_i != 0,
                },
                resnet=(layer_i != 0) and convnet_resnet,
                nonlinearity_type=convnet_nonlinearity_type,
                nonlinearity_scalars=convnet_nonlinearity_scalars,
                nonlinearity_gates=convnet_nonlinearity_gates,
            )
            prev_irreps_out = current_convnet.irreps_out
            modules.update({f"layer{layer_i}_convnet": current_convnet})

        # === assemble in SequentialGraphNetwork ===
        self.backbone = SequentialGraphNetwork(modules)

        # === readout ===
        self.backbone_irreps_out = prev_irreps_out
        self.task_head = SequentialGraphNetwork({
            'auxdensity_atom_readout': AuxdensityHeadForNequip(
                type_names=self.type_names,
                auxbasis=auxbasis,
                field=AtomicDataDict.NODE_FEATURES_KEY,
                out_field='output:auxdensity',
                biases=True,
                irreps_in=self.backbone_irreps_out,
            ),
        })

    def convert_inputs(self, inputs):
        ret = inputs.copy()
        ret.update({
            AtomicDataDict.ATOM_TYPE_KEY: inputs['z'],
            AtomicDataDict.POSITIONS_KEY: inputs['pos'],
            AtomicDataDict.EDGE_INDEX_KEY: inputs['edge_index'],
        })
        return ret

    def forward(self, data):
        data = self.convert_inputs(data)
        data = self.backbone(data)
        data = self.task_head(data)
        return data


def nequip_simple_builder(
    num_layers: int = 4,
    l_max: int = 1,
    parity: bool = True,
    num_features: int = 32,
    radial_mlp_depth: int = 2,
    radial_mlp_width: int = 64,
    **kwargs
) -> nn.Module:
    irreps_edge_sh = repr(
        o3.Irreps.spherical_harmonics(lmax=l_max, p=-1 if parity else 1)
    )
    feature_irreps_hidden = repr(
        o3.Irreps(
            [
                (num_features, (l, p))
                for p in ((1, -1) if parity else (1,))
                for l in range(l_max + 1)
            ]
        )
    )
    feature_irreps_hidden_list = [feature_irreps_hidden] * (num_layers - 1)
    radial_mlp_depth_list = [radial_mlp_depth] * num_layers
    radial_mlp_width_list = [radial_mlp_width] * num_layers

    feature_irreps_hidden_list += [feature_irreps_hidden]

    model = NequipArch(
        irreps_edge_sh=irreps_edge_sh,
        type_embed_num_features=num_features,
        feature_irreps_hidden=feature_irreps_hidden_list,
        radial_mlp_depth=radial_mlp_depth_list,
        radial_mlp_width=radial_mlp_width_list,
        **kwargs,
    )
    return model
