"""
Contains functions to calculate the Fock matrix
for a given density matrix using these tensors.

For a general understanding of this module see section 3 and 6 in:
Susi Lehtola et al. "An Overview of Self-Consistent Field Calculations Within Finite Basis Sets"
Molecules 2020, 25 (5), 1218. https://doi.org/10.3390/molecules25051218.
"""

from typing import Callable, Dict, Tuple

import einops
import flax.linen as nn
import jax.numpy as jnp

from egxc.systems.base import System
from egxc.utils.typing import (
    Float1,
    Float2xBxB,
    FloatAx3,
    FloatBxB,
    FloatBxBxBxB,
    FloatQxBxB,
)
from egxc.xc_energy import XCModule
from egxc.xc_energy.functionals.classical.hybrid import Hybrid
from egxc.xc_energy.functionals.classical.range_separated_hybrid import (
    BaseRangeSeparatedHybrid,
)


def mean_field_energy(
    density_matrix: FloatBxB | Float2xBxB,
    coulomb_matrix: FloatBxB,
    core_hamiltonian: FloatBxB,
) -> Float1:
    return ((core_hamiltonian + 0.5 * coulomb_matrix) * density_matrix).sum()


def get_coulomb_matrix_fn(
    spin_restricted: bool, use_density_fitting: bool
) -> Callable[[FloatBxB | Float2xBxB, FloatQxBxB | FloatBxBxBxB], FloatBxB]:
    def coulomb_matrix_fn(
        density_matrix: FloatBxB | Float2xBxB,
        electron_repulsion_tensor: FloatQxBxB | FloatBxBxBxB,
    ) -> FloatBxB:
        P = density_matrix if spin_restricted else density_matrix.sum(axis=0)
        if not use_density_fitting:
            J = jnp.einsum('ijkl,ij->kl', electron_repulsion_tensor, P)
        else:
            J = jnp.einsum(
                'Pij,Pkl,ij->kl',
                electron_repulsion_tensor,
                electron_repulsion_tensor,
                P,
            )
        return J

    return coulomb_matrix_fn


class FockMatrix(nn.Module):
    xc_module: XCModule
    use_density_fitting: bool
    spin_restricted: bool
    # TODO: check correct return Precision

    def setup(self):
        self.coulomb_matrix_fn = get_coulomb_matrix_fn(
            self.spin_restricted, self.use_density_fitting
        )

        def preprocessing(
            nuc_pos: FloatAx3,
            sys: System,
        ) -> Tuple[FloatBxB | Float2xBxB, Dict]:
            non_local_kwargs = {}
            func = self.xc_module.functional
            if isinstance(func, BaseRangeSeparatedHybrid):
                non_local_kwargs['eri_sr_tensor'] = sys.fock_tensors.eri_sr_tensor
                non_local_kwargs['eri_lr_tensor'] = sys.fock_tensors.eri_lr_tensor
                # Provide grid data required by functionals like wB97M-V (VV10 dispersion)
                non_local_kwargs['grid_coords'] = sys.grid.coords
                non_local_kwargs['grid_weights'] = sys.grid.weights
                non_local_kwargs['grid_aos'] = sys.grid.aos
                non_local_kwargs['grid_grad_aos'] = sys.grid.grad_aos
            if isinstance(func, Hybrid):
                non_local_kwargs['eri_tensor'] = sys.fock_tensors.ert
            if self.xc_module.functional.is_graph_based:
                non_local_kwargs['atom_mask'] = sys.atom_mask
                non_local_kwargs['nuc_pos'] = nuc_pos
                non_local_kwargs['grid_coords'] = sys.grid.coords
            H_core = sys.fock_tensors.core_hamiltonian
            if not self.spin_restricted:
                H_core = einops.repeat(H_core, 'i j -> spin i j', spin=2)
            return H_core, non_local_kwargs

        self.preprocessing = preprocessing

    def __call__(
        self,
        nuc_pos: FloatAx3,
        density_matrix: FloatBxB | Float2xBxB,
        sys: System,
    ) -> FloatBxB | Float2xBxB:
        return self.fock_matrix(nuc_pos, density_matrix, sys)

    def fock_matrix(
        self,
        nuc_pos: FloatAx3,
        density_matrix: FloatBxB | Float2xBxB,
        sys: System,
    ) -> FloatBxB | Float2xBxB:
        """
        Calculates the Fock matrix for a given coefficient matrix.
        """
        P = density_matrix
        H_core, non_local_kwargs = self.preprocessing(nuc_pos, sys)
        J = self.coulomb_matrix_fn(P, sys.fock_tensors.ert)

        V_xc = self.xc_module.xc_potential(
            P, sys.grid, sys.fock_tensors.basis_mask, **non_local_kwargs
        )
        return H_core + J + V_xc

    def fock_matrix_contributions(
        self,
        nuc_pos: FloatAx3,
        density_matrix: FloatBxB | Float2xBxB,
        sys: System,
    ) -> Tuple[FloatBxB, FloatBxB, FloatBxB] | Tuple[FloatBxB, Float2xBxB, Float2xBxB]:
        """
        Calculates the Fock matrix for a given coefficient matrix.
        """
        P = density_matrix
        H_core, non_local_kwargs = self.preprocessing(nuc_pos, sys)
        J = self.coulomb_matrix_fn(P, sys.fock_tensors.ert)

        V_xc = self.xc_module.xc_potential(
            P, sys.grid, sys.fock_tensors.basis_mask, **non_local_kwargs
        )
        return H_core, J, V_xc

    def energy(
        self,
        nuc_pos: FloatAx3,
        density_matrix: FloatBxB | Float2xBxB,
        sys: System,
    ) -> Tuple[Float1, Float1]:
        """
        returns the energies due to (core hamiltonian + coulomb, exchange-correlation)
        """
        P = density_matrix
        H_core, non_local_kwargs = self.preprocessing(nuc_pos, sys)
        J = self.coulomb_matrix_fn(P, sys.fock_tensors.ert)
        e_xc = self.xc_module.xc_energy(P, sys.grid, **non_local_kwargs)
        return mean_field_energy(P, J, H_core), e_xc

    def energy_and_fock_matrix(
        self,
        nuc_pos: FloatAx3,
        density_matrix: FloatBxB | Float2xBxB,
        sys: System,
    ) -> Tuple[Tuple[Float1, Float1], FloatBxB | Float2xBxB]:
        P = density_matrix
        H_core, non_local_kwargs = self.preprocessing(nuc_pos, sys)
        J = self.coulomb_matrix_fn(P, sys.fock_tensors.ert)
        e_xc, v_xc = self.xc_module.xc_energy_and_potential(
            P, sys.grid, sys.fock_tensors.basis_mask, **non_local_kwargs
        )
        F = H_core + J + v_xc
        return (mean_field_energy(P, J, H_core), e_xc), F
