from typing import Tuple

import e3nn_jax as e3nn
import flax.linen as nn
import jax
import jax.numpy as jnp

from egxc.utils import linalg
from egxc.utils.typing import (
    BoolA,
    Float1,
    FloatAx3,
    FloatAxA,
    FloatAxAx3,
    FloatAxAx3xF,
    FloatAxF,
    FloatAxFx3,
)

from .gnn import BaseGNN, GraphReadout


def cosine_cutoff(x: jax.Array, cutoff: float) -> jax.Array:
    """Behler-style cosine cutoff function"""
    x = 0.5 * (jnp.cos((jnp.pi / cutoff) * x) + 1)
    return jnp.where(x < cutoff, x, 0)


class ScalarFilter(nn.Module):
    cutoff_dist: float
    atom_features: int
    n_radial_basis_fn: int

    def setup(self) -> None:
        n = jnp.arange(1, self.n_radial_basis_fn + 1)
        self.prefactors = n * jnp.pi / self.cutoff_dist

    @nn.compact
    def __call__(self, x: FloatAxA) -> FloatAxAx3xF:
        """
        input: interatomic distances shape (N_atoms, N_atoms)
        output: scalar filter shape (N_atoms, N_atoms, 3 * n_features)
        """
        x = jnp.sin(
            x[..., None] * self.prefactors[None, None, :]
        )  # shape (N_atoms, N_atoms, n_basis)
        x = nn.Dense(3 * self.atom_features)(x)
        return x


class MessageBlock(nn.Module):
    atom_features: int
    cutoff_dist: float
    n_radial_basis_fn: int

    @nn.compact
    def __call__(
        self, s: FloatAxF, v: FloatAxFx3, dr: FloatAxAx3, atom_mask: BoolA
    ) -> Tuple[FloatAxF, FloatAxFx3]:
        """
        Computes messages between atoms.
        Feature-wise updates across atoms.

        Args:
            s: scalar atom features
            v: equivariant atom features
            dr: distance vectors between atoms
            atom_mask: padding mask for atoms

        Returns:
            ds_msg: scalar messages
            dv_msg: equivariant messages
        """
        distances = linalg.safe_norm(dr)

        phi = nn.Dense(self.atom_features)(s)
        phi = nn.silu(phi)
        phi = nn.Dense(3 * self.atom_features)(phi)
        f_cut = cosine_cutoff(distances, self.cutoff_dist)  # shape (N_atoms, N_atoms)
        W = (
            ScalarFilter(self.cutoff_dist, self.atom_features, self.n_radial_basis_fn)(
                distances
            )
            * f_cut[..., None]
        )
        msg = jnp.einsum(
            'jf, ijf, j -> ijf', phi, W, atom_mask
        )  # shape (N_atoms, N_atoms, 3 * n_features)

        # scalar messages
        ds_msg = msg[..., : self.atom_features].sum(axis=1)

        # equivariant messages
        msg_vv = msg[..., self.atom_features : 2 * self.atom_features]
        msg_vs = msg[..., 2 * self.atom_features :]
        e_r_save = dr / (distances + 1e-9)[..., None]
        dv_msg = jnp.einsum('jfv,  ijf -> ifv', v, msg_vv) + jnp.einsum(
            'ijv,  ijf -> ifv', e_r_save, msg_vs
        )
        return ds_msg, dv_msg


class UpdateBlock(nn.Module):
    atom_features: int

    @nn.compact
    def __call__(self, s: FloatAxF, v: FloatAxFx3) -> Tuple[FloatAxF, FloatAxFx3]:
        """
        Updates the atom features.
        Atom-wise updates across features.

        Args:
            s: scalar atom features
            v: equivariant atom features

        Returns:
            ds_up: scalar updates
            dv_up: equivariant updates
        """
        # learnable linear combinations of equivariant vectors
        Vv = nn.Einsum(
            (self.atom_features, self.atom_features), 'ifv, gf -> igv', use_bias=False
        )(v)  # shape (N_atoms, n_features, 3)
        Uv = nn.Einsum(
            (self.atom_features, self.atom_features), 'ifv, gf -> igv', use_bias=False
        )(v)  # shape (N_atoms, n_features, 3)
        norm_Vv = linalg.safe_norm(Vv)
        scalar_prod_Vv_Uv = (Vv * Uv).sum(axis=-1)

        scalar_features = jnp.concatenate(
            [s, norm_Vv], axis=-1
        )  # shape (N_atoms, 2 * n_features)
        atm_rep = nn.Dense(self.atom_features)(scalar_features)
        atm_rep = nn.silu(atm_rep)
        atm_rep = nn.Dense(3 * self.atom_features)(
            atm_rep
        )  # shape (N_atoms, 3 * n_features)

        ds_up = (
            atm_rep[..., : self.atom_features]
            + atm_rep[..., self.atom_features : 2 * self.atom_features]
            * scalar_prod_Vv_Uv
        )
        dv_up = atm_rep[..., 2 * self.atom_features :, None] * Uv
        return ds_up, dv_up


class PaiNN(BaseGNN):
    """
    Implementation of polarizable atom interaction neural network by Schütt et al.
    https://doi.org/10.48550/arXiv.2102.03150

    TODO: check if layer norm is needed
    """

    n_radial_basis: int = 20

    @nn.compact
    def __call__(
        self,
        atom_features: e3nn.IrrepsArray,  # (A, Irreps[RBF, (l,m)])  with m,l as in Y_{l,m}
        atom_pos: FloatAx3,
        atom_mask: BoolA,
    ) -> Tuple[
        Float1, e3nn.IrrepsArray
    ]:  # (A, Irreps[F_out, (l,m)])  with m,l as in Y_{l,m}
        s, v = self.convert_irreps_to_scalar_and_vector(atom_features)
        F = s.shape[-1]
        # preprocessing
        dr = atom_pos[:, None, :] - atom_pos[None, :, :]
        s = nn.LayerNorm()(s)
        v_norms = linalg.safe_norm(v)
        v = v * (nn.LayerNorm()(v_norms) / (v_norms + 1e-9))[..., None]
        # apply message passing and node updates
        for _ in range(self.layers):
            ds_msg, dv_msg = MessageBlock(F, self.message_cutoff, self.n_radial_basis)(
                s, v, dr, atom_mask
            )

            s += ds_msg
            s = nn.LayerNorm()(s)
            v += dv_msg
            v_norms = linalg.safe_norm(v)
            v = v * (nn.LayerNorm()(v_norms) / (v_norms + 1e-9))[..., None]

            ds_up, dv_up = UpdateBlock(F)(s, v)
            s += ds_up
            s = nn.LayerNorm()(s)
            v += dv_up
            v_norms = linalg.safe_norm(v)
            v = v * (nn.LayerNorm()(v_norms) / (v_norms + 1e-9))[..., None]

        # convert back to e3nn format
        atom_features_out = self.convert_to_irreps(s, v)

        graph_readout = GraphReadout(
            self.energy_graph_readout_hidden_dims, self.init_graph_readout_to_zero
        )(atom_features_out)

        return (
            (graph_readout * atom_mask[:, None]).sum(),
            atom_features_out * atom_mask[:, None],
        )  # scalar, (A, F, (l,m))  with m,l as in Y_{l,m}

    @nn.compact
    def convert_irreps_to_scalar_and_vector(
        self, node_features: e3nn.IrrepsArray
    ) -> Tuple[FloatAxF, FloatAxFx3]:
        irreps = e3nn.Irreps(self.irreps_str)
        node_features = e3nn.flax.Linear(irreps)(node_features)
        node_features = (
            node_features.mul_to_axis()
        )  # (A, (F,l,m)) Fx0e + Fx1o -> (A, F, (l,m)) 1x0e + 1x1o
        s = node_features.filter('0e').array[..., 0]  # type: ignore
        v = node_features.filter('1o').array  # type: ignore
        assert s.shape[-1] == v.shape[-2], 'feature dimension mismatch'
        return s, v

    @nn.compact
    def convert_to_irreps(
        self, s: FloatAxF, v: FloatAxFx3
    ) -> e3nn.IrrepsArray:  # (A, F, (l,m))  with m,l as in Y_{l,m}
        s = e3nn.IrrepsArray('0e', s[..., None])  # type: ignore
        v = e3nn.IrrepsArray('1o', v)  # type: ignore
        out = e3nn.concatenate([s, v], axis=-1).axis_to_mul()  # type: ignore
        out = e3nn.flax.Linear(e3nn.Irreps(self.output_irreps_str))(out)
        return out
