"""Davidson-based iterative eigensolvers for TDA and TDDFT problems.


References:
    Davidson, E. R. (1975). The iterative calculation of a few of the lowest
    eigenvalues and corresponding eigenvectors of large real-symmetric matrices.
    Journal of Computational Physics, 17(1), 87-94.
"""

from typing import Callable, List, Tuple

import jax.numpy as jnp

from tddft.utils import gram_schmidt
from tddft.utils.linalg import (
    gen_sub_ab,
    gen_vw,
    tda_diag_initial_guess,
    tda_diag_preconditioner,
    tddft_diag_preconditioner,
    tddft_subspace_eigen_solver,
    utriangle_symmetrize,
)
from tddft.utils.typing import (
    PRECISION,
    DavidsonCasidaState,
    DiagonalApprox,
    FloatArray,
    TDAResult,
    TDDFTResult,
    TrialVectorMatrix,
)

TDAMatrixVectorProduct = Callable[[TrialVectorMatrix], TrialVectorMatrix]
CasidaMatrixVectorProduct = Callable[
    [TrialVectorMatrix, TrialVectorMatrix],
    Tuple[TrialVectorMatrix, TrialVectorMatrix],
]


def _create_casida_state(
    dim: int,
    max_vectors: int,
    dtype: jnp.dtype,
    diagonal_approx: FloatArray,
    initial_size: int,
) -> DavidsonCasidaState:
    """Create initial Davidson-Casida state with pre-allocated matrices."""
    V = tda_diag_initial_guess(
        jnp.zeros((dim, max_vectors), dtype=dtype), initial_size, diagonal_approx
    )
    W = jnp.zeros_like(V)
    U1 = jnp.zeros_like(V)
    U2 = jnp.zeros_like(V)
    overlap = jnp.zeros((max_vectors, max_vectors), dtype=dtype)
    return DavidsonCasidaState(
        V,
        W,
        U1,
        U2,
        overlap,
        overlap.copy(),
        overlap.copy(),
        overlap.copy(),
        overlap.copy(),
        overlap.copy(),
        overlap.copy(),
    )


def davidson(
    matrix_vector_product: TDAMatrixVectorProduct,
    diagonal_approx: DiagonalApprox,
    n_states: int,
    conv_tol: float = 1e-9,
    max_iter: int = 25,
) -> TDAResult:
    """Iterative Davidson diagonalization for the TDA eigenvalue problem.

    Solves A·v = λ·v for the lowest `n_states` eigenvalues using the
    Davidson iterative method with diagonal preconditioning.

    The algorithm:
        1. Initialize subspace with unit vectors on smallest diagonal entries
        2. Compute A·V for new subspace vectors
        3. Solve eigenvalue problem in reduced subspace
        4. Check convergence via residual norms ‖A·v - λ·v‖
        5. If not converged, precondition residuals and expand subspace
        6. Orthogonalize new vectors and repeat

    Args:
        matrix_vector_product: Function implementing A @ X for batch evaluation.
        diagonal_approx: Approximate diagonal of A, used for initial guess and
            preconditioning. Shape: (dim,).
        n_states: Number of lowest eigenpairs to compute.
        conv_tol: Convergence threshold for maximum residual norm.
        max_iter: Maximum number of Davidson iterations.

    Returns:
        eigenvalues: Array of n_states lowest eigenvalues, shape (n_states,).
        eigenvectors: Matrix of eigenvectors as columns, shape (dim, n_states).

    Raises:
        RuntimeError: If convergence is not achieved within max_iter iterations.

    Notes:
        - Initial subspace size: min(n_states + 8, 2*n_states, dim)
        - Uses double Gram-Schmidt orthogonalization for numerical stability
    """
    dim = diagonal_approx.shape[0]
    dtype = PRECISION.davidson

    subspace_size_old = 0
    subspace_size = min(n_states + 8, 2 * n_states, dim)
    block_size = subspace_size
    max_vectors = max_iter * n_states + subspace_size

    V_holder = tda_diag_initial_guess(
        jnp.zeros((dim, max_vectors), dtype=dtype), subspace_size, diagonal_approx
    )
    W_holder = jnp.zeros_like(V_holder)
    subspace_A = jnp.zeros((max_vectors, max_vectors), dtype=dtype)

    max_residual_norm = float('inf')
    residual_norms: List[float] = []
    ritz_values = jnp.zeros(n_states, dtype=dtype)
    eigenvector_approx = jnp.zeros((dim, n_states), dtype=dtype)

    for iteration in range(max_iter):
        new_trial_vectors = V_holder[:, subspace_size_old:subspace_size]
        n_new_vectors = new_trial_vectors.shape[1]
        padding_needed = block_size - n_new_vectors

        if padding_needed > 0:
            new_trial_vectors = jnp.concatenate(
                [new_trial_vectors, jnp.zeros((dim, padding_needed), dtype=dtype)], axis=1
            )

        AV_padded = matrix_vector_product(new_trial_vectors)
        W_holder = W_holder.at[:, subspace_size_old:subspace_size].set(
            AV_padded[:, :n_new_vectors]
        )

        subspace_A = gen_vw(
            subspace_A, V_holder, W_holder, subspace_size_old, subspace_size
        )
        subspace_A_symmetric = utriangle_symmetrize(
            subspace_A[:subspace_size, :subspace_size]
        )

        all_ritz_values, all_ritz_vectors = jnp.linalg.eigh(subspace_A_symmetric)
        ritz_values = all_ritz_values[:n_states]
        ritz_vectors = all_ritz_vectors[:, :n_states]

        eigenvector_approx = jnp.dot(V_holder[:, :subspace_size], ritz_vectors)
        Av_approx = jnp.dot(W_holder[:, :subspace_size], ritz_vectors)
        residuals = Av_approx - eigenvector_approx * ritz_values

        residual_norms = jnp.linalg.norm(residuals, axis=0)
        max_residual_norm = float(jnp.max(residual_norms))  # type: ignore

        if max_residual_norm < conv_tol or iteration == (max_iter - 1):
            break

        unconverged_mask = residual_norms > conv_tol  # type: ignore
        unconverged_indices = jnp.where(unconverged_mask)[0]

        preconditioned_residuals = tda_diag_preconditioner(
            residuals[:, unconverged_indices],
            ritz_values[unconverged_indices],
            diagonal_approx,
        )

        subspace_size_old = subspace_size
        V_holder, subspace_size = gram_schmidt.fill_holder(
            V_holder, subspace_size_old, preconditioned_residuals, double=True
        )

    if max_residual_norm >= conv_tol:
        raise RuntimeError(
            f'TDA Davidson failed to converge. '
            f'Max residual norm: {max_residual_norm:.2e}, '
            f'tolerance: {conv_tol:.2e}. '
            f'Residual norms: {residual_norms}'
        )

    return ritz_values, eigenvector_approx


def davidson_casida(
    matrix_vector_product: CasidaMatrixVectorProduct,
    diagonal_approx: DiagonalApprox,
    n_states: int = 20,
    conv_tol: float = 1e-5,
    max_iter: int = 25,
    single_precision: bool = False,
) -> TDDFTResult:
    """Davidson solver for the full TDDFT (Casida) eigenvalue problem.

    Solves the TDDFT eigenvalue problem in Casida form:
        [A  B] [X]     [1   0 ] [X]
        [B  A] [Y] = ω [0  -1 ] [Y]

    where A and B are the TDDFT response matrices, and ω are the
    excitation energies.

    The method maintains trial vector pairs (V, W) representing (X, Y)
    and uses a non-standard inner product (the TDDFT metric):
        metric((V1, W1), (V2, W2)) = V1.T @ V2 - W1.T @ W2

    Args:
        matrix_vector_product: Function (X, Y) → (U1, U2) where
            U1 = A·X + B·Y and U2 = B·X + A·Y.
        diagonal_approx: Approximate diagonal for initial guess and preconditioning.
            Shape: (dim,) where dim = n_occ * n_virt.
        n_states: Number of lowest excitation energies to compute.
        conv_tol: Convergence threshold for maximum residual norm.
        max_iter: Maximum number of Davidson iterations.
        single_precision: If True, use float32 for internal buffers.

    Returns:
        omega: Excitation energies, shape (n_states,).
        X: X eigenvector amplitudes, shape (dim, n_states).
        Y: Y eigenvector amplitudes, shape (dim, n_states).

    Notes:
        - Uses symmetric orthogonalization to balance ||X+Y|| and ||X-Y|| norms
        - Convergence can be slower than TDA for difficult cases
        - The algorithm handles the non-standard metric via matrix power
    """
    dim = diagonal_approx.shape[0]
    dtype = jnp.float32 if single_precision else jnp.float64

    subspace_size_old = 0
    subspace_size = min(n_states + 8, 2 * n_states, dim)
    block_size = subspace_size
    max_vectors = (max_iter + 1) * n_states

    state = _create_casida_state(dim, max_vectors, dtype, diagonal_approx, subspace_size)
    omega = jnp.zeros(n_states, dtype=dtype)
    X_full = jnp.zeros((dim, n_states), dtype=dtype)
    Y_full = jnp.zeros_like(X_full)

    for iteration in range(max_iter):
        V = state.V[:, :subspace_size]
        W = state.W[:, :subspace_size]
        new_V = V[:, subspace_size_old:subspace_size]
        new_W = W[:, subspace_size_old:subspace_size]

        n_new = new_V.shape[1]
        padding = block_size - n_new
        if padding > 0:
            new_V = jnp.concatenate(
                [new_V, jnp.zeros((dim, padding), dtype=dtype)], axis=1
            )
            new_W = jnp.concatenate(
                [new_W, jnp.zeros((dim, padding), dtype=dtype)], axis=1
            )

        U1_new, U2_new = matrix_vector_product(new_V, new_W)
        state = state.with_updates(
            U1=state.U1.at[:, subspace_size_old:subspace_size].set(U1_new[:, :n_new]),
            U2=state.U2.at[:, subspace_size_old:subspace_size].set(U2_new[:, :n_new]),
        )

        U1 = state.U1[:, :subspace_size]
        U2 = state.U2[:, :subspace_size]

        subspace, state = gen_sub_ab(state, subspace_size_old, subspace_size)
        omega, subspace_x, subspace_y = tddft_subspace_eigen_solver(
            subspace.A, subspace.B, subspace.sigma, subspace.pi, n_states
        )

        X_full = jnp.dot(V, subspace_x) + jnp.dot(W, subspace_y)
        Y_full = jnp.dot(W, subspace_x) + jnp.dot(V, subspace_y)

        residual_X = jnp.dot(U1, subspace_x) + jnp.dot(U2, subspace_y) - X_full * omega
        residual_Y = jnp.dot(U2, subspace_x) + jnp.dot(U1, subspace_y) + Y_full * omega

        combined_residual = jnp.vstack((residual_X, residual_Y))
        residual_norms = jnp.linalg.norm(combined_residual, axis=0)
        max_residual_norm = float(jnp.max(residual_norms))

        if max_residual_norm < conv_tol or iteration == (max_iter - 1):
            break

        unconverged_mask = residual_norms > conv_tol
        unconverged_indices = jnp.where(unconverged_mask)[0]

        X_new, Y_new = tddft_diag_preconditioner(
            residual_X[:, unconverged_indices],
            residual_Y[:, unconverged_indices],
            omega[unconverged_indices],
            diagonal_approx,
        )

        subspace_size_old = subspace_size
        V_new, W_new, subspace_size = gram_schmidt.vw_fill_holder(
            state.V, state.W, subspace_size_old, X_new, Y_new, double=False
        )
        state = state.with_updates(V=V_new, W=W_new)

        if subspace_size == subspace_size_old:
            break

    return omega, X_full, Y_full
