"""Optimal weights computation α* = M_I^{-1} * b for Expected GradCAM.

The optimal weights minimize explanation infidelity and are computed by
solving a linear system involving the second moment matrix.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

import torch

from expected_gradcam.exceptions import ComputationError, SingularMatrixError
from expected_gradcam.core.second_moment import (
    compute_second_moment_matrix,
    compute_cross_moment,
)
from expected_gradcam.types.results import SolverDiagnostics


if TYPE_CHECKING:
    from torch import Tensor


def solve_linear_system(
    M_I: "Tensor",
    b: "Tensor",
    regularization_eps: float = 1e-6,
) -> "Tensor":
    """Solve M_I @ α = b for α with numerical stability.

    Uses regularization for robustness:
        (M_I + ε*I) @ α = b

    Args:
        M_I: Second moment matrix [K, K].
        b: Cross-moment vector [K].
        regularization_eps: Regularization constant.

    Returns:
        Solution α [K].

    Raises:
        SingularMatrixError: If system cannot be solved.
    """
    K = M_I.shape[0]
    device = M_I.device

    # Add regularization
    M_I_reg = M_I + regularization_eps * torch.eye(K, device=device)

    try:
        # Try direct solve first (faster if well-conditioned)
        alpha = torch.linalg.solve(M_I_reg, b)
    except RuntimeError:
        # Fall back to least squares
        try:
            result = torch.linalg.lstsq(M_I_reg, b.unsqueeze(1))
            alpha = result.solution.squeeze()
        except RuntimeError as e:
            raise SingularMatrixError() from e

    # Check for NaN/Inf
    if torch.isnan(alpha).any() or torch.isinf(alpha).any():
        raise ComputationError(
            "optimal_weights",
            details="Solution contains NaN or Inf values",
        )

    return alpha


def solve_linear_system_robust(
    M_I: "Tensor",
    b: "Tensor",
    method: Literal["pinv", "adaptive_reg", "subspace", "regularized"] = "pinv",
    rcond: float = 1e-6,
    regularization_eps: float = 1e-6,
) -> tuple["Tensor", SolverDiagnostics]:
    """Robustly solve M_I @ α = b for rank-deficient M_I.

    Methods:
        - "pinv": Moore-Penrose pseudo-inverse. Minimum-norm solution.
          Recommended for rank-deficient systems from data-aware perturbations.

        - "adaptive_reg": Tikhonov regularization with eigenvalue-based ε.
          Sets ε = 0.01 * min_positive_eigenvalue.

        - "subspace": Projects to significant eigenspace and solves there.

        - "regularized": Fixed Tikhonov with provided regularization_eps.

    Args:
        M_I: Second moment matrix [K, K].
        b: Cross-moment vector [K].
        method: Solver method.
        rcond: Cutoff ratio for singular values.
        regularization_eps: Fallback regularization.

    Returns:
        Tuple of (alpha [K], diagnostics).

    Raises:
        ComputationError: If solution contains NaN or Inf.
        ValueError: If unknown method specified.
    """
    K = M_I.shape[0]
    device = M_I.device

    # Compute eigenvalue decomposition for diagnostics
    eigenvalues, eigenvectors = torch.linalg.eigh(M_I)

    # Compute effective rank and condition number
    max_ev = eigenvalues.max()
    threshold = rcond * max_ev
    significant_mask = eigenvalues > threshold
    effective_rank = int(significant_mask.sum().item())

    pos_eigenvalues = eigenvalues[significant_mask]
    if len(pos_eigenvalues) > 0:
        condition_number = float((max_ev / pos_eigenvalues.min()).item())
        min_pos_ev = float(pos_eigenvalues.min().item())
    else:
        condition_number = float("inf")
        min_pos_ev = 0.0

    reg_eps_used: float | None = None

    if method == "pinv":
        # Moore-Penrose pseudo-inverse
        M_I_pinv = torch.linalg.pinv(M_I, rcond=rcond)
        alpha = torch.mv(M_I_pinv, b)

    elif method == "adaptive_reg":
        # Adaptive Tikhonov
        if len(pos_eigenvalues) > 0:
            eps = 0.01 * pos_eigenvalues.min()
        else:
            eps = regularization_eps
        reg_eps_used = float(eps.item()) if torch.is_tensor(eps) else eps
        M_I_reg = M_I + eps * torch.eye(K, device=device)
        alpha = torch.linalg.solve(M_I_reg, b)

    elif method == "subspace":
        # Solve in significant eigenspace
        V_sig = eigenvectors[:, significant_mask]  # [K, r]
        lambda_sig = eigenvalues[significant_mask]  # [r]

        # Project b and solve
        b_proj = V_sig.T @ b  # [r]
        alpha_proj = b_proj / lambda_sig  # [r]
        alpha = V_sig @ alpha_proj  # [K]

    elif method == "regularized":
        # Fixed Tikhonov
        reg_eps_used = regularization_eps
        M_I_reg = M_I + regularization_eps * torch.eye(K, device=device)
        alpha = torch.linalg.solve(M_I_reg, b)

    else:
        raise ValueError(
            f"Unknown solver method: {method}. "
            "Use 'pinv', 'adaptive_reg', 'subspace', or 'regularized'."
        )

    # Check for NaN/Inf
    if torch.isnan(alpha).any() or torch.isinf(alpha).any():
        raise ComputationError(
            f"optimal_weights ({method})",
            details=f"Solution contains NaN/Inf. Rank: {effective_rank}/{K}, "
            f"condition: {condition_number:.2e}",
        )

    # Compute residual norm
    residual = torch.mv(M_I, alpha) - b
    b_norm = torch.norm(b)
    residual_norm = float((torch.norm(residual) / (b_norm + 1e-10)).item())

    diagnostics = SolverDiagnostics(
        method=method,
        condition_number=condition_number,
        effective_rank=effective_rank,
        K=K,
        regularization_eps=reg_eps_used,
        residual_norm=residual_norm,
        eigenvalue_min=min_pos_ev,
        eigenvalue_max=float(max_ev.item()),
    )

    return alpha, diagnostics


def compute_optimal_weights(
    M_I: "Tensor",
    I_samples: "Tensor",
    phi_samples: "Tensor",
    regularization_eps: float = 1e-6,
) -> "Tensor":
    """Compute optimal GradCAM weights α* = M_I^{-1} * b.

    Mathematical specification:
        α* = M_I^{-1} * E[I * <I, φ>]

    These weights minimize explanation infidelity:
        INFD(α) = E[(I^T α - (g(z_0) - g(z_0 - I)))²]

    Args:
        M_I: Second moment matrix [K, K].
        I_samples: Perturbation samples [M, K].
        phi_samples: Attribution vectors [M, K].
        regularization_eps: Regularization for stability.

    Returns:
        Optimal weights α* [K].
    """
    b = compute_cross_moment(I_samples, phi_samples)
    alpha_opt = solve_linear_system(M_I, b, regularization_eps)
    return alpha_opt


def compute_optimal_weights_full(
    I_samples: "Tensor",
    phi_samples: "Tensor",
    regularization_eps: float = 1e-6,
) -> tuple["Tensor", "Tensor", "Tensor"]:
    """Compute optimal weights with all intermediate values.

    Args:
        I_samples: Perturbation samples [M, K].
        phi_samples: Attribution vectors [M, K].
        regularization_eps: Regularization constant.

    Returns:
        Tuple of (alpha_opt [K], M_I [K, K], b [K]).
    """
    M_I = compute_second_moment_matrix(I_samples)
    b = compute_cross_moment(I_samples, phi_samples)
    alpha_opt = solve_linear_system(M_I, b, regularization_eps)
    return alpha_opt, M_I, b


def verify_optimality(
    alpha: "Tensor",
    M_I: "Tensor",
    I_samples: "Tensor",
    phi_samples: "Tensor",
    tolerance: float = 1e-4,
) -> tuple[bool, float]:
    """Verify that weights satisfy first-order optimality condition.

    Checks: M_I @ α* ≈ E[I * <I, φ>] = b

    Args:
        alpha: Weights to verify [K].
        M_I: Second moment matrix [K, K].
        I_samples: Perturbation samples [M, K].
        phi_samples: Attribution vectors [M, K].
        tolerance: Tolerance for verification.

    Returns:
        Tuple of (passed, residual_norm).
    """
    b = compute_cross_moment(I_samples, phi_samples)

    residual = torch.mv(M_I, alpha) - b
    residual_norm = float((torch.norm(residual) / (torch.norm(b) + 1e-10)).item())

    passed = residual_norm < tolerance
    return passed, residual_norm


def compute_infidelity(
    alpha: "Tensor",
    I_samples: "Tensor",
    g_z0: "Tensor",
    g_perturbed: "Tensor",
) -> float:
    """Compute explanation infidelity for given weights.

    Mathematical specification:
        INFD(α) = E[(I^T α - (g(z_0) - g(z_0 - I)))²]

    Args:
        alpha: Importance weights [K].
        I_samples: Perturbation samples [M, K].
        g_z0: Model output at reference g(z_0) (scalar).
        g_perturbed: Model outputs at perturbed points [M].

    Returns:
        Infidelity value (lower is better).
    """
    # Predicted change: I^T @ α for each sample
    predicted_change = torch.mv(I_samples, alpha)  # [M]

    # Actual change: g(z_0) - g(z_0 - I)
    actual_change = g_z0 - g_perturbed  # [M]

    # Mean squared error
    squared_error = (predicted_change - actual_change) ** 2
    infidelity = float(squared_error.mean().item())

    return infidelity
