from typing import Callable, Tuple

import numpy as np
from pyscf.scf import hf
from pyscf.soscf import newton_ah

from deixc.orbital_transforms import dm_gradient_to_orbital_rotation_gradient
from egxc.utils.typing import (
    NpFloatBxB,
    NpFloatOV,
    NpFloatOxV,
    NpFloatRefSCFxBxB,
    NpFloatRefSCFxOxV,
)


def get_gradient_hessian_diag_and_hvp(
    mf: hf.SCF,
    mo_coeff: NpFloatBxB,
    fock_matrix: NpFloatBxB,
    n_occ: int,
    n_virt: int,
) -> Tuple[NpFloatOxV, Callable[[NpFloatOV], NpFloatOV]]:
    # complete (non-trivial part of the) Hessian would have shape (n_occ * n_vir, n_occ * n_vir)
    grad_total, hessian_vector_product, hessian_diagonal = newton_ah.gen_g_hop_rhf(
        mf,
        mo_coeff,  # This is where we evaluate the hessian on the Manifold
        mf.mo_occ,
        fock_ao=fock_matrix,
        with_symmetry=False,  # NOTE: we decided to not use symmetry for now
    )
    # gradient vector of E_xc w.r.t. orbital rotations (Delta_rot) at Delta_rot = 0:
    hessian_diagonal = hessian_diagonal.reshape(n_virt, n_occ).T
    hessian_diagonal *= 2  # TODO: Are we sure about this factor of 2?
    return hessian_diagonal, hessian_vector_product


def linear_response_fn(
    direction: NpFloatOxV,
    hvp: Callable[[NpFloatOV], NpFloatOV],
    n_occ: int,
    n_virt: int,
) -> NpFloatOxV:
    """
    Linear response of the XC potential for given perturbation in MO space.
    Normalizes the perturbation direction to ensure it is a unit vector.
    """
    direction = direction.reshape(-1)
    direction = direction / (np.linalg.norm(direction) + 1e-15)
    out = hvp(direction).real * 2
    return out.reshape(n_occ, n_virt)


def get(
    mf: hf.SCF,
    mo_coeffs: NpFloatRefSCFxBxB,
    fock_matrices: NpFloatRefSCFxBxB,
    xc_potential_matrices: NpFloatRefSCFxBxB,
) -> Tuple[NpFloatRefSCFxOxV, NpFloatRefSCFxOxV, NpFloatOxV]:
    """
    Compute the orbital-rotation Hessian projections along an SCF trajectory for RHF.

    Parameters
    ----------
    mf : pyscf.scf.hf.SCF
        Mean-field (RHF) SCF object with attributes:
        - mo_occ: occupation numbers array e.g. [2, 2, 2, 0, 0, 0]
    mo_coeffs : NpFloatBxB or NpFloatRefSCFxBxB
        MO coefficients, either for a single SCF step or along an SCF trajectory.
    fock_matrices : NpFloatBxB or NpFloatRefSCFxBxB
        Fock matrices, either from a single matrix from the ground-state or along an SCF trajectory.

    Returns
    -------
    gradients : NpFloatRefSCFxOxV
        Array of shape (SCF, O, V) containing the xc energy minimization directions in MO space.
    linear_responses : NpFloatRefSCFxOxV
        Array of shape (SCF, O, V) containing the Linear response of the XC potential along the
        normalized direction of steepest descent from direct minimization of the XC energy.
    hdiag : NpFloatOxV
        2D array of shape (O, V) representing the diagonal elements of the orbital rotation Hessian (scaled by 2).

    Notes
    -----
    1. Rotates AO-basis density changes into the occupied-virtual MO subspace.
    2. Uses pyscf.soscf.newton_ah.gen_g_hop_rhf to obtain gradient, Hessian-vector
       product function, and diagonal.
    3. Scales second derivatives by a factor of two for RHF energy definition.
    """

    SCF = mo_coeffs.shape[0]
    B = mo_coeffs.shape[-1]
    O = np.count_nonzero(mf.mo_occ)  # noqa: E741  # type: ignore
    V = B - O
    linear_responses = np.zeros((SCF, O, V), dtype=np.float64)
    gradients = np.zeros((SCF, O, V), dtype=np.float64)

    assert mo_coeffs.ndim == 3, 'mo_coeffs and fock_matrices must be a 3D array'

    for i, (C_0, F_0, V_0) in enumerate(
        zip(mo_coeffs, fock_matrices, xc_potential_matrices)
    ):
        hdiag, hvp = (
            get_gradient_hessian_diag_and_hvp(  # hdiag is overwritten and only the last one is returned
                mf,
                C_0,
                F_0,
                O,
                V,
            )
        )
        grad_xc = dm_gradient_to_orbital_rotation_gradient(V_0, C_0, O)
        gradients[i] = grad_xc
        linear_responses[i] = linear_response_fn(grad_xc, hvp, O, V)
    return gradients, linear_responses, hdiag  # type: ignore
