"""Type definitions for TDDFT/TDA Davidson solvers.

This module provides semantic type aliases that make function signatures
more readable and self-documenting. Types are organized by their physical
or algorithmic meaning rather than just their shape.

Type Naming Conventions:
    - Vector types: Named by what dimension they represent (e.g., DiagonalApprox)
    - Matrix types: Named by their role (e.g., SubspaceMatrix, EigenvectorMatrix)
"""

from dataclasses import dataclass, replace
from types import SimpleNamespace
from typing import NamedTuple, Tuple

from jaxtyping import Array, Float

# Generic float array type (any shape)
FloatArray = Float[Array, '...']

# Shape type aliases for Davidson-Casida state
# dim = n_occ * n_virt (problem dimension)
# M = max_vectors (maximum subspace size)
TrialVectorHolder = Float[Array, 'dim M']
"""Trial vector storage matrix, shape (dim, max_vectors)."""

SubspaceOverlapHolder = Float[Array, 'M M']
"""Subspace overlap matrix storage, shape (max_vectors, max_vectors)."""

_HIGH_PRECISION = 'float64'
_LOW_PRECISION = 'float32'

PRECISION = SimpleNamespace(
    davidson=_HIGH_PRECISION,
    subspace=_HIGH_PRECISION,
)

# Tolerance for denominator clamping in TDA preconditioner
# Prevents division by zero when (diagonal_approx - eigenvalue) is very small
TDA_PRECONDITIONER_EPS = 1e-8

# Tighter tolerance for TDDFT preconditioner (more sensitive due to metric)
TDDFT_PRECONDITIONER_EPS = 1e-14

# Threshold for vector norm during Gram-Schmidt orthogonalization
# Vectors with norm below this are considered linearly dependent
GRAM_SCHMIDT_NORM_THRESHOLD = 1e-14

# Physical dimension: n_occ * n_virt (size of the TDA/TDDFT problem)
DiagonalApprox = Float[Array, 'dim']
"""Approximate diagonal elements of the TDA/TDDFT matrix, shape (dim,)."""

ExcitationEnergies = Float[Array, 'n_states']
"""Computed excitation energies (eigenvalues), shape (n_states,)."""

ResidualNorms = Float[Array, 'n_states']
"""Norms of residual vectors for convergence checking, shape (n_states,)."""

FloatTxT = Float[Array, 'T T']
"""Square matrix type, shape (T, T)."""

# Trial/eigenvector matrices: columns are individual vectors
TrialVectorMatrix = Float[Array, 'dim n_vectors']
"""Trial vectors for Davidson iteration, shape (dim, n_vectors)."""

EigenvectorMatrix = Float[Array, 'dim n_states']
"""Eigenvectors as columns, shape (dim, n_states)."""

# Subspace matrices (small, dense)
SubspaceMatrix = Float[Array, 'subspace_dim subspace_dim']
"""Dense matrix in reduced Davidson subspace, shape (m, m)."""

SubspaceEigenvectors = Float[Array, 'subspace_dim n_states']
"""Eigenvectors in the reduced subspace, shape (m, n_states)."""


@dataclass
class DavidsonCasidaState:
    """Holds all pre-allocated matrices for Davidson-Casida iteration.

    Trial vectors V, W have shape (dim, M) where dim = n_occ * n_virt
    and M = max_vectors. Subspace overlaps have shape (M, M).

    Attributes:
        V: Trial vectors (X-component), shape (dim, M).
        W: Trial vectors (Y-component), shape (dim, M).
        U1: Matrix-vector products A @ X + B @ Y, shape (dim, M).
        U2: Matrix-vector products B @ X + A @ Y, shape (dim, M).
        VU1: Subspace overlap V.T @ U1, shape (M, M).
        VU2: Subspace overlap V.T @ U2, shape (M, M).
        WU1: Subspace overlap W.T @ U1, shape (M, M).
        WU2: Subspace overlap W.T @ U2, shape (M, M).
        VV: Subspace overlap V.T @ V, shape (M, M).
        VW: Subspace overlap V.T @ W, shape (M, M).
        WW: Subspace overlap W.T @ W, shape (M, M).
    """

    # Trial vector holders: (dim, M)
    V: TrialVectorHolder
    W: TrialVectorHolder
    U1: TrialVectorHolder
    U2: TrialVectorHolder

    # Subspace overlap matrices: (M, M)
    VU1: SubspaceOverlapHolder
    VU2: SubspaceOverlapHolder
    WU1: SubspaceOverlapHolder
    WU2: SubspaceOverlapHolder
    VV: SubspaceOverlapHolder
    VW: SubspaceOverlapHolder
    WW: SubspaceOverlapHolder

    def with_updates(self, **kwargs) -> 'DavidsonCasidaState':
        """Return a new state with specified fields updated."""
        return replace(self, **kwargs)


class SubspaceResult(NamedTuple):
    """Result from gen_sub_ab: reduced matrices for the Casida eigenproblem."""

    A: SubspaceMatrix
    B: SubspaceMatrix
    sigma: SubspaceMatrix
    pi: SubspaceMatrix


TDAResult = Tuple[ExcitationEnergies, EigenvectorMatrix]
"""Return type for TDA solver: (eigenvalues, eigenvectors)."""

TDDFTResult = Tuple[ExcitationEnergies, EigenvectorMatrix, EigenvectorMatrix]
"""Return type for TDDFT solver: (excitation_energies, X_amplitudes, Y_amplitudes)."""
