from typing import Literal

import jax
import jax.numpy as jnp

from egxc.utils.typing import Bool, Float1, FloatN, NnParams, PyTree
from egxc.xc_energy.functionals.base import BaseEnergyFunctional
from egxc.xc_energy.functionals.learnable.nn import (
    BaseGNN,
    NumericDecoder,
    NumericEncoder,
    SpatialReweighting,
)

NonLocalGridFeatureMode = Literal[
    'local_only',
    'reweighting_with_mGGA_feats',
    'reweighting_without_mGGA_feats',
    'injection',
]


class EGXC(BaseEnergyFunctional):
    """
    A learnable exchange-correlation functional based on the
    Equivariant Graph Non-Local Exchange-Correlation (EG-XC)
    """

    local_model: BaseEnergyFunctional
    encoder: NumericEncoder
    gnn: BaseGNN
    non_local_grid_feature_mode: NonLocalGridFeatureMode

    decoder: NumericDecoder | None = None
    non_local_reweighting: SpatialReweighting | None = None
    graph_readout: bool = True

    is_graph_based = True

    def setup(self):
        if self.non_local_grid_feature_mode == 'local_only':
            assert self.decoder is None
        else:
            assert self.decoder is not None
            if 'reweighting' in self.non_local_grid_feature_mode:
                assert self.non_local_reweighting is not None
            elif self.non_local_grid_feature_mode == 'injection':
                assert hasattr(self.local_model, 'non_local_xc_energy_density')
            else:
                raise ValueError(
                    f'Unknown non-local grid feature mode: {self.non_local_grid_feature_mode}'
                )

    def __call__(
        self, weights: FloatN, *local_feats: FloatN, **non_local_kwargs: jax.Array
    ) -> Float1:
        n = local_feats[0]  # TODO: generalize to arbitrary number of local features
        nuc_pos = non_local_kwargs['nuc_pos']
        atom_mask = non_local_kwargs['atom_mask']
        grid_coords = non_local_kwargs['grid_coords']
        # Embedding
        atom_features, cache = self.encoder(nuc_pos, atom_mask, grid_coords, weights, n)
        # GNN
        e_graph_xc, atom_features = self.gnn(atom_features, nuc_pos, atom_mask)
        if not self.graph_readout:
            e_graph_xc = 0.0
        if self.non_local_grid_feature_mode == 'local_only':
            gamma = 1.0
            e_xc = self.local_model.xc_energy_density(*local_feats)
        else:
            non_local_grid_feats = self.decoder(atom_features, cache)  # type: ignore
            if 'reweighting' in self.non_local_grid_feature_mode:
                e_xc = self.local_model.xc_energy_density(*local_feats)
                reweighting_feats = non_local_grid_feats
                if self.non_local_grid_feature_mode == 'reweighting_with_mGGA_feats':
                    reweighting_feats = jnp.concatenate(
                        [reweighting_feats, jnp.stack(local_feats, axis=-1)], axis=-1
                    )
                gamma = self.non_local_reweighting(reweighting_feats)  # type: ignore
            elif self.non_local_grid_feature_mode == 'injection':
                gamma = 1.0
                e_xc = self.local_model.non_local_xc_energy_density(
                    non_local_grid_feats, *local_feats
                )
            else:
                raise ValueError(
                    f'Unknown non-local grid feature mode: {self.non_local_grid_feature_mode}'
                )
        # Compute final energy
        out = e_graph_xc + (weights * n * gamma * e_xc).sum()
        return out

    def graph_readout_decay_mask(self, params: NnParams) -> PyTree[Bool]:
        return self.gnn.graph_readout_decay_mask(params)
