import torch
from scipy.linalg import logm as scipy_logm


def has_complex_eigenvalues(A, tol=1e-8):
    """Check if matrix A has eigenvalues with significant imaginary parts."""
    eigvals = torch.linalg.eigvals(A)
    # Check if any imaginary part magnitude > tol
    return torch.any(eigvals.imag.abs() > tol).item()


def logm_newton_schulz_pade(A, num_sqrt=6, sqrt_iters=20):
    # (Same code as before)
    batch_size, N, _ = A.shape
    I = torch.eye(N, dtype=A.dtype, device=A.device).expand(batch_size, -1, -1)

    def sqrtm_newton_schulz(A, num_iters=20):
        normA = torch.linalg.norm(A, dim=(1, 2)).reshape(batch_size, 1, 1)
        Y = A / normA
        Z = I.clone()
        for _ in range(num_iters):
            T = 0.5 * (3.0 * I - Z @ Y)
            Y = Y @ T
            Z = T @ Z
        sqrtA = Y * torch.sqrt(normA)
        return sqrtA

    def logm_pade_approx(I_plus_X, m=6):
        X = I_plus_X - torch.eye(N, device=A.device, dtype=A.dtype)
        X_powers = [X]
        for _ in range(1, m):
            X_powers.append(X_powers[-1] @ X)
        approx = torch.zeros_like(X)
        for i in range(m):
            sign = 1.0 if i % 2 == 0 else -1.0
            approx = approx + sign * X_powers[i] / (i + 1)
        return approx

    A_scaled = A
    for _ in range(num_sqrt):
        A_scaled = sqrtm_newton_schulz(A_scaled, num_iters=sqrt_iters)

    log_approx = logm_pade_approx(A_scaled, m=6)
    logA = log_approx * (2**num_sqrt)
    return logA


def batch_logm_power_series(A, K=20):
    """
    Approximate logm for batch of matrices A using power series expansion:
        log(I + X) = sum_{k=1}^K (-1)^{k+1} X^k / k
    where X = A - I.

    Args:
        A: Tensor of shape (B, N, N), matrices close to identity
        K: int, number of terms in power series

    Returns:
        Tensor of shape (B, N, N) approximate logm(A)
    """
    B, N, _ = A.shape
    I = torch.eye(N, device=A.device, dtype=A.dtype).expand(B, -1, -1)
    X = A - I

    # Initialize sum and X^k
    current_power = X.clone()
    log_approx = current_power.clone()  # first term k=1, sign=1

    for k in range(2, K + 1):
        current_power = torch.bmm(current_power, X)
        sign = -1 if k % 2 == 0 else 1
        log_approx = log_approx + sign * current_power / k

    return log_approx


def batch_logm_approx(A, K=10, n_squarings=5):
    """
    Use scaling and squaring to improve convergence of power series logm approx.

    Steps:
    - Compute A^(1/2^n_squarings) by repeated matrix sqrt (approximate)
    - Compute logm of scaled matrix by power series
    - Multiply result by 2^n_squarings

    Args:
        A: Tensor (B,N,N) positive definite-ish matrices on GPU
        K: terms in power series
        n_squarings: int, number of sqrt iterations

    Returns:
        Tensor (B,N,N) approximate logm(A)
    """
    B, N, _ = A.shape
    I = torch.eye(N, device=A.device, dtype=A.dtype).expand(B, -1, -1)

    def sqrtm_newton_schulz(A, num_iters=10):
        """Approximate matrix sqrt using Newton-Schulz iteration"""
        normA = torch.linalg.norm(A, dim=(1, 2)).view(B, 1, 1)
        Y = A / normA
        Z = I.clone()
        for _ in range(num_iters):
            T = 0.5 * (3.0 * I - Z @ Y)
            Y = Y @ T
            Z = T @ Z
        sqrtA = Y * torch.sqrt(normA)
        return sqrtA

    # Compute matrix A^(1/2^n_squarings) by repeated sqrt
    A_scaled = A
    for _ in range(n_squarings):
        A_scaled = sqrtm_newton_schulz(A_scaled)

    # Now approximate logm on matrix close to identity
    log_approx = batch_logm_power_series(A_scaled, K=K)

    # Scale back result
    logA = log_approx * (2**n_squarings)
    return logA


def batch_logm_with_fallback(A_batch, approximate_logm_fn=batch_logm_approx):
    """
    Compute batch matrix logarithm with fallback to scipy.linalg.logm for
    matrices with complex eigenvalues or that cause failure.

    Args:
        A_batch: torch.Tensor of shape (B, N, N), assumed on CUDA device.
        approximate_logm_fn: function(A_batch) -> batch logm on GPU (approximate)

    Returns:
        torch.Tensor of shape (B, N, N) with matrix logarithms.
    """
    device = A_batch.device
    dtype = A_batch.dtype
    batch_size = A_batch.shape[0]

    logm_results = torch.empty_like(A_batch)
    fallback_indices = []

    # Pre-check matrices on CPU to decide which need fallback
    A_cpu = A_batch.cpu()
    for i in range(batch_size):
        Ai = A_cpu[i]
        if has_complex_eigenvalues(Ai) or torch.linalg.cond(Ai) > 1e8:
            fallback_indices.append(i)

    # Indices for approximate GPU batch
    approx_indices = [i for i in range(batch_size) if i not in fallback_indices]

    # Compute approximate logm for "safe" matrices
    if approx_indices:
        approx_batch = A_batch[approx_indices]
        approx_logm = approximate_logm_fn(approx_batch)
        logm_results[approx_indices] = approx_logm

    # Compute fallback logm for others using SciPy
    for i in fallback_indices:
        Ai = A_cpu[i].numpy()
        logm_cpu = scipy_logm(Ai, disp=False)[0]
        logm_results[i] = torch.tensor(logm_cpu.real, dtype=dtype)

    return logm_results.to(device)
