"""
Relevant references:
(1) Gould, T. A Step toward Density Benchmarking—The Energy-Relevant “Mean Field Error.” The Journal of Chemical Physics 2023, 159 (20), 204111. https://doi.org/10.1063/5.0175925.
"""

from dataclasses import dataclass
from functools import partial
from typing import Callable, Literal, TypeAlias, get_args

import jax
import jax.numpy as jnp

from egxc.solver.fock import get_coulomb_matrix_fn, mean_field_energy
from egxc.systems import Grid, System
from egxc.training.loss.field import (
    BaseFieldLossConfig,
    FieldMeasures,
    delta_field_fn,
    field_loss,
)
from egxc.utils.typing import (
    Float1,
    Float2xBxB,
    FloatBxB,
)

DensityOnlyMeasures: TypeAlias = Literal['dipole', 'coulomb', 'mean_field']
DensityMeasures: TypeAlias = FieldMeasures | DensityOnlyMeasures


@dataclass(frozen=True, slots=True)
class DensityFieldLossConfig(BaseFieldLossConfig[DensityMeasures]):
    spin_restricted: bool
    is_density_fitted: bool


@partial(jax.jit, static_argnames=['reference_basis_is_same'])
def dipole_difference(
    target: FloatBxB | Float2xBxB,
    prediction: FloatBxB | Float2xBxB,
    grid: Grid,
    reference_basis_is_same: bool,
):
    """
    Returns the magnitude of the dipole moment difference between the target and predicted density.
    """
    delta_field = delta_field_fn(target, prediction, grid.aos, reference_basis_is_same)
    delta_mu = jnp.einsum(
        'i,ir,i->r', delta_field, grid.coords, grid.weights
    )  # shape (3,)
    return jnp.linalg.norm(delta_mu)  # Return magnitude of dipole difference vector


def get_coulomb_energy_error_fn(
    spin_restricted: bool,
    use_density_fitting: bool,
    scale_per_electron: bool,
) -> Callable[[System, FloatBxB | Float2xBxB, FloatBxB | Float2xBxB], Float1]:
    coulomb_matrix_fn = get_coulomb_matrix_fn(spin_restricted, use_density_fitting)

    def coulomb_energy_error_fn(
        sys: System,
        target: FloatBxB | Float2xBxB,
        prediction: FloatBxB | Float2xBxB,
    ) -> Float1:
        """
        Returns a tuple (density-driven E_xc error, density-driven coulomb energy error).
        The coulomb energy error is universal i.e. solely depends on the density.
        """
        J_at_pred_dens = coulomb_matrix_fn(prediction, sys.fock_tensors.ert)
        J_at_true_dens = coulomb_matrix_fn(target, sys.fock_tensors.ert)
        coulomb_at_pred_dens = (J_at_pred_dens * prediction).sum()
        coulomb_at_true_dens = (J_at_true_dens * target).sum()
        out = 0.5 * jnp.abs(coulomb_at_true_dens - coulomb_at_pred_dens)
        if scale_per_electron:
            out /= sys.n_electrons
        return out

    return coulomb_energy_error_fn


def get_density_mean_field_error_fn(
    spin_restricted: bool,
    use_density_fitting: bool,
    scale_per_electron: bool,
) -> Callable[[System, FloatBxB | Float2xBxB, FloatBxB | Float2xBxB], Float1]:
    """
    Returns Density-driven Mean Field Error (DMF) function.
    This error is independent of the density functional approximation and hence universal.
    * Allows for cancellation of numerical and basis set errors in kinetic, electron-nuclei,
      and electron-electron repulsion.
    * To first order in the density error this coincides to an integral over the xc_potential error * density error
    * Positively uncorrelated with L1, L2 and coulomb metrics.

    Gould, T. A Step toward Density Benchmarking—The Energy-Relevant “Mean Field Error.”
    The Journal of Chemical Physics 2023, 159 (20), 204111. https://doi.org/10.1063/5.0175925.
    """
    coulomb_matrix_fn = get_coulomb_matrix_fn(spin_restricted, use_density_fitting)

    def mean_field_error_fn(
        sys: System,
        target: FloatBxB | Float2xBxB,
        prediction: FloatBxB | Float2xBxB,
    ) -> Float1:
        """
        Returns a tuple (density-driven E_xc error, density-driven mean field error).
        The mean field error is universal i.e. solely depends on the density.
        """
        J_at_pred_dens = coulomb_matrix_fn(prediction, sys.fock_tensors.ert)
        J_at_true_dens = coulomb_matrix_fn(target, sys.fock_tensors.ert)
        H_core = sys.fock_tensors.core_hamiltonian
        mf_at_true_dens = mean_field_energy(target, J_at_true_dens, H_core)
        mf_at_pred_dens = mean_field_energy(prediction, J_at_pred_dens, H_core)
        err_D_mf = mf_at_true_dens - mf_at_pred_dens
        if scale_per_electron:
            err_D_mf /= sys.n_electrons
        return jnp.abs(err_D_mf)  # TODO: abs cast here?

    return mean_field_error_fn


@partial(jax.jit, static_argnames=['config', 'reference_basis_is_same'])
def density_loss(
    target: FloatBxB | Float2xBxB,
    prediction: FloatBxB | Float2xBxB,
    sys: System,
    config: DensityFieldLossConfig,
    reference_basis_is_same: bool,
):
    if config.measure in get_args(FieldMeasures):
        return field_loss(target, prediction, sys, config, reference_basis_is_same)
    elif config.measure == 'dipole':
        return dipole_difference(target, prediction, sys.grid, reference_basis_is_same)
    elif config.measure == 'coulomb':
        _loss_fn = get_coulomb_energy_error_fn(
            config.spin_restricted, config.is_density_fitted, config.scale_per_electron
        )
        return _loss_fn(sys, target, prediction)
    elif config.measure == 'mean_field':
        _loss_fn = get_density_mean_field_error_fn(
            config.spin_restricted, config.is_density_fitted, config.scale_per_electron
        )
        return _loss_fn(sys, target, prediction)
    else:
        raise ValueError(f'Invalid density measure: {config.measure}')
