# Copied and cleaned up and extended from https://github.com/mariogeiger/nequip-jax/tree/main
import functools
from typing import Callable, Tuple

import e3nn_jax as e3nn
import flax.linen as nn
import jax
import jax.numpy as jnp
from e3nn_jax.experimental.linear_shtp import LinearSHTP

from egxc.utils.typing import BoolA, Float1, FloatAx3, FloatAxN, FloatAxNxRBF
from egxc.xc_energy.functionals.learnable.nn.radial_basis_fns import _trigonometric_rbf

from .gnn import BaseGNN, GraphReadout


def _default_radial_basis(r: FloatAxN, n: int) -> FloatAxNxRBF:
    """Default NequIP-style radial basis with the polynomial envelope."""
    return _trigonometric_rbf(
        r[..., None], n, envelope=e3nn.poly_envelope(5, 2), add_constant=True
    )


class _NEQUIPESCNLayerFlax(nn.Module):
    """
    Optimization of NequIP for large L using https://doi.org/10.48550/arXiv.2302.03655,
    with extra support of parity.
    """

    avg_num_neighbors: float
    irreps_out_str: e3nn.Irreps
    num_species: int = 1
    even_activation: Callable[[jax.Array], jax.Array] = jax.nn.silu
    odd_activation: Callable[[jax.Array], jax.Array] = jax.nn.tanh
    gate_activation: Callable[[jax.Array], jax.Array] = jax.nn.silu
    mlp_activation: Callable[[jax.Array], jax.Array] = jax.nn.silu
    mlp_n_hidden: int = 64
    mlp_n_layers: int = 2
    radial_basis: Callable[[jax.Array, int], FloatAxNxRBF] = _default_radial_basis
    n_radial_basis: int = 8

    @nn.compact
    def __call__(
        self,
        vectors: e3nn.IrrepsArray,
        node_feats: e3nn.IrrepsArray,
        node_species: jax.Array | None,
        senders: jax.Array,
        receivers: jax.Array,
    ):
        return _impl(
            e3nn.flax.Linear,
            e3nn.flax.MultiLayerPerceptron,
            self,
            vectors,
            node_feats,
            node_species,
            senders,
            receivers,
        )


def _flatten(w):
    return jnp.concatenate([x.ravel() for x in jax.tree_util.tree_leaves(w)])


def _unflatten(array, template):
    lst = []
    start = 0
    for x in jax.tree_util.tree_leaves(template):
        lst.append(array[start : start + x.size].reshape(x.shape))
        start += x.size
    return jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(template), lst)


def _impl(
    Linear: Callable,
    MultiLayerPerceptron: Callable,
    self: _NEQUIPESCNLayerFlax,
    vectors: e3nn.IrrepsArray,  # [n_edges, 3]
    node_feats: e3nn.IrrepsArray,  # [n_nodes, irreps]
    node_species: jax.Array | None,  # [n_nodes] int between 0 and num_species-1
    senders: jax.Array,  # [n_edges]
    receivers: jax.Array,  # [n_edges]
):
    num_nodes = node_feats.shape[0]
    num_edges = vectors.shape[0]
    assert vectors.shape == (num_edges, 3)
    assert node_feats.shape == (num_nodes, node_feats.irreps.dim)
    assert senders.shape == (num_edges,)
    assert receivers.shape == (num_edges,)

    gate = functools.partial(
        e3nn.gate,
        even_act=self.even_activation,  # type: ignore
        odd_act=self.odd_activation,  # type: ignore
        even_gate_act=self.gate_activation,  # type: ignore
    )

    # we regroup the target irreps to make sure that gate activation
    # has the same irreps as the target
    output_irreps = e3nn.Irreps(self.irreps_out_str).regroup()

    # Message MLP
    gate_irreps = node_feats.irreps
    num_nonscalar = gate_irreps.filter(drop='0e + 0o').num_irreps  # type: ignore
    gate_irreps = gate_irreps + e3nn.Irreps(f'{num_nonscalar}x0e').simplify()
    messages = Linear(gate_irreps, name='linear_up')(node_feats)
    messages = Linear(node_feats.irreps, name='linear_up2')(gate(messages))
    messages = messages[senders]

    conv = LinearSHTP(output_irreps, mix=False)
    w_unused = conv.init(jax.random.PRNGKey(0), messages[0], vectors[0])
    w_unused_flat = _flatten(w_unused)

    # Radial part
    with jax.ensure_compile_time_eval():
        assert abs(self.mlp_activation(0.0)) < 1e-6  # type: ignore
    lengths = e3nn.norm(vectors).array
    mix = MultiLayerPerceptron(
        self.mlp_n_layers * (self.mlp_n_hidden,) + (w_unused_flat.size,),
        self.mlp_activation,
        output_activation=False,
    )(self.radial_basis(lengths[:, 0], self.n_radial_basis))

    # Discard 0 length edges that come from graph padding
    mix = jnp.where(lengths == 0.0, 0.0, mix)
    assert mix.shape == (num_edges, w_unused_flat.size)  # type: ignore

    w = jax.vmap(_unflatten, (0, None))(mix, w_unused)
    messages = jax.vmap(conv.apply)(w, messages, vectors)
    messages = messages.astype(node_feats.dtype)  # type: ignore
    assert messages.shape == (num_edges, messages.irreps.dim)

    # Skip connection
    irreps = output_irreps.filter(keep=messages.irreps)  # type: ignore
    num_nonscalar = irreps.filter(drop='0e + 0o').num_irreps  # type: ignore
    irreps = irreps + e3nn.Irreps(f'{num_nonscalar}x0e').simplify()

    skip = Linear(
        irreps,
        num_indexed_weights=self.num_species,
        name='skip_tp',
        force_irreps_out=True,
    )(node_species, node_feats)

    # Message passing
    node_feats = e3nn.scatter_sum(messages, dst=receivers, output_size=num_nodes)  # type: ignore
    node_feats = node_feats / jnp.sqrt(self.avg_num_neighbors).astype(jnp.float32)

    node_feats = Linear(irreps, name='linear_down')(node_feats)

    out = (node_feats + skip) / jnp.sqrt(2)
    assert node_feats.shape == (num_nodes, node_feats.irreps.dim)

    # Update MLP
    num_nonscalar = irreps.filter(drop='0e + 0o').num_irreps  # type: ignore
    irreps_b_gate = irreps + e3nn.Irreps(f'{num_nonscalar}x0e').simplify()
    node_feats = Linear(irreps_b_gate, name='linear_down2')(out)
    node_feats = Linear(irreps, name='linear_down3')(gate(node_feats))
    out = (out + node_feats) / jnp.sqrt(2)

    # node_feats = node_feats.mul_to_axis()
    return out


class NequIP(BaseGNN):
    avg_num_neighbors: float = 15  # computed on QM9(7) for 5 Angstrom cutoff

    @nn.compact
    def __call__(
        self,
        atom_features: e3nn.IrrepsArray,  # (A, Irreps[RBF, (l,m)])  with m,l as in Y_{l,m}
        atom_pos: FloatAx3,
        atom_mask: BoolA,
    ) -> Tuple[
        Float1, e3nn.IrrepsArray
    ]:  # energy, (A, Irreps[F_out, (l,m)])  with m,l as in Y_{l,m}
        # statically define senders and receivers
        A = atom_mask.shape[0]
        senders, receivers = jnp.nonzero(jnp.ones((A, A), dtype=bool), size=A**2)
        vectors = (atom_pos[:, None] - atom_pos) / self.message_cutoff
        # mask out the padding atoms and self-messages
        mask = atom_mask[:, None] & atom_mask[None, :] & ~jnp.eye(A, dtype=bool)
        vectors = jnp.where(mask[..., None], vectors, jnp.ones_like(vectors))
        vectors = e3nn.IrrepsArray('1o', vectors.reshape(-1, 3))
        # node_features = EquiLayerNorm()(node_features)  # TODO: do I need this?
        graph_readout = GraphReadout(
            self.energy_graph_readout_hidden_dims, self.init_graph_readout_to_zero
        )(atom_features)

        output_irreps = e3nn.Irreps(self.output_irreps_str)
        atom_features_out = e3nn.flax.Linear(output_irreps)(atom_features)
        for _ in range(self.layers):
            atom_features = _NEQUIPESCNLayerFlax(
                irreps_out_str=e3nn.Irreps(self.irreps_str),
                avg_num_neighbors=self.avg_num_neighbors,
                n_radial_basis=self.n_radial_basis,
            )(vectors, atom_features, None, senders, receivers)
            # node_features = EquiLayerNorm()(node_features)  # TODO: do I need this?
            graph_readout += GraphReadout(
                self.energy_graph_readout_hidden_dims, self.init_graph_readout_to_zero
            )(atom_features)
            atom_features_out += e3nn.flax.Linear(output_irreps)(atom_features)

        graph_readout = (graph_readout * atom_mask[:, None]).sum() / self.layers
        atom_features_out *= atom_mask[:, None] / self.layers
        return graph_readout, atom_features_out
