# An architecture pulling the best out of
# - NEQUIP simplicity
# - MACE polynomial structure
# - ESCN performance

from typing import Callable, Optional, Union

import e3nn_jax as e3nn
import flax
import haiku as hk
import jax
import jax.numpy as jnp
from e3nn_jax.experimental.linear_shtp import LinearSHTP

from symphony import datatypes


class MarioNetteLayerFlax(flax.linen.Module):
    avg_num_neighbors: float
    num_species: int = 1
    output_irreps: e3nn.Irreps = 64 * e3nn.Irreps("0e + 1o + 2e")
    interaction_irreps: e3nn.Irreps = 64 * e3nn.Irreps("0e + 1o + 2e")
    soft_normalization: float = 1e5
    even_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.gelu
    odd_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.tanh
    mlp_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.gelu
    mlp_n_hidden: int = 64
    mlp_n_layers: int = 2
    n_radial_basis: int = 8
    use_bessel: bool = True

    @flax.linen.compact
    def __call__(
        self,
        vectors: e3nn.IrrepsArray,
        node_feats: e3nn.IrrepsArray,
        node_specie: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
    ):
        return _impl(
            e3nn.flax.Linear,
            e3nn.flax.MultiLayerPerceptron,
            self,
            vectors,
            node_feats,
            node_specie,
            senders,
            receivers,
        )


class MarioNetteLayerHaiku(hk.Module):
    def __init__(
        self,
        avg_num_neighbors: float,
        num_species: int = 1,
        output_irreps: e3nn.Irreps = 64 * e3nn.Irreps("0e + 1o + 2e"),
        interaction_irreps: e3nn.Irreps = 64 * e3nn.Irreps("0e + 1o + 2e"),
        soft_normalization: float = 1e5,
        even_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.gelu,
        odd_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.tanh,
        mlp_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.gelu,
        mlp_n_hidden: int = 64,
        mlp_n_layers: int = 2,
        n_radial_basis: int = 8,
        use_bessel: bool = True,
        name: Optional[str] = None,
    ):
        super().__init__(name)
        self.avg_num_neighbors = avg_num_neighbors
        self.num_species = num_species
        self.output_irreps = output_irreps
        self.interaction_irreps = interaction_irreps
        self.soft_normalization = soft_normalization
        self.even_activation = even_activation
        self.odd_activation = odd_activation
        self.mlp_activation = mlp_activation
        self.mlp_n_hidden = mlp_n_hidden
        self.mlp_n_layers = mlp_n_layers
        self.n_radial_basis = n_radial_basis
        self.use_bessel = use_bessel

    def __call__(
        self,
        vectors: e3nn.IrrepsArray,
        node_feats: e3nn.IrrepsArray,
        node_specie: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
    ):
        return _impl(
            e3nn.haiku.Linear,
            e3nn.haiku.MultiLayerPerceptron,
            self,
            vectors,
            node_feats,
            node_specie,
            senders,
            receivers,
        )


def _impl(
    Linear: Callable,
    MultiLayerPerceptron: Callable,
    self: Union[MarioNetteLayerFlax, MarioNetteLayerHaiku],
    vectors: e3nn.IrrepsArray,  # [n_edges, 3]
    node_feats: e3nn.IrrepsArray,  # [n_nodes, irreps]
    node_specie: jnp.ndarray,  # [n_nodes] int between 0 and num_species-1
    senders: jnp.ndarray,  # [n_edges]
    receivers: jnp.ndarray,  # [n_edges]
):
    n_edge = vectors.shape[0]
    n_node = node_feats.shape[0]
    assert vectors.shape == (n_edge, 3)
    assert node_feats.shape == (n_node, node_feats.irreps.dim)
    assert node_specie.shape == (n_node,)
    assert senders.shape == (n_edge,)
    assert receivers.shape == (n_edge,)

    interaction_irreps = e3nn.Irreps(self.interaction_irreps)
    output_irreps = e3nn.Irreps(self.output_irreps)

    # Self connection
    self_connection = Linear(
        output_irreps, num_indexed_weights=self.num_species, name="skip_tp"
    )(
        node_specie, node_feats
    )  # [n_nodes, output_irreps]

    node_feats = Linear(node_feats.irreps, name="linear_up")(node_feats)

    messages = node_feats[senders]

    conv = LinearSHTP(interaction_irreps, mix=False)
    w_unused = conv.init(jax.random.PRNGKey(0), messages[0], vectors[0])
    w_unused_flat = flatten(w_unused)

    # Radial part
    lengths = e3nn.norm(vectors).array  # [n_edges, 1]
    radial = e3nn.soft_envelope(lengths)  # [n_edges, 1]
    if self.use_bessel:
        radial = (
            e3nn.bessel(lengths[:, 0], self.n_radial_basis) * radial
        )  # [n_edges, n_radial_basis]

    mix = MultiLayerPerceptron(
        self.mlp_n_layers * (self.mlp_n_hidden,) + (w_unused_flat.size,),
        self.mlp_activation,
        output_activation=False,
    )(radial)

    # Discard 0 length edges that come from graph padding
    mix = jnp.where(lengths == 0.0, 0.0, mix)

    # vmap over edges
    w = jax.vmap(unflatten, (0, None))(mix, w_unused)
    messages = jax.vmap(conv.apply)(w, messages, vectors)
    assert messages.shape == (n_edge, messages.irreps.dim)

    # Message passing
    zeros = e3nn.IrrepsArray.zeros(
        messages.irreps, node_feats.shape[:1], messages.dtype
    )
    node_feats = zeros.at[receivers].add(messages)  # [n_nodes, irreps]
    node_feats = node_feats / jnp.sqrt(self.avg_num_neighbors)

    node_feats = Linear(interaction_irreps, name="linear_down")(node_feats)

    # Activation
    node_feats = activation(node_feats, self.even_activation, self.odd_activation)
    node_feats = Linear(output_irreps, name="linear_out")(node_feats)

    # Soft normalization
    node_feats = soft_normalization(node_feats, self.soft_normalization)

    node_feats = 0.9 * self_connection + 0.45 * node_feats  # [n_nodes, irreps]

    assert node_feats.irreps == output_irreps
    assert node_feats.shape == (n_node, output_irreps.dim)
    return node_feats


def activation(
    x: e3nn.IrrepsArray, even_activation, odd_activation
) -> e3nn.IrrepsArray:
    x = e3nn.scalar_activation(x, even_act=even_activation, odd_act=odd_activation)
    x = e3nn.concatenate([x, e3nn.tensor_square(x.mul_to_axis()).axis_to_mul()])
    return x


def soft_normalization(x: e3nn.IrrepsArray, max_norm: float = 1.0) -> e3nn.IrrepsArray:
    def phi(n):
        n = n / max_norm
        return 1.0 / (1.0 + n * e3nn.sus(n))

    return e3nn.norm_activation(x, [phi] * len(x.irreps))


def flatten(w):
    return jnp.concatenate([x.ravel() for x in jax.tree_util.tree_leaves(w)])


def unflatten(array, template):
    lst = []
    start = 0
    for x in jax.tree_util.tree_leaves(template):
        lst.append(array[start : start + x.size].reshape(x.shape))
        start += x.size
    return jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(template), lst)


class MarioNette(hk.Module):
    """Wrapper class for MarioNette."""

    def __init__(
        self,
        num_species: int,
        r_max: float,
        avg_num_neighbors: float,
        init_embedding_dims: int,
        output_irreps: str,
        soft_normalization: float,
        num_interactions: int,
        even_activation: Callable[[jnp.ndarray], jnp.ndarray],
        odd_activation: Callable[[jnp.ndarray], jnp.ndarray],
        mlp_activation: Callable[[jnp.ndarray], jnp.ndarray],
        mlp_n_hidden: int,
        mlp_n_layers: int,
        n_radial_basis: int,
        use_bessel: bool,
        alpha: float,
        alphal: float,
        name: Optional[str] = None,
    ):
        super().__init__(name=name)
        self.num_species = num_species
        self.r_max = r_max
        self.avg_num_neighbors = avg_num_neighbors
        self.init_embedding_dims = init_embedding_dims
        self.output_irreps = output_irreps
        self.soft_normalization = soft_normalization
        self.num_interactions = num_interactions
        self.even_activation = even_activation
        self.odd_activation = odd_activation
        self.mlp_activation = mlp_activation
        self.mlp_n_hidden = mlp_n_hidden
        self.mlp_n_layers = mlp_n_layers
        self.n_radial_basis = n_radial_basis
        self.use_bessel = use_bessel
        self.alpha = alpha
        self.alphal = alphal

    def __call__(
        self,
        graphs: datatypes.Fragments,
    ):
        relative_positions = (
            graphs.nodes.positions[graphs.receivers]
            - graphs.nodes.positions[graphs.senders]
        )
        relative_positions = relative_positions / self.r_max
        relative_positions = e3nn.IrrepsArray("1o", relative_positions)

        species = graphs.nodes.species
        node_feats = hk.Embed(self.num_species, self.init_embedding_dims)(species)
        node_feats = e3nn.IrrepsArray(f"{node_feats.shape[1]}x0e", node_feats)

        for _ in range(self.num_interactions):
            node_feats = MarioNetteLayerHaiku(
                avg_num_neighbors=self.avg_num_neighbors,
                num_species=self.num_species,
                output_irreps=self.output_irreps,
                interaction_irreps=self.output_irreps,
                soft_normalization=self.soft_normalization,
                even_activation=self.even_activation,
                odd_activation=self.odd_activation,
                mlp_activation=self.mlp_activation,
                mlp_n_hidden=self.mlp_n_hidden,
                mlp_n_layers=self.mlp_n_layers,
                n_radial_basis=self.n_radial_basis,
                use_bessel=self.use_bessel,
            )(relative_positions, node_feats, species, graphs.senders, graphs.receivers)

        alpha = self.alpha * (self.alphal ** jnp.array(node_feats.irreps.ls))
        node_feats = node_feats * alpha
        return node_feats
