from dataclasses import dataclass
from typing import Any, Dict

import jax
import numpy as np

from deixc.orbital_transforms import dm_gradient_to_orbital_rotation_gradient
from egxc.utils.linalg import coeffs_to_density_matrix
from egxc.utils.typing import (
    NpFloatAx3,
    NpFloatB,
    NpFloatOxV,
    NpFloatRefSCF,
    NpFloatRefSCFxBxB,
    NpFloatRefSCFxOxV,
    NpUIntB,
)


@dataclass
class DeiXCTargets:
    """
    DEI-XC targets for a single sample.

    mo_coeffs (NpFloatRefSCFxBxB): Molecular orbital coefficients (a.k.a. C) along the SCF trajectory.
    total_energies (NpFloatRefSCF): Total energies along the SCF trajectory.
    xc_energies (NpFloatRefSCF): Exchange-correlation (XC) energies along the SCF trajectory.
    xc_potential_matrices (NpFloatRefSCFxBxB): XC potentials along the SCF trajectory.
    linear_response_xc_pot (NpFloatRefSCFxOxV): Linear response of the XC potential along the
        normalized direction of steepest descent from direct minimization of the XC energy.
    orbital_energies (NpFloatB): Molecular orbital eigenvalues at the final SCF iterate.
    density_hessian_diagonal (NpFloatOxV | None): Diagonal of the XC Hessian at the ground-state density.
    forces (NpFloatAx3 | None): Forces on nuclei.
    d3_dispersion_energy (float | None): D3 dispersion energy.
    d3_dispersion_forces (NpFloatAx3 | None): D3 dispersion forces on atoms.
    compute_costs (Dict[str, Any]): Log of compute time for different components.

    TODO: This class is presently limited to spin-restricted DFT calculations.
    """

    # Along SCF trajectory:
    mo_coeffs: NpFloatRefSCFxBxB  # a.k.a. C  (SCF, B, B)
    total_energies: NpFloatRefSCF
    xc_energies: NpFloatRefSCF
    xc_potential_matrices: NpFloatRefSCFxBxB
    linear_response_xc_pot: NpFloatRefSCFxOxV  # (SCF, O, V)
    orbital_energies: NpFloatB  # (B,)
    # From ground-state density only:
    density_hessian_diagonal: NpFloatOxV | None
    forces: NpFloatAx3 | None
    # independent of density matrix:
    d3_dispersion_energy: float | None
    d3_dispersion_forces: NpFloatAx3 | None
    # additional statistics:
    compute_costs: Dict[str, Any]

    def __post_init__(self):
        SCF, B, _ = self.mo_coeffs.shape
        assert self.mo_coeffs.shape[2] == self.mo_coeffs.shape[1], (
            f'mo_coeffs.shape: {self.mo_coeffs.shape}'
        )  # N_ao should be equal to N_mo
        assert self.linear_response_xc_pot.shape[0] == SCF
        _, O, V = self.linear_response_xc_pot.shape
        assert O + V == B

    def to_dict(self) -> Dict[str, Any]:
        return {
            'mo_coeffs': self.mo_coeffs,  # a.k.a. C
            'total_energies': self.total_energies,
            'xc_energies': self.xc_energies,
            'xc_potential_matrices': self.xc_potential_matrices,
            'linear_response_xc_pot': self.linear_response_xc_pot,
            'orbital_energies': self.orbital_energies,
            'density_hessian_diagonal': self.density_hessian_diagonal,
            'forces': self.forces,
            'd3_dispersion_energy': self.d3_dispersion_energy,
            'd3_dispersion_forces': self.d3_dispersion_forces,
            'compute_costs': self.compute_costs,
        }

    @property
    def occupancies(self) -> NpUIntB:
        # TODO: assumes spin-restricted calculations
        out = np.zeros(self.n_basis_functions, dtype=np.uint8)
        out[: self.n_occ] = 2
        return out

    @property
    def n_occ(self) -> int:
        if self.linear_response_xc_pot is None:
            raise ValueError('linear_response_xc_pot is None')
        else:
            return self.linear_response_xc_pot.shape[1]

    @property
    def n_basis_functions(self) -> int:
        return self.mo_coeffs.shape[1]

    def test_recomputations(
        self,
        xc_minimization_directions: NpFloatRefSCFxOxV,
        density_matrices: NpFloatRefSCFxBxB,
        xc_potential_matrices: NpFloatRefSCFxBxB,
    ):
        """
        Additional sanity check to ensure that the gradients and density matrices
        can be recomputed correctly from the stored quantities.
        """
        recomputed_gradients = jax.vmap(
            dm_gradient_to_orbital_rotation_gradient, in_axes=(0, 0, None)
        )(xc_potential_matrices, self.mo_coeffs, self.n_occ)
        max_rel_error = np.max(
            np.abs(xc_minimization_directions - recomputed_gradients)
            / (np.abs(xc_minimization_directions) + 1e-15)
        )
        max_abs_error = np.max(np.abs(xc_minimization_directions - recomputed_gradients))
        assert max_rel_error < 1e-5, (
            f'Max relative error: {max_rel_error:.6f}, max absolute error: {max_abs_error:.6f}'
        )

        recomputed_density_matrices = jax.vmap(
            coeffs_to_density_matrix, in_axes=(0, None)
        )(self.mo_coeffs, self.occupancies)
        assert np.allclose(density_matrices, recomputed_density_matrices)
