from typing import Callable, Tuple

import e3nn_jax as e3nn
import flax.linen as nn
import jax

from egxc.utils.typing import (
    PRECISION,
    Bool,
    BoolA,
    Float1,
    FloatAx3,
    NnParams,
    PyTree,
)

from .mlp import MLP


class GraphReadout(nn.Module):
    hidden_dims: Tuple[int, ...]
    init_last_layer_to_zero: bool
    activation: Callable[[jax.Array], jax.Array] = jax.nn.silu

    @nn.compact
    def __call__(self, node_features: e3nn.IrrepsArray):
        x = e3nn.flax.Linear(self.hidden_dims[0] * e3nn.Irreps('0e'))(node_features)
        if len(self.hidden_dims) > 1:
            x = self.activation(x.array)
            return MLP(
                self.hidden_dims[1:],
                activation=self.activation,
                dtype=PRECISION.graph_readout,
                init_last_layer_to_zero=self.init_last_layer_to_zero,
            )(x)
        else:
            return x.array


class BaseGNN(nn.Module):
    irreps_str: str
    output_irreps_str: str
    message_cutoff: float
    layers: int
    energy_graph_readout_hidden_dims: Tuple[int, ...]
    n_radial_basis: int  # number of RBF used in the message passing layer
    init_graph_readout_to_zero: bool = True

    def __call__(
        self,
        atom_features: e3nn.IrrepsArray,  # (A, Irreps[RBF, (l,m)])
        atom_pos: FloatAx3,
        atom_mask: BoolA,
    ) -> Tuple[Float1, e3nn.IrrepsArray]: ...

    def graph_readout_decay_mask(self, params: NnParams) -> PyTree[Bool]:
        """
        Return a PyTree[bool] mask: True where to apply weight decay.
        Targets only GraphReadout submodules' weights.
        """

        def should_decay(path) -> bool:
            # path like ('params','GraphReadout_0','Dense_0','kernel')
            in_graph_readout = any(
                str(seg).startswith(GraphReadout.__name__) for seg in path
            )
            return in_graph_readout

        return jax.tree_util.tree_map_with_path(
            lambda p, x: should_decay(p),
            params,
        )
