"""Utility functions for TDDFT/TDA Davidson solvers.

This module provides helper functions for:
- Initial guess generation
- Diagonal preconditioning
- Subspace matrix construction
- TDDFT subspace eigenvalue solving

These utilities support the Davidson iterative diagonalization algorithms
for both TDA and full TDDFT calculations.
"""

from typing import Tuple

import jax.numpy as jnp

from tddft.utils.typing import (
    TDA_PRECONDITIONER_EPS,
    TDDFT_PRECONDITIONER_EPS,
    DavidsonCasidaState,
    FloatArray,
    FloatTxT,
    SubspaceResult,
)


def tda_diag_initial_guess(
    V_holder: FloatArray, n_states: int, diagonal_approx: FloatArray
) -> FloatArray:
    """Seed the Davidson subspace with unit vectors on the smallest diagonal entries.

    Args:
        V_holder: Pre-allocated matrix to hold trial vectors, shape (dim, max_vectors).
        n_states: Number of initial guess vectors to generate.
        diagonal_approx: Approximate diagonal of the TDA/TDDFT matrix, shape (dim,).

    Returns:
        V_holder with first n_states columns set to unit vectors corresponding
        to the smallest diagonal elements.
    """
    diagonal_approx = diagonal_approx.reshape(-1)
    sorted_indices = jnp.argsort(diagonal_approx)
    V_holder = V_holder.at[sorted_indices[:n_states], jnp.arange(n_states)].set(1.0)
    return V_holder


def tda_diag_preconditioner(
    residual: FloatArray, sub_eigenvalue: FloatArray, diagonal_approx: FloatArray
) -> FloatArray:
    """Simple diagonal preconditioner for TDA Davidson method.

    Solves the linear system (D - ω)·x = r approximately using the diagonal
    approximation D ≈ diag(A), where A is the TDA matrix. This generates
    improved trial vectors from residuals.

    Specifically computes: x = r / (diagonal_approx - ω)

    Args:
        residual: Residual vectors r = A·v - ω·v, shape (dim, n_states).
        sub_eigenvalue: Current eigenvalue approximations ω, shape (n_states,).
        diagonal_approx: Approximate diagonal of A matrix, shape (dim,).

    Returns:
        Preconditioned residuals, shape (dim, n_states).

    Notes:
        - Small denominators are clamped to avoid division by zero
        - The sign is preserved when clamping: sign(x - ω) * eps
    """
    n_states = residual.shape[1]
    eps = TDA_PRECONDITIONER_EPS
    denom = jnp.repeat(diagonal_approx.reshape(-1, 1), n_states, axis=1) - sub_eigenvalue
    denom = jnp.where(jnp.abs(denom) < eps, jnp.sign(denom) * eps, denom)
    return residual / denom


def tddft_diag_preconditioner(
    R_x: FloatArray, R_y: FloatArray, omega: FloatArray, diagonal_approx: FloatArray
) -> Tuple[FloatArray, FloatArray]:
    """Diagonal preconditioner for TDDFT (Casida) Davidson method.

    Solves the approximate correction equations:
        (D - ω)·X_new = R_x
        (D + ω)·Y_new = R_y

    where D ≈ diag(A) and ω are the current excitation energy approximations.

    Args:
        R_x: X-component residuals, shape (dim, n_states).
        R_y: Y-component residuals, shape (dim, n_states).
        omega: Current excitation energies ω, shape (n_states,).
        diagonal_approx: Approximate diagonal of A matrix, shape (dim,).

    Returns:
        Tuple of (X_new, Y_new) preconditioned correction vectors.
    """
    diagonal_approx = diagonal_approx.reshape(-1, 1)
    n_states = R_x.shape[1]
    eps = TDDFT_PRECONDITIONER_EPS
    diag = jnp.repeat(diagonal_approx, n_states, axis=1)

    denom_x = jnp.where(
        jnp.abs(diag - omega) < eps, jnp.sign(diag - omega) * eps, diag - omega
    )
    denom_y = jnp.where(
        jnp.abs(diag + omega) < eps, jnp.sign(diag + omega) * eps, diag + omega
    )

    return R_x / denom_x, R_y / denom_y


def matrix_power(S: FloatTxT, power: float) -> FloatArray:
    """Compute ``S**power`` via eigen-decomposition.

    Args:
        S: Symmetric matrix, shape (n, n).
        power: Exponent to raise eigenvalues to.

    Returns:
        Matrix power S^power, shape (n, n).
    """
    eigenvalues, eigenvectors = jnp.linalg.eigh(S)
    return jnp.dot(eigenvectors * (eigenvalues**power), eigenvectors.T)


def block_symmetrize(A: FloatArray, m: int, n: int) -> FloatArray:
    """Fill the lower block from the upper block for a symmetric matrix."""
    A = A.at[m:n, :m].set(A[:m, m:n].T)
    return A


def utriangle_symmetrize(A: FloatArray) -> FloatArray:
    """Symmetrize a matrix by averaging with its transpose."""
    return (A + A.T) / 2


def gen_vw(
    sub_A_holder: FloatArray,
    V_holder: FloatArray,
    W_holder: FloatArray,
    size_old: int,
    size_new: int,
    *,
    symmetry: bool = True,
    upper_triangle_only: bool = False,
) -> FloatArray:
    """Compute V^T W block updates into sub_A_holder.

    Args:
        sub_A_holder: Pre-allocated subspace matrix holder.
        V_holder: Left matrix for dot product.
        W_holder: Right matrix for dot product.
        size_old: Previous subspace size (start of new block).
        size_new: Current subspace size (end of new block).
        symmetry: If True, symmetrize by copying upper to lower block.
        upper_triangle_only: If True, only compute upper triangle.

    Returns:
        Updated sub_A_holder with new block filled in.
    """
    V_current = V_holder[:, :size_new]
    W_new = W_holder[:, size_old:size_new]
    sub_A_holder = sub_A_holder.at[:size_new, size_old:size_new].set(
        jnp.dot(V_current.T, W_new)
    )

    if symmetry:
        sub_A_holder = block_symmetrize(sub_A_holder, size_old, size_new)
    elif not upper_triangle_only:
        V_new = V_holder[:, size_old:size_new]
        W_old = W_holder[:, :size_old]
        sub_A_holder = sub_A_holder.at[size_old:size_new, :size_old].set(
            jnp.dot(V_new.T, W_old)
        )

    return sub_A_holder


def gen_sub_ab(
    state: DavidsonCasidaState, size_old: int, size_new: int
) -> Tuple[SubspaceResult, DavidsonCasidaState]:
    """Build subspace matrices for TDDFT Davidson iteration.

    Computes all necessary overlap matrices and constructs the reduced
    Casida eigenvalue problem matrices.

    Args:
        state: Current Davidson-Casida iteration state.
        size_old: Previous subspace size.
        size_new: Current subspace size.

    Returns:
        Tuple of (SubspaceResult, updated_state) where SubspaceResult contains
        the reduced Casida matrices (A, B, sigma, pi).
    """
    VU1 = gen_vw(
        state.VU1,
        state.V,
        state.U1,
        size_old,
        size_new,
        symmetry=False,
        upper_triangle_only=True,
    )
    VU2 = gen_vw(
        state.VU2,
        state.V,
        state.U2,
        size_old,
        size_new,
        symmetry=False,
        upper_triangle_only=True,
    )
    WU1 = gen_vw(
        state.WU1,
        state.W,
        state.U1,
        size_old,
        size_new,
        symmetry=False,
        upper_triangle_only=True,
    )
    WU2 = gen_vw(
        state.WU2,
        state.W,
        state.U2,
        size_old,
        size_new,
        symmetry=False,
        upper_triangle_only=True,
    )
    VV = gen_vw(
        state.VV,
        state.V,
        state.V,
        size_old,
        size_new,
        symmetry=False,
        upper_triangle_only=True,
    )
    WW = gen_vw(
        state.WW,
        state.W,
        state.W,
        size_old,
        size_new,
        symmetry=False,
        upper_triangle_only=True,
    )
    VW = gen_vw(
        state.VW,
        state.V,
        state.W,
        size_old,
        size_new,
        symmetry=False,
        upper_triangle_only=False,
    )

    sub_A = utriangle_symmetrize(VU1[:size_new, :size_new] + WU2[:size_new, :size_new])
    sub_B = utriangle_symmetrize(VU2[:size_new, :size_new] + WU1[:size_new, :size_new])
    sigma = utriangle_symmetrize(VV[:size_new, :size_new] - WW[:size_new, :size_new])
    pi = VW[:size_new, :size_new] - VW[:size_new, :size_new].T

    updated_state = state.with_updates(
        VU1=VU1, VU2=VU2, WU1=WU1, WU2=WU2, VV=VV, WW=WW, VW=VW
    )
    return SubspaceResult(A=sub_A, B=sub_B, sigma=sigma, pi=pi), updated_state


def solve_ax_minus_xomega_equals_b(
    A: FloatArray, omega: FloatArray, Q: FloatArray
) -> FloatArray:
    """Solve AX - XΩ = Q for X using eigen-decomposition of A.

    Args:
        A: Matrix A, shape (n, n).
        omega: Diagonal elements of Ω, shape (k,).
        Q: Right-hand side matrix, shape (n, k).

    Returns:
        Solution matrix X, shape (n, k).
    """
    Qnorm = jnp.linalg.norm(Q, axis=0, keepdims=True)
    Q = Q / Qnorm
    n_vectors = len(omega)
    a, u = jnp.linalg.eigh(A)
    ub = jnp.dot(u.T, Q)
    ux = jnp.zeros_like(Q)
    for k in range(n_vectors):
        ux = ux.at[:, k].set(ub[:, k] / (a - omega[k]))
    X = jnp.dot(u, ux)
    X = X * Qnorm
    return X


def tddft_subspace_eigen_solver(
    a: FloatArray, b: FloatArray, sigma: FloatArray, pi: FloatArray, k: int
) -> Tuple[FloatArray, FloatArray, FloatArray]:
    """Solve the Casida TDDFT eigenproblem in the reduced Davidson subspace.

    Solves the generalized eigenvalue problem:
        [A  B] [x]     [σ   π] [x]
        [B  A] [y] = ω [-π -σ] [y]

    where A = V.T @ U1 + W.T @ U2, B = V.T @ U2 + W.T @ U1,
          σ = V.T @ V - W.T @ W, π = V.T @ W - W.T @ V

    This is solved by transforming to standard form via:
        M = A^{-1/2} B A^{-1/2}
        ω = 1/λ where λ are eigenvalues of M

    Args:
        a: Subspace matrix A (half_size × half_size).
        b: Subspace matrix B (half_size × half_size).
        sigma: Metric overlap σ (half_size × half_size).
        pi: Metric overlap π (half_size × half_size).
        k: Number of lowest excitation energies to extract.

    Returns:
        Tuple containing:
            - omega: Excitation energies, shape (k,), sorted low to high
            - x: X eigenvector components, shape (half_size, k)
            - y: Y eigenvector components, shape (half_size, k)
    """
    half_size = a.shape[0]
    A = jnp.zeros((2 * half_size, 2 * half_size))
    A = A.at[:half_size, :half_size].set(a)
    A = A.at[:half_size, half_size:].set(b)
    A = A.at[half_size:, :half_size].set(b)
    A = A.at[half_size:, half_size:].set(a)

    B = jnp.zeros_like(A)
    B = B.at[:half_size, :half_size].set(sigma)
    B = B.at[:half_size, half_size:].set(pi)
    B = B.at[half_size:, :half_size].set(-pi)
    B = B.at[half_size:, half_size:].set(-sigma)

    A_neg_half = matrix_power(A, -0.5)
    M = jnp.linalg.multi_dot([A_neg_half, B, A_neg_half])
    omega, Z = jnp.linalg.eigh(M)

    omega = 1 / omega[-k:][::-1]
    Z = Z[:, -k:][:, ::-1]
    Z = Z * (omega**0.5)

    T = jnp.dot(A_neg_half, Z)
    x = T[:half_size, :]
    y = T[half_size:, :]

    return omega, x, y
