from typing import Literal, Tuple

import e3nn_jax as e3nn
import flax.linen as nn
import jax
import jax.numpy as jnp
import jaxtyping

from egxc.utils import linalg
from egxc.utils.typing import (
    PRECISION,
    BoolA,
    FloatAx3,
    FloatAxN,
    FloatAxNxRBF,
    FloatN,
    FloatNx3,
    RBFType,
)

from .radial_basis_fns import radial_basis_values

IntT = jaxtyping.Int[jaxtyping.Array, 'T']
EncoderCache = Tuple[int, IntT, e3nn.IrrepsArray, FloatAxNxRBF]

EPSILON = 1e-15


class NumericEncoder(nn.Module):
    """
    Module that encodes the electron density on to a nuclei-centered point cloud.
    This module is independent of the underlying basis set, and purely quadrature based.
    """

    irreps_str: str
    cutoff: float  # TODO: check units, NOTE: This cutoff is independent of the GNN cutoff
    num_radial_filters: int  # RBF dimension
    radial_basis_type: RBFType = 'trigonometric'
    nuclei_partitioning: Literal[None, 'Exponential', 'Gaussian'] = None
    _quadrature_points_per_atom_scaling: int = 8

    def setup(self) -> None:
        if self.nuclei_partitioning is not None:
            self.sigma = self.param(  # TODO: should this be a learnable parameter?
                'partitioning_smoothness',
                jax.nn.initializers.ones,
                (),
                PRECISION.quadrature,
            )

            def density_partitioning(dist: FloatAxN, atom_mask: BoolA) -> FloatAxN:
                if self.nuclei_partitioning == 'Gaussian':
                    claim = jnp.exp(-0.5 * dist**2 / (self.sigma**2))
                else:
                    # exponential
                    claim = jnp.exp(-dist / jnp.abs(self.sigma))
                claim = claim * atom_mask[:, None]  # mask out fake (padding) atoms
                normalizer = claim.sum(0, keepdims=True)
                share = claim / (normalizer + EPSILON)
                return share

            self.density_partitioning_fn = density_partitioning

    def __call__(
        self,
        atom_pos: FloatAx3,
        atom_mask: BoolA,
        grid_coords: FloatNx3,
        weights: FloatN,
        n: FloatN,
    ) -> Tuple[e3nn.IrrepsArray, EncoderCache]:
        """
        TODO: Should we include other features like s, xi, tau in the embedding?
        """
        parsed_irreps = e3nn.Irreps(self.irreps_str)
        irreps = e3nn.Irreps([(1, ir) for _, ir in parsed_irreps])
        # change to units of cutoff
        grid_coords = grid_coords / self.cutoff
        atom_pos = atom_pos / self.cutoff

        # compute squared distances without allocating (A, N, 3)
        a2 = jnp.sum(atom_pos**2, axis=-1)[:, None]  # (A,1)
        g2 = jnp.sum(grid_coords**2, axis=-1)[None, :]  # (1,N)
        ag = atom_pos @ grid_coords.T  # (A,N)
        squared_dist = a2 + g2 - 2.0 * ag  # (A,N)

        # optimize calculations by using the distance based cutoff
        A, N = squared_dist.shape
        T = min(N, self._quadrature_points_per_atom_scaling * N // A)
        # pick nearest T grid points per atom (no full sort)
        _, truncated_idx = jax.lax.top_k(-squared_dist, k=T)  # (A,T)

        # gather selected grid points, then form displacement/dist only for them
        truncated_grid = grid_coords[truncated_idx]  # (A,T,3)
        displacement = atom_pos[:, None, :] - truncated_grid  # (A,T,3)
        dist = linalg.safe_norm(displacement, axis=-1)  # (A,T)
        n = n[truncated_idx]  # (A,T)
        weights = weights[truncated_idx]  # (A,T)

        radial_basis_vals = radial_basis_values(  # FloatAxNxRBF
            dist, 1.0, self.num_radial_filters, self.radial_basis_type
        )
        # apply nuclei wise partitioning of the quadrature points
        if self.nuclei_partitioning is not None:
            partitioning = self.density_partitioning_fn(dist, atom_mask)
            radial_basis_vals *= partitioning[..., None]
        # directions = diff / (dist[..., None] + EPSILON)  # FloatAxNx3
        spherical_harmonics = e3nn.spherical_harmonics(
            irreps, displacement, normalize=True, normalization='norm'
        )
        atom_features = jnp.einsum(
            'atr,ath,at,at->arh',
            radial_basis_vals,
            spherical_harmonics.array,
            n,
            weights,
        )  # FloatAxRBFxH
        atom_features = e3nn.IrrepsArray(
            irreps, atom_features.astype(PRECISION.gnn)
        ).axis_to_mul()  # FloatAx(Irreps(RBFxH))

        return atom_features, (
            N,
            truncated_idx,
            spherical_harmonics.astype(PRECISION.decoding),
            radial_basis_vals.astype(PRECISION.decoding),
        )
