from __future__ import annotations

import numpy as np


def center_cols(A: np.ndarray) -> np.ndarray:
    """Center columns of a matrix."""
    return A - A.mean(axis=0, keepdims=True)


def make_psd(A: np.ndarray, min_eig: float = 1e-10) -> np.ndarray:
    """Project a matrix to the nearest PSD matrix by eigenvalue clipping."""
    A = 0.5 * (A + A.T)
    w, V = np.linalg.eigh(A)
    w = np.maximum(w, min_eig)
    A_psd = V @ np.diag(w) @ V.T
    # Re-symmetrize to avoid tiny numerical asymmetries breaking downstream Cholesky.
    return 0.5 * (A_psd + A_psd.T)


def safe_inv_spd(A: np.ndarray, ridge: float = 1e-10) -> np.ndarray:
    """Compute a stable inverse of an SPD matrix with ridge regularization."""
    A = 0.5 * (A + A.T) + ridge * np.eye(A.shape[0])
    w, V = np.linalg.eigh(A)
    w = np.maximum(w, ridge)
    Ainv = V @ np.diag(1.0 / w) @ V.T
    return 0.5 * (Ainv + Ainv.T)


def safe_cholesky_spd(
    A: np.ndarray,
    *,
    ridge: float = 1e-10,
    max_tries: int = 5,
    jitter_mult: float = 10.0,
) -> np.ndarray:
    """Compute a stable Cholesky factor of an SPD matrix with jitter."""
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError("safe_cholesky_spd expects a square 2D array.")

    S = 0.5 * (A + A.T)
    n = S.shape[0]

    # Scale a baseline jitter to the matrix magnitude to behave well across sizes.
    diag_scale = float(np.mean(np.abs(np.diag(S)))) if n > 0 else 0.0
    base = max(float(ridge), 1e-15 * max(1.0, diag_scale))
    I = np.eye(n)

    jitter = base
    last_err: Exception | None = None
    for _ in range(max_tries):
        try:
            return np.linalg.cholesky(S + jitter * I)
        except np.linalg.LinAlgError as e:
            last_err = e
            jitter *= float(jitter_mult)

    # Final fallback: explicit eigenvalue clipping at the final jitter level.
    w, V = np.linalg.eigh(S)
    w = np.maximum(w, jitter)
    S_psd = V @ np.diag(w) @ V.T
    S_psd = 0.5 * (S_psd + S_psd.T)
    try:
        return np.linalg.cholesky(S_psd)
    except np.linalg.LinAlgError:
        # Re-raise the original error for a clearer message.
        if last_err is not None:
            raise last_err
        raise np.linalg.LinAlgError("Matrix is not positive definite")


def cross_cov_IX(I: np.ndarray, X: np.ndarray) -> np.ndarray:
    """Compute cross-covariance between instruments and covariates."""
    Ic = center_cols(I)
    Xc = center_cols(X)
    n = I.shape[0]
    return (Ic.T @ Xc) / max(n - 1, 1)


def cross_cov_IY(I: np.ndarray, Y: np.ndarray) -> np.ndarray:
    """Compute cross-covariance between instruments and outcomes."""
    Ic = center_cols(I)
    Yc = Y - Y.mean()
    n = I.shape[0]
    return (Ic.T @ Yc) / max(n - 1, 1)
