from typing import Callable

import e3nn_jax as e3nn
import flax.linen as nn
import jax
import jax.numpy as jnp

from egxc.utils.typing import PRECISION, FloatNxF

from .encoder import EncoderCache
from .mlp import MLP


class NumericDecoder(nn.Module):
    spatial_feature_dim: int  # output spatial feature dimension
    activation: Callable[[jax.Array], jax.Array] = nn.silu

    @nn.compact
    def __call__(
        self,
        atom_features: e3nn.IrrepsArray,
        cache: EncoderCache,
    ) -> FloatNxF:
        """
        Assumes that atom_features of padded atoms are constant zero.
        Nuclei partitioning is handled in the encoder/included in the radial basis values.
        Handles variable per-L multiplicities in `atom_features` by first projecting
        each irrep block to a fixed per-L multiplicity = `spatial_feature_dim`
        using `e3nn.flax.Linear` (no mixing across different irreps).
        """
        N, truncated_idx, spherical_harmonics, radial_basis_vals = cache

        _, _, RBF = radial_basis_vals.shape
        F_out = self.spatial_feature_dim

        inv_features = atom_features.filter('0e').array  # type: ignore

        # Project each irrep block (mul_l x irrep_l) -> (F_out x irrep_l)
        projected_irreps = e3nn.Irreps(
            ' + '.join([f'{F_out}x{ir}' for _, ir in atom_features.irreps])
        )
        atom_features_proj = e3nn.flax.Linear(projected_irreps)(atom_features)

        spatial_features = jnp.zeros((N, F_out))
        sph_irreps_no_mul = [ir for _, ir in spherical_harmonics.irreps]
        feat_irreps_no_mul = [ir for _, ir in atom_features_proj.irreps]
        assert sph_irreps_no_mul == feat_irreps_no_mul, (
            'Irreps mismatch between encoder spherical harmonics and atom_features; '
            f'encoder={spherical_harmonics.irreps}, features={atom_features_proj.irreps}'
        )
        for sph_h, atom_feats_h in zip(
            spherical_harmonics.chunks,
            atom_features_proj.chunks,
        ):
            sph_h = sph_h.squeeze(-2)  # type: ignore
            # MLP maps rotation invariant features -> radial weights for each output feature.
            rbf_to_f = MLP(
                [
                    RBF * F_out,
                    RBF * F_out,
                    RBF * F_out,
                ],
                activation=self.activation,  # type: ignore
                dtype=PRECISION.decoding,
            )(inv_features).reshape(-1, RBF, F_out) / jnp.sqrt(RBF)

            # A / a: atom,  T / t: truncated grid points, F / f: spatial features
            # M / m: magnetic quantum number, R / r: radial basis functions
            sparse_decoded_spatial_feats = jnp.einsum(
                'atr,atm,afm,arf->tf',
                radial_basis_vals,
                sph_h,
                atom_feats_h,  # masked out atoms have constant zero features
                rbf_to_f,
            )
            spatial_features = spatial_features.at[truncated_idx].add(
                sparse_decoded_spatial_feats
            )

        return spatial_features
