"""Gram-Schmidt orthogonalization algorithms for Davidson solvers.

This module provides Gram-Schmidt orthogonalization functions for:
- Standard orthogonalization (for TDA)
- Metric-aware orthogonalization (for TDDFT with non-standard inner product)

Both standard and JIT-compatible versions are provided. Functions with
`_jit` suffix use `jax.lax.fori_loop` and can be compiled with `jax.jit`.
"""

from typing import Tuple

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from tddft.utils.typing import GRAM_SCHMIDT_NORM_THRESHOLD

FloatArray = Float[Array, '...']


def __bvec(A: FloatArray, bvec: FloatArray) -> FloatArray:
    """Orthogonalize bvec against the columns of A (assumed orthonormal).

    Args:
        A: Matrix with orthonormal columns, shape (dim, m).
        bvec: Vector to orthogonalize, shape (dim, 1).

    Returns:
        Orthogonalized vector, shape (dim, 1).
    """
    projection_coeff = jnp.dot(A.T, bvec)
    bvec = bvec - jnp.dot(A, projection_coeff)
    return bvec


def fill_holder(
    V: FloatArray, count: int, vecs: FloatArray, *, double: bool = True
) -> Tuple[FloatArray, int]:
    """Fill V with orthonormal vectors from vecs starting at column ``count``.

    This version uses Python loops and is NOT JIT-compatible.
    For JIT compilation, use `gram_schmidt_fill_holder_jit`.

    Args:
        V: Holder matrix for orthonormal vectors, shape (dim, max_vectors).
        count: Current number of vectors in V (column index to start filling).
        vecs: New vectors to orthogonalize and add, shape (dim, n_new).
        double: If True, apply Gram-Schmidt twice for numerical stability.

    Returns:
        Tuple of (V, new_count) where V has orthonormalized vectors added
        and new_count is the updated column count.
    """
    nvec = vecs.shape[1]
    for j in range(nvec):
        vec = vecs[:, j].reshape(-1, 1)
        if count > 0:
            vec = __bvec(V[:, :count], vec)
            if double:
                vec = __bvec(V[:, :count], vec)
        norm = jnp.linalg.norm(vec)
        if norm > GRAM_SCHMIDT_NORM_THRESHOLD:
            vec = vec / norm
            V = V.at[:, count].set(vec[:, 0])
            count += 1
    return V, count


def fill_holder_jit(
    V: FloatArray, count: int, vecs: FloatArray, *, double: bool = True
) -> Tuple[FloatArray, int]:
    """JIT-compatible version of gram_schmidt_fill_holder using lax.fori_loop.

    Note: This version always processes all vectors, even if some are
    linearly dependent. The count is still updated correctly.

    Args:
        V: Holder matrix for orthonormal vectors, shape (dim, max_vectors).
        count: Current number of vectors in V (column index to start filling).
        vecs: New vectors to orthogonalize and add, shape (dim, n_new).
        double: If True, apply Gram-Schmidt twice for numerical stability.

    Returns:
        Tuple of (V, new_count) where V has orthonormalized vectors added
        and new_count is the updated column count.
    """
    nvec = vecs.shape[1]

    def body_fn(j: int, carry: Tuple[FloatArray, int]) -> Tuple[FloatArray, int]:
        V_curr, cnt = carry
        vec = vecs[:, j].reshape(-1, 1)

        # Orthogonalize against existing vectors
        vec = jax.lax.cond(
            cnt > 0,
            lambda v: __bvec(V_curr[:, :cnt], v),
            lambda v: v,
            vec,
        )
        if double:
            vec = jax.lax.cond(
                cnt > 0,
                lambda v: __bvec(V_curr[:, :cnt], v),
                lambda v: v,
                vec,
            )

        norm = jnp.linalg.norm(vec)
        vec_normalized = vec / jnp.maximum(norm, GRAM_SCHMIDT_NORM_THRESHOLD)

        # Only add if norm is above threshold
        should_add = norm > GRAM_SCHMIDT_NORM_THRESHOLD
        V_new = jax.lax.cond(
            should_add,
            lambda: V_curr.at[:, cnt].set(vec_normalized[:, 0]),
            lambda: V_curr,
        )
        cnt_new = jax.lax.cond(should_add, lambda: cnt + 1, lambda: cnt)

        return V_new, cnt_new

    V_final, count_final = jax.lax.fori_loop(0, nvec, body_fn, (V, count))
    return V_final, count_final


def vw_gram_schmidt(
    x: FloatArray, y: FloatArray, V: FloatArray, W: FloatArray
) -> Tuple[FloatArray, FloatArray]:
    """Orthogonalize (x, y) against (V, W) in the TDDFT metric.

    Performs Gram-Schmidt orthogonalization using the non-standard inner product:
        metric((V1, W1), (V2, W2)) = V1.T @ V2 - W1.T @ W2

    This ensures that trial vectors maintain proper orthogonality in the
    generalized eigenvalue problem for TDDFT.

    Args:
        x: Vector to orthogonalize, shape (dim,) or (dim, 1).
        y: Paired vector to orthogonalize, shape (dim,) or (dim, 1).
        V: Orthonormal basis vectors for X-space, shape (dim, m).
        W: Paired orthonormal basis vectors for Y-space, shape (dim, m).

    Returns:
        Tuple of (x_orth, y_orth) orthogonalized against (V, W) in the metric.

    Notes:
        - Projection coefficients: m = V.T @ x + W.T @ y, n = W.T @ x + V.T @ y
        - Orthogonalization: x' = x - V @ m - W @ n, y' = y - W @ m - V @ n
        - Assumes (V, W) columns are already orthonormal in the metric
    """
    m = jnp.dot(V.T, x) + jnp.dot(W.T, y)
    n = jnp.dot(W.T, x) + jnp.dot(V.T, y)
    x = x - jnp.dot(V, m) - jnp.dot(W, n)
    y = y - jnp.dot(W, m) - jnp.dot(V, n)
    return x, y


def symmetrically_orthogonalize(
    x: FloatArray, y: FloatArray
) -> Tuple[FloatArray, FloatArray]:
    """Balance the norms of (x+y) and (x-y) while preserving their directions.

    Transforms (x, y) so that ||x+y|| ≈ ||x-y||, which maintains proper
    symmetry properties required by the TDDFT eigenproblem.

    The transformation:
        s = (x + y) / 2
        d = (x - y) / 2 * (||x+y|| / ||x-y||)
        new_x = s + d
        new_y = s - d

    When x ≈ y (diff norm near zero), returns (x+y, y) to avoid numerical issues.

    Args:
        x: First vector, shape (dim,) or (dim, 1).
        y: Second vector, shape (dim,) or (dim, 1).

    Returns:
        Tuple of (new_x, new_y) with balanced sum/difference norms.
    """
    sum_vec = x + y
    diff_vec = x - y
    norm_sum = jnp.linalg.norm(sum_vec)
    norm_diff = jnp.linalg.norm(diff_vec)

    # When norm_diff is tiny, avoid division and return (sum, y) like original
    is_diff_tiny = norm_diff < GRAM_SCHMIDT_NORM_THRESHOLD

    # Normal case: balance the norms
    factor = jnp.where(is_diff_tiny, 0.0, norm_sum / norm_diff)
    s = 0.5 * sum_vec
    d = 0.5 * diff_vec * factor

    # For tiny diff case, return (sum_vec, y); otherwise return (s+d, s-d)
    new_x = jnp.where(is_diff_tiny, sum_vec, s + d)
    new_y = jnp.where(is_diff_tiny, y, s - d)
    return new_x, new_y


def vw_fill_holder(
    V_holder: FloatArray,
    W_holder: FloatArray,
    m: int,
    X_new: FloatArray,
    Y_new: FloatArray,
    *,
    double: bool = False,
) -> Tuple[FloatArray, FloatArray, int]:
    """Insert |X_new, Y_new> into |V_holder, W_holder| with orthonormalization.

    This version uses Python loops and is NOT JIT-compatible.
    For JIT compilation, use `vw_fill_holder_jit`.

    Args:
        V_holder: Holder for X-component trial vectors, shape (dim, max_vectors).
        W_holder: Holder for Y-component trial vectors, shape (dim, max_vectors).
        m: Current number of vector pairs in holders.
        X_new: New X-component vectors to add, shape (dim, n_new).
        Y_new: New Y-component vectors to add, shape (dim, n_new).
        double: If True, apply Gram-Schmidt twice for numerical stability.

    Returns:
        Tuple of (V_holder, W_holder, new_count) with orthonormalized vectors.
    """
    nvec = X_new.shape[1]

    for j in range(nvec):
        V = V_holder[:, :m]
        W = W_holder[:, :m]

        x_tmp = X_new[:, j].reshape(-1, 1)
        y_tmp = Y_new[:, j].reshape(-1, 1)

        x_tmp, y_tmp = vw_gram_schmidt(x_tmp, y_tmp, V, W)
        if double:
            x_tmp, y_tmp = vw_gram_schmidt(x_tmp, y_tmp, V, W)

        x_tmp, y_tmp = symmetrically_orthogonalize(x_tmp, y_tmp)

        norm_sq = jnp.dot(x_tmp.T, x_tmp) + jnp.dot(y_tmp.T, y_tmp)
        xy_norm = float(jnp.sqrt(norm_sq.squeeze()))

        if xy_norm > GRAM_SCHMIDT_NORM_THRESHOLD:
            x_tmp = x_tmp / xy_norm
            y_tmp = y_tmp / xy_norm

            V_holder = V_holder.at[:, m].set(x_tmp[:, 0])
            W_holder = W_holder.at[:, m].set(y_tmp[:, 0])
            m += 1

    return V_holder, W_holder, m


def vw_fill_holder_jit(
    V_holder: FloatArray,
    W_holder: FloatArray,
    m: int,
    X_new: FloatArray,
    Y_new: FloatArray,
    *,
    double: bool = False,
) -> Tuple[FloatArray, FloatArray, int]:
    """JIT-compatible version of vw_fill_holder using lax.fori_loop.

    Args:
        V_holder: Holder for X-component trial vectors, shape (dim, max_vectors).
        W_holder: Holder for Y-component trial vectors, shape (dim, max_vectors).
        m: Current number of vector pairs in holders.
        X_new: New X-component vectors to add, shape (dim, n_new).
        Y_new: New Y-component vectors to add, shape (dim, n_new).
        double: If True, apply Gram-Schmidt twice for numerical stability.

    Returns:
        Tuple of (V_holder, W_holder, new_count) with orthonormalized vectors.
    """
    nvec = X_new.shape[1]

    def body_fn(
        j: int, carry: Tuple[FloatArray, FloatArray, int]
    ) -> Tuple[FloatArray, FloatArray, int]:
        V_curr, W_curr, cnt = carry

        V = V_curr[:, :cnt]
        W = W_curr[:, :cnt]

        x_tmp = X_new[:, j].reshape(-1, 1)
        y_tmp = Y_new[:, j].reshape(-1, 1)

        x_tmp, y_tmp = vw_gram_schmidt(x_tmp, y_tmp, V, W)
        if double:
            x_tmp, y_tmp = vw_gram_schmidt(x_tmp, y_tmp, V, W)

        x_tmp, y_tmp = symmetrically_orthogonalize(x_tmp, y_tmp)

        norm_sq = jnp.dot(x_tmp.T, x_tmp) + jnp.dot(y_tmp.T, y_tmp)
        xy_norm = jnp.sqrt(norm_sq.squeeze())

        should_add = xy_norm > GRAM_SCHMIDT_NORM_THRESHOLD
        x_normalized = x_tmp / jnp.maximum(xy_norm, GRAM_SCHMIDT_NORM_THRESHOLD)
        y_normalized = y_tmp / jnp.maximum(xy_norm, GRAM_SCHMIDT_NORM_THRESHOLD)

        V_new = jax.lax.cond(
            should_add,
            lambda: V_curr.at[:, cnt].set(x_normalized[:, 0]),
            lambda: V_curr,
        )
        W_new = jax.lax.cond(
            should_add,
            lambda: W_curr.at[:, cnt].set(y_normalized[:, 0]),
            lambda: W_curr,
        )
        cnt_new = jax.lax.cond(should_add, lambda: cnt + 1, lambda: cnt)

        return V_new, W_new, cnt_new

    V_final, W_final, m_final = jax.lax.fori_loop(
        0, nvec, body_fn, (V_holder, W_holder, m)
    )
    return V_final, W_final, m_final
