from typing import Any, Protocol, Tuple

import jax.numpy as jnp

from egxc.solver.fock import get_coulomb_matrix_fn, mean_field_energy
from egxc.systems import System
from egxc.utils.typing import (
    Float1,
    Float2xBxB,
    FloatBxB,
)
from egxc.xc_energy import XCModule


class DFADependentDensityErrorFn(Protocol):
    def __call__(
        self,
        sys: System,
        reference_density_matrix: FloatBxB | Float2xBxB,
        predicted_density_matrix: FloatBxB | Float2xBxB,
        **non_local_kwargs: Any,
    ) -> Float1: ...


def get_functional_density_error_decomposition_fns(
    learned_dfa: XCModule,
    reference_dfa: XCModule,
    spin_restricted: bool,
    use_density_fitting: bool,
) -> DFADependentDensityErrorFn:
    """
    An adation to learned functional destillations of the error decomposition proposed by Kim et al.
    Kim et al. Understanding and Reducing Errors in Density Functional Calculations.
    Phys. Rev. Lett. 2013, 111 (7), 073003. https://doi.org/10.1103/PhysRevLett.111.073003.
    """

    coulomb_matrix_fn = get_coulomb_matrix_fn(spin_restricted, use_density_fitting)

    def error_fn(
        sys: System,
        ref_dm: FloatBxB | Float2xBxB,
        pred_dm: FloatBxB | Float2xBxB,
        **non_local_kwargs: Any,
    ) -> Tuple[Float1, Float1]:
        """
        Returns at tuple (functional error, density-driven error).
        """
        J_at_pred_dens = coulomb_matrix_fn(pred_dm, sys.fock_tensors.ert)
        J_at_true_dens = coulomb_matrix_fn(ref_dm, sys.fock_tensors.ert)
        H_core = sys.fock_tensors.core_hamiltonian
        mf_at_true_dens = mean_field_energy(ref_dm, J_at_true_dens, H_core)
        mf_at_pred_dens = mean_field_energy(pred_dm, J_at_pred_dens, H_core)
        xc_target_at_true_dens = reference_dfa(ref_dm, sys.grid, **non_local_kwargs)
        xc_pred_at_true_dens = learned_dfa(ref_dm, sys.grid, **non_local_kwargs)
        xc_pred_at_pred_dens = learned_dfa(pred_dm, sys.grid, **non_local_kwargs)
        err_F = xc_target_at_true_dens - xc_pred_at_true_dens
        err_D = (xc_pred_at_true_dens + mf_at_true_dens) - (
            xc_pred_at_pred_dens + mf_at_pred_dens
        )
        return jnp.abs(err_F), jnp.abs(err_D)  # TODO: abs cast here?

    return error_fn  # type: ignore


def get_density_driven_dfa_xc_energy_error_fns(
    learned_dfa: XCModule,
    scale_per_electron: bool,
) -> DFADependentDensityErrorFn:
    """
    DFA dependent density error.
    Gould, T. A Step toward Density Benchmarking—The Energy-Relevant “Mean Field Error.”
    The Journal of Chemical Physics 2023, 159 (20), 204111. https://doi.org/10.1063/5.0175925.
    """

    def error_fn(
        sys: System,
        ref_dm: FloatBxB | Float2xBxB,
        pred_dm: FloatBxB | Float2xBxB,
        **non_local_kwargs: Any,
    ) -> Float1:
        """
        Returns a tuple (density-driven E_xc error, density-driven mean field error).
        The mean field error is universal i.e. solely depends on the density.
        """
        xc_pred_at_true_dens = learned_dfa(ref_dm, sys.grid, **non_local_kwargs)
        xc_pred_at_pred_dens = learned_dfa(pred_dm, sys.grid, **non_local_kwargs)
        err_D_xc = xc_pred_at_true_dens - xc_pred_at_pred_dens
        if scale_per_electron:
            err_D_xc /= sys.n_electrons
        return jnp.abs(err_D_xc)

    return error_fn  # type: ignore
