"""Unit tests for the optimal_weights module.

Tests linear system solvers and optimal weight computation.
"""

from __future__ import annotations

import pytest
import torch
from torch import Tensor

from expected_gradcam.core.optimal_weights import (
    compute_optimal_weights,
    compute_optimal_weights_full,
    solve_linear_system,
    solve_linear_system_robust,
    verify_optimality,
)
from expected_gradcam.core.second_moment import compute_second_moment_matrix


class TestSolveLinearSystemRobust:
    """Test robust linear system solver."""

    @pytest.fixture
    def well_conditioned_matrix(self) -> Tensor:
        """Create well-conditioned positive definite matrix."""
        torch.manual_seed(42)
        A = torch.randn(64, 64)
        return A @ A.T + torch.eye(64) * 0.1

    @pytest.fixture
    def ill_conditioned_matrix(self) -> Tensor:
        """Create ill-conditioned matrix."""
        torch.manual_seed(42)
        Q = torch.linalg.qr(torch.randn(64, 64))[0]
        S = torch.logspace(-10, 0, 64)
        return Q @ torch.diag(S) @ Q.T

    @pytest.fixture
    def singular_matrix(self) -> Tensor:
        """Create rank-deficient matrix."""
        torch.manual_seed(42)
        A = torch.randn(64, 32)
        return A @ A.T

    @pytest.fixture
    def sample_rhs(self) -> Tensor:
        """Create sample right-hand side vector."""
        torch.manual_seed(42)
        return torch.randn(64)

    def test_well_conditioned_pinv(self, well_conditioned_matrix, sample_rhs):
        """Test pinv solver on well-conditioned matrix."""
        solution, diagnostics = solve_linear_system_robust(
            well_conditioned_matrix, sample_rhs, method="pinv"
        )

        assert solution.shape == (64,)
        assert diagnostics.condition_number < 1e10

    def test_well_conditioned_regularized(self, well_conditioned_matrix, sample_rhs):
        """Test regularized solver on well-conditioned matrix."""
        solution, diagnostics = solve_linear_system_robust(
            well_conditioned_matrix, sample_rhs, method="regularized"
        )

        assert solution.shape == (64,)

    def test_solution_accuracy(self, well_conditioned_matrix, sample_rhs):
        """Test solution accuracy on well-conditioned matrix."""
        solution, _ = solve_linear_system_robust(
            well_conditioned_matrix, sample_rhs, method="pinv"
        )

        # Check A @ x ≈ b
        reconstructed = well_conditioned_matrix @ solution
        error = torch.norm(reconstructed - sample_rhs) / torch.norm(sample_rhs)

        assert error < 1e-4

    def test_ill_conditioned_with_regularization(
        self, ill_conditioned_matrix, sample_rhs
    ):
        """Test solver handles ill-conditioned matrix with regularization."""
        solution, diagnostics = solve_linear_system_robust(
            ill_conditioned_matrix,
            sample_rhs,
            method="pinv",
            regularization_eps=1e-6,
        )

        assert solution.shape == (64,)
        assert not torch.isnan(solution).any()
        assert not torch.isinf(solution).any()

    def test_singular_matrix_handling(self, singular_matrix, sample_rhs):
        """Test solver handles singular matrix."""
        solution, diagnostics = solve_linear_system_robust(
            singular_matrix, sample_rhs, method="pinv"
        )

        assert solution.shape == (64,)
        assert not torch.isnan(solution).any()
        # Effective rank should be less than full
        assert diagnostics.effective_rank < 64

    def test_adaptive_regularization(self, ill_conditioned_matrix, sample_rhs):
        """Test adaptive regularization method."""
        solution, diagnostics = solve_linear_system_robust(
            ill_conditioned_matrix,
            sample_rhs,
            method="adaptive_reg",
        )

        assert solution.shape == (64,)
        # adaptive_reg sets regularization_eps
        assert diagnostics.regularization_eps is not None

    def test_subspace_method(self, singular_matrix, sample_rhs):
        """Test subspace projection method."""
        solution, diagnostics = solve_linear_system_robust(
            singular_matrix, sample_rhs, method="subspace"
        )

        assert solution.shape == (64,)
        assert not torch.isnan(solution).any()

    def test_diagnostics_returned(self, well_conditioned_matrix, sample_rhs):
        """Test that diagnostics are properly returned."""
        _, diagnostics = solve_linear_system_robust(
            well_conditioned_matrix, sample_rhs, method="pinv"
        )

        assert diagnostics.condition_number is not None
        assert diagnostics.effective_rank is not None
        assert diagnostics.method == "pinv"
        assert diagnostics.K == 64

    def test_diagnostics_properties(self, well_conditioned_matrix, sample_rhs):
        """Test diagnostics properties."""
        _, diagnostics = solve_linear_system_robust(
            well_conditioned_matrix, sample_rhs, method="pinv"
        )

        # Well-conditioned should have full rank
        assert diagnostics.is_full_rank
        assert diagnostics.is_well_conditioned

    def test_batch_rhs(self, well_conditioned_matrix):
        """Test solving with batch of RHS vectors."""
        torch.manual_seed(42)
        rhs_batch = torch.randn(64, 5)  # 5 right-hand sides

        solutions = []
        for i in range(5):
            sol, _ = solve_linear_system_robust(
                well_conditioned_matrix, rhs_batch[:, i], method="pinv"
            )
            solutions.append(sol)

        solutions = torch.stack(solutions, dim=1)
        assert solutions.shape == (64, 5)

    def test_invalid_method_raises(self, well_conditioned_matrix, sample_rhs):
        """Test that invalid method raises ValueError."""
        with pytest.raises(ValueError, match="Unknown solver method"):
            solve_linear_system_robust(
                well_conditioned_matrix, sample_rhs, method="invalid"
            )


class TestSolveLinearSystem:
    """Test basic solve_linear_system function."""

    def test_basic_solve(self):
        """Test basic linear system solving."""
        torch.manual_seed(42)
        A = torch.randn(64, 64)
        M = A @ A.T + torch.eye(64) * 0.1
        b = torch.randn(64)

        solution = solve_linear_system(M, b)

        assert solution.shape == (64,)
        assert not torch.isnan(solution).any()


class TestComputeOptimalWeights:
    """Test compute_optimal_weights function."""

    @pytest.fixture
    def sample_data(self) -> tuple[Tensor, Tensor, Tensor]:
        """Create sample I_samples, phi_samples, and M_I."""
        torch.manual_seed(42)
        M = 100
        K = 64

        # Random perturbation samples
        I_samples = torch.rand(M, K)

        # Compute second moment matrix
        M_I = compute_second_moment_matrix(I_samples)

        # Random attribution samples
        phi_samples = torch.randn(M, K)

        return M_I, I_samples, phi_samples

    def test_optimal_weights_computation(self, sample_data):
        """Test basic optimal weights computation."""
        M_I, I_samples, phi_samples = sample_data

        weights = compute_optimal_weights(M_I, I_samples, phi_samples)

        assert weights.shape == (64,)
        assert not torch.isnan(weights).any()

    def test_weights_solve_system(self, sample_data):
        """Test that weights approximately solve the first-order condition."""
        M_I, I_samples, phi_samples = sample_data

        weights = compute_optimal_weights(M_I, I_samples, phi_samples)

        # Verify with verify_optimality function
        passed, residual = verify_optimality(weights, M_I, I_samples, phi_samples)

        # Should have reasonable residual
        assert residual < 0.1

    def test_with_regularization(self, sample_data):
        """Test computation with explicit regularization."""
        M_I, I_samples, phi_samples = sample_data

        weights = compute_optimal_weights(
            M_I, I_samples, phi_samples, regularization_eps=1e-4
        )

        assert weights.shape == (64,)
        assert not torch.isnan(weights).any()


class TestComputeOptimalWeightsFull:
    """Test compute_optimal_weights_full function."""

    def test_returns_all_values(self):
        """Test that full version returns all intermediate values."""
        torch.manual_seed(42)
        M = 100
        K = 64

        I_samples = torch.rand(M, K)
        phi_samples = torch.randn(M, K)

        alpha_opt, M_I, b = compute_optimal_weights_full(I_samples, phi_samples)

        assert alpha_opt.shape == (K,)
        assert M_I.shape == (K, K)
        assert b.shape == (K,)


class TestVerifyOptimality:
    """Test verify_optimality function."""

    def test_optimal_weights_pass(self):
        """Test that optimal weights pass verification."""
        torch.manual_seed(42)
        M = 100
        K = 64

        I_samples = torch.rand(M, K)
        phi_samples = torch.randn(M, K)

        alpha_opt, M_I, b = compute_optimal_weights_full(I_samples, phi_samples)

        passed, residual = verify_optimality(alpha_opt, M_I, I_samples, phi_samples)

        # Should pass with tight tolerance
        assert residual < 0.01

    def test_random_weights_fail(self):
        """Test that random weights fail verification."""
        torch.manual_seed(42)
        M = 100
        K = 64

        I_samples = torch.rand(M, K)
        phi_samples = torch.randn(M, K)
        M_I = compute_second_moment_matrix(I_samples)

        # Random weights (not optimal)
        random_weights = torch.randn(K)

        passed, residual = verify_optimality(
            random_weights, M_I, I_samples, phi_samples, tolerance=0.01
        )

        # Random weights should have large residual
        assert residual > 0.1
