"""
Partial correlation computation for Gaussian distributions.

This module provides functions for computing partial correlations
from covariance matrices and sample data, along with related
transformations (Fisher's z-transform).
"""

from __future__ import annotations

from typing import Set, Tuple, Optional, Union
import numpy as np
from scipy import linalg
import warnings


def partial_correlation(
    Sigma: np.ndarray,
    i: int,
    j: int,
    S: Union[Set[int], list, None] = None
) -> float:
    """
    Compute the partial correlation ρ_{ij|S} from covariance matrix.

    ρ_{ij|S} = Corr(X_i, X_j | X_S)

    For Gaussian distributions, ρ_{ij|S} = 0 iff X_i ⊥ X_j | X_S.

    Algorithm:
    1. Extract submatrix Σ_{ij,S} for variables {i, j} ∪ S
    2. Compute conditional covariance:
       Σ_{ij|S} = Σ_{ij} - Σ_{i,S} Σ_{S,S}^{-1} Σ_{S,j}
    3. Return Σ_{ij|S}[0,1] / sqrt(Σ_{ij|S}[0,0] * Σ_{ij|S}[1,1])

    Args:
        Sigma: d×d covariance matrix
        i: First variable index
        j: Second variable index
        S: Conditioning set (empty set or None for marginal correlation)

    Returns:
        Partial correlation ρ_{ij|S}

    Raises:
        ValueError: If indices are out of bounds or i == j
    """
    d = Sigma.shape[0]

    # Validate inputs
    if i < 0 or i >= d or j < 0 or j >= d:
        raise ValueError(f"Indices i={i}, j={j} out of bounds for d={d}")
    if i == j:
        raise ValueError("Cannot compute partial correlation of variable with itself")

    # Handle empty conditioning set
    if S is None:
        S = set()
    else:
        S = set(S)

    # Remove i and j from S if present
    S = S - {i, j}

    # Check S is valid
    for s in S:
        if s < 0 or s >= d:
            raise ValueError(f"Conditioning set index {s} out of bounds")

    # If empty conditioning set, return marginal correlation
    if not S:
        var_i = Sigma[i, i]
        var_j = Sigma[j, j]
        cov_ij = Sigma[i, j]

        if var_i <= 0 or var_j <= 0:
            warnings.warn("Non-positive variance encountered")
            return 0.0

        return cov_ij / np.sqrt(var_i * var_j)

    # Non-empty conditioning set
    S_list = sorted(S)
    vars_ij = [i, j]

    # Extract submatrices
    # Σ_{ij,ij} (2×2)
    Sigma_ij_ij = Sigma[np.ix_(vars_ij, vars_ij)]

    # Σ_{ij,S} (2×|S|)
    Sigma_ij_S = Sigma[np.ix_(vars_ij, S_list)]

    # Σ_{S,S} (|S|×|S|)
    Sigma_SS = Sigma[np.ix_(S_list, S_list)]

    # Σ_{S,ij} (|S|×2)
    Sigma_S_ij = Sigma[np.ix_(S_list, vars_ij)]

    # Compute conditional covariance matrix
    # Σ_{ij|S} = Σ_{ij,ij} - Σ_{ij,S} Σ_{S,S}^{-1} Σ_{S,ij}
    try:
        # Try Cholesky for positive definite matrices
        L = linalg.cholesky(Sigma_SS, lower=True)
        # Solve L @ L.T @ X = Sigma_S_ij
        Z = linalg.solve_triangular(L, Sigma_S_ij, lower=True)
        Sigma_SS_inv_Sigma_S_ij = linalg.solve_triangular(L.T, Z, lower=False)
    except linalg.LinAlgError:
        # Fall back to pseudo-inverse
        try:
            Sigma_SS_inv = linalg.pinv(Sigma_SS)
            Sigma_SS_inv_Sigma_S_ij = Sigma_SS_inv @ Sigma_S_ij
        except linalg.LinAlgError:
            warnings.warn("Singular conditioning matrix, returning 0")
            return 0.0

    Sigma_cond = Sigma_ij_ij - Sigma_ij_S @ Sigma_SS_inv_Sigma_S_ij

    # Extract conditional covariance and variances
    cond_cov = Sigma_cond[0, 1]
    cond_var_i = Sigma_cond[0, 0]
    cond_var_j = Sigma_cond[1, 1]

    # Handle numerical issues
    if cond_var_i <= 1e-15 or cond_var_j <= 1e-15:
        warnings.warn("Near-zero conditional variance, returning 0")
        return 0.0

    rho = cond_cov / np.sqrt(cond_var_i * cond_var_j)

    # Clip to valid range (numerical precision)
    rho = np.clip(rho, -1.0 + 1e-10, 1.0 - 1e-10)

    return float(rho)


def partial_correlation_from_precision(
    Theta: np.ndarray,
    i: int,
    j: int
) -> float:
    """
    Compute partial correlation from precision matrix.

    For Gaussian distributions:
    ρ_{ij|rest} = -Θ_ij / sqrt(Θ_ii * Θ_jj)

    This gives the partial correlation conditioning on ALL other variables.

    Args:
        Theta: d×d precision matrix (inverse covariance)
        i: First variable index
        j: Second variable index

    Returns:
        Partial correlation ρ_{ij|all other variables}
    """
    d = Theta.shape[0]

    if i < 0 or i >= d or j < 0 or j >= d:
        raise ValueError(f"Indices i={i}, j={j} out of bounds for d={d}")
    if i == j:
        raise ValueError("Cannot compute partial correlation of variable with itself")

    Theta_ii = Theta[i, i]
    Theta_jj = Theta[j, j]
    Theta_ij = Theta[i, j]

    if Theta_ii <= 0 or Theta_jj <= 0:
        warnings.warn("Non-positive diagonal in precision matrix")
        return 0.0

    rho = -Theta_ij / np.sqrt(Theta_ii * Theta_jj)

    # Clip to valid range
    rho = np.clip(rho, -1.0 + 1e-10, 1.0 - 1e-10)

    return float(rho)


def sample_partial_correlation(
    X: np.ndarray,
    i: int,
    j: int,
    S: Union[Set[int], list, None] = None
) -> float:
    """
    Estimate partial correlation from sample data.

    Args:
        X: n×d data matrix (n samples, d variables)
        i: First variable index
        j: Second variable index
        S: Conditioning set

    Returns:
        Sample partial correlation estimate
    """
    n, d = X.shape

    if n < 4:
        raise ValueError(f"Need at least 4 samples, got {n}")
    if i < 0 or i >= d or j < 0 or j >= d:
        raise ValueError(f"Indices i={i}, j={j} out of bounds for d={d}")

    # Handle conditioning set
    if S is None:
        S = set()
    else:
        S = set(S) - {i, j}

    # Need enough samples relative to conditioning set size
    if n <= len(S) + 3:
        warnings.warn(f"Insufficient samples (n={n}) for conditioning set size |S|={len(S)}")

    # Compute sample covariance matrix
    X_centered = X - X.mean(axis=0)
    Sigma_hat = (X_centered.T @ X_centered) / (n - 1)

    return partial_correlation(Sigma_hat, i, j, S)


def fisher_z_transform(rho: float) -> float:
    """
    Apply Fisher's z-transformation to a correlation.

    z = (1/2) * log((1 + ρ) / (1 - ρ)) = arctanh(ρ)

    The transformed value is approximately normal with variance
    1/(n - |S| - 3) for sample partial correlations.

    Args:
        rho: Correlation value in (-1, 1)

    Returns:
        Fisher's z value
    """
    # Clip to avoid numerical issues at boundaries
    rho = np.clip(rho, -1.0 + 1e-10, 1.0 - 1e-10)
    return float(np.arctanh(rho))


def inverse_fisher_z(z: float) -> float:
    """
    Apply inverse Fisher's z-transformation.

    ρ = (exp(2z) - 1) / (exp(2z) + 1) = tanh(z)

    Args:
        z: Fisher's z value

    Returns:
        Correlation value in (-1, 1)
    """
    return float(np.tanh(z))


def partial_correlation_test(
    X: np.ndarray,
    i: int,
    j: int,
    S: Union[Set[int], list, None],
    alpha: float = 0.05
) -> Tuple[bool, float, float]:
    """
    Test for conditional independence using Fisher's z-test.

    H0: ρ_{ij|S} = 0 (conditional independence)
    H1: ρ_{ij|S} ≠ 0 (conditional dependence)

    Args:
        X: n×d data matrix
        i: First variable index
        j: Second variable index
        S: Conditioning set
        alpha: Significance level

    Returns:
        Tuple of (is_independent, p_value, sample_correlation)
    """
    from scipy import stats

    n, d = X.shape

    if S is None:
        S = set()
    else:
        S = set(S) - {i, j}

    # Compute sample partial correlation
    rho_hat = sample_partial_correlation(X, i, j, S)

    # Degrees of freedom
    df = n - len(S) - 3

    if df <= 0:
        warnings.warn(f"Non-positive degrees of freedom (df={df})")
        return True, 1.0, rho_hat

    # Fisher's z-transform
    z = fisher_z_transform(rho_hat)

    # Standard error
    se = 1.0 / np.sqrt(df)

    # Test statistic
    z_stat = z / se

    # Two-sided p-value
    p_value = 2 * (1 - stats.norm.cdf(abs(z_stat)))

    # Decision
    is_independent = p_value >= alpha

    return is_independent, float(p_value), float(rho_hat)


def all_partial_correlations(
    Sigma: np.ndarray,
    max_cond_size: Optional[int] = None
) -> dict:
    """
    Compute all partial correlations up to a maximum conditioning set size.

    Args:
        Sigma: d×d covariance matrix
        max_cond_size: Maximum |S| to consider (default: d-2)

    Returns:
        Dict mapping (i, j, frozenset(S)) to ρ_{ij|S}
    """
    from itertools import combinations

    d = Sigma.shape[0]

    if max_cond_size is None:
        max_cond_size = d - 2

    results = {}

    for i in range(d):
        for j in range(i + 1, d):
            # All other nodes can be in conditioning set
            other_nodes = [k for k in range(d) if k != i and k != j]

            for cond_size in range(max_cond_size + 1):
                for S in combinations(other_nodes, cond_size):
                    S_set = frozenset(S)
                    rho = partial_correlation(Sigma, i, j, set(S))
                    results[(i, j, S_set)] = rho
                    results[(j, i, S_set)] = rho  # Symmetry

    return results


def conditional_covariance_matrix(
    Sigma: np.ndarray,
    target_vars: list,
    conditioning_vars: list
) -> np.ndarray:
    """
    Compute the conditional covariance matrix.

    Σ_{T|S} = Σ_{TT} - Σ_{TS} Σ_{SS}^{-1} Σ_{ST}

    where T is the target variables and S is the conditioning set.

    Args:
        Sigma: Full covariance matrix
        target_vars: List of target variable indices
        conditioning_vars: List of conditioning variable indices

    Returns:
        Conditional covariance matrix for target variables
    """
    if not conditioning_vars:
        return Sigma[np.ix_(target_vars, target_vars)]

    T = target_vars
    S = conditioning_vars

    Sigma_TT = Sigma[np.ix_(T, T)]
    Sigma_TS = Sigma[np.ix_(T, S)]
    Sigma_SS = Sigma[np.ix_(S, S)]
    Sigma_ST = Sigma[np.ix_(S, T)]

    try:
        Sigma_SS_inv = linalg.inv(Sigma_SS)
    except linalg.LinAlgError:
        Sigma_SS_inv = linalg.pinv(Sigma_SS)

    Sigma_cond = Sigma_TT - Sigma_TS @ Sigma_SS_inv @ Sigma_ST

    # Ensure symmetry
    Sigma_cond = (Sigma_cond + Sigma_cond.T) / 2

    return Sigma_cond


def partial_correlation_from_sem_formula(
    beta_ji: float,
    var_j_given_S: float,
    var_i_given_S: float
) -> float:
    """
    Compute partial correlation using Lemma 7.1 formula.

    For edge (j, i) with coefficient β_ji and S = Pa(i) \ {j}:
    ρ_{ij|S} = β_ji * sqrt(Var(X_j | X_S) / Var(X_i | X_S))

    This is useful for validation and understanding the relationship
    between edge coefficients and partial correlations.

    Args:
        beta_ji: Edge coefficient from j to i
        var_j_given_S: Conditional variance of X_j given X_S
        var_i_given_S: Conditional variance of X_i given X_S

    Returns:
        Partial correlation ρ_{ij|S}
    """
    if var_i_given_S <= 0:
        raise ValueError("var_i_given_S must be positive")
    if var_j_given_S < 0:
        raise ValueError("var_j_given_S must be non-negative")

    if var_j_given_S == 0:
        return 0.0

    rho = beta_ji * np.sqrt(var_j_given_S / var_i_given_S)

    # Clip to valid range
    rho = np.clip(rho, -1.0 + 1e-10, 1.0 - 1e-10)

    return float(rho)


def z_test_threshold(
    alpha: float,
    n: int,
    cond_set_size: int = 0
) -> float:
    """
    Compute the correlation threshold for Fisher's z-test.

    Returns the minimum |ρ| needed to reject H0: ρ = 0 at level α.

    Args:
        alpha: Significance level
        n: Sample size
        cond_set_size: Size of conditioning set |S|

    Returns:
        Threshold value for |ρ|
    """
    from scipy import stats

    df = n - cond_set_size - 3
    if df <= 0:
        return 1.0

    z_crit = stats.norm.ppf(1 - alpha / 2)
    z_threshold = z_crit / np.sqrt(df)

    rho_threshold = inverse_fisher_z(z_threshold)

    return abs(rho_threshold)
