"""Tests for solver metrics (condition number, effective rank, residual norm)."""

from __future__ import annotations

import math

import pytest
import torch
from torch import Tensor

from expected_gradcam.metrics.exceptions import InvalidMetricInputError
from expected_gradcam.metrics.solver import (
    ConditionNumber,
    EffectiveRank,
    EigenvalueSpread,
    ResidualNorm,
    analyze_condition,
)


class TestConditionNumber:
    """Tests for ConditionNumber metric."""

    def test_compute_well_conditioned(self, well_conditioned_M_I: Tensor):
        """Well-conditioned matrix should have small condition number."""
        metric = ConditionNumber()
        result = metric.compute(M_I=well_conditioned_M_I)

        assert isinstance(result, float)
        assert result >= 1.0  # Condition number is always >= 1
        assert result < 1e3  # Should be small for well-conditioned

    def test_compute_ill_conditioned(self, ill_conditioned_M_I: Tensor):
        """Ill-conditioned matrix should have large condition number."""
        metric = ConditionNumber()
        result = metric.compute(M_I=ill_conditioned_M_I)

        assert result >= 1e6  # Should be large

    def test_identity_condition_is_one(self):
        """Identity matrix should have condition number = 1."""
        metric = ConditionNumber()
        I = torch.eye(64)
        result = metric.compute(M_I=I)

        assert abs(result - 1.0) < 1e-5

    def test_condition_number_scale_invariant(self, well_conditioned_M_I: Tensor):
        """Condition number should be scale-invariant."""
        metric = ConditionNumber()

        cond_original = metric.compute(M_I=well_conditioned_M_I)
        cond_scaled = metric.compute(M_I=well_conditioned_M_I * 100)

        # Should be very close
        assert abs(cond_original - cond_scaled) / cond_original < 0.01

    def test_missing_M_I_raises_error(self):
        """Missing M_I should raise InvalidMetricInputError."""
        metric = ConditionNumber()

        with pytest.raises(InvalidMetricInputError):
            metric.compute(M_I=None)

    def test_eps_parameter_effect(self, rank_deficient_M_I: Tensor):
        """eps parameter should affect condition number for near-singular matrices."""
        metric_no_eps = ConditionNumber(eps=0.0)
        metric_with_eps = ConditionNumber(eps=1e-6)

        # Rank-deficient matrix has inf condition without eps
        cond_no_eps = metric_no_eps.compute(M_I=rank_deficient_M_I)
        cond_with_eps = metric_with_eps.compute(M_I=rank_deficient_M_I)

        # With eps, condition should be finite and smaller
        assert math.isfinite(cond_with_eps)


class TestEffectiveRank:
    """Tests for EffectiveRank metric."""

    def test_compute_full_rank(self, well_conditioned_M_I: Tensor):
        """Full-rank matrix should have effective rank near K."""
        metric = EffectiveRank()
        result = metric.compute(M_I=well_conditioned_M_I)

        K = well_conditioned_M_I.shape[0]
        assert result > K * 0.8  # Should be close to full rank

    def test_compute_rank_deficient(self, rank_deficient_M_I: Tensor):
        """Rank-deficient matrix should have lower effective rank."""
        metric = EffectiveRank()
        result = metric.compute(M_I=rank_deficient_M_I)

        K = rank_deficient_M_I.shape[0]
        # Rank should be ~32 for our fixture
        assert result < K * 0.6

    def test_rank_is_integer(self, well_conditioned_M_I: Tensor):
        """Effective rank should be a positive integer."""
        metric = EffectiveRank()
        result = metric.compute(M_I=well_conditioned_M_I)

        assert isinstance(result, float)  # Returns float but should be int-like
        assert result >= 0

    def test_identity_full_rank(self):
        """Identity matrix should have full effective rank."""
        metric = EffectiveRank()
        K = 64
        I = torch.eye(K)
        result = metric.compute(M_I=I)

        assert result == K

    def test_threshold_parameter(self, ill_conditioned_M_I: Tensor):
        """threshold parameter should affect effective rank count."""
        metric_strict = EffectiveRank(threshold=1e-3)
        metric_loose = EffectiveRank(threshold=1e-10)

        rank_strict = metric_strict.compute(M_I=ill_conditioned_M_I)
        rank_loose = metric_loose.compute(M_I=ill_conditioned_M_I)

        # Stricter threshold -> fewer significant eigenvalues
        assert rank_strict <= rank_loose


class TestEigenvalueSpread:
    """Tests for EigenvalueSpread metric."""

    def test_compute_returns_dict(self, well_conditioned_M_I: Tensor):
        """EigenvalueSpread should return dict with min, max, spread."""
        metric = EigenvalueSpread()
        result = metric.compute(M_I=well_conditioned_M_I)

        assert isinstance(result, dict)
        assert "eigenvalue_min" in result
        assert "eigenvalue_max" in result
        assert "log_spread" in result

    def test_identity_spread(self):
        """Identity matrix should have spread = 0 (all eigenvalues = 1)."""
        metric = EigenvalueSpread()
        I = torch.eye(64)
        result = metric.compute(M_I=I)

        assert result["eigenvalue_min"] == pytest.approx(1.0, rel=1e-5)
        assert result["eigenvalue_max"] == pytest.approx(1.0, rel=1e-5)
        assert result["log_spread"] == pytest.approx(0.0, abs=1e-5)

    def test_eigenvalue_max_geq_min(self, well_conditioned_M_I: Tensor):
        """max eigenvalue should be >= min eigenvalue."""
        metric = EigenvalueSpread()
        result = metric.compute(M_I=well_conditioned_M_I)

        assert result["eigenvalue_max"] >= result["eigenvalue_min"]


class TestResidualNorm:
    """Tests for ResidualNorm metric."""

    def test_compute_basic(
        self,
        well_conditioned_M_I: Tensor,
        sample_alpha: Tensor,
        sample_b: Tensor,
    ):
        """Basic computation should return a float."""
        metric = ResidualNorm()
        result = metric.compute(
            M_I=well_conditioned_M_I,
            alpha=sample_alpha,
            b=sample_b,
        )

        assert isinstance(result, float)
        assert result >= 0.0

    def test_exact_solution_zero_residual(self, well_conditioned_M_I: Tensor):
        """Exact solution should have near-zero residual."""
        # Solve M_I @ alpha = b exactly
        b = torch.randn(64)
        alpha = torch.linalg.solve(well_conditioned_M_I, b)

        metric = ResidualNorm()
        result = metric.compute(
            M_I=well_conditioned_M_I,
            alpha=alpha,
            b=b,
        )

        assert result < 1e-5  # Should be very small

    def test_wrong_solution_large_residual(
        self,
        well_conditioned_M_I: Tensor,
        sample_b: Tensor,
    ):
        """Random alpha should give larger residual than exact solution."""
        # Exact solution
        alpha_exact = torch.linalg.solve(well_conditioned_M_I, sample_b)

        # Random solution
        alpha_random = torch.randn(64)

        metric = ResidualNorm()
        resid_exact = metric.compute(
            M_I=well_conditioned_M_I,
            alpha=alpha_exact,
            b=sample_b,
        )
        resid_random = metric.compute(
            M_I=well_conditioned_M_I,
            alpha=alpha_random,
            b=sample_b,
        )

        assert resid_random > resid_exact

    def test_relative_vs_absolute(
        self,
        well_conditioned_M_I: Tensor,
        sample_alpha: Tensor,
        sample_b: Tensor,
    ):
        """relative=True should normalize by ||b||."""
        metric_rel = ResidualNorm(relative=True)
        metric_abs = ResidualNorm(relative=False)

        resid_rel = metric_rel.compute(
            M_I=well_conditioned_M_I,
            alpha=sample_alpha,
            b=sample_b,
        )
        resid_abs = metric_abs.compute(
            M_I=well_conditioned_M_I,
            alpha=sample_alpha,
            b=sample_b,
        )

        # Absolute should be >= relative (divided by ||b||)
        b_norm = torch.norm(sample_b).item()
        assert abs(resid_abs - resid_rel * b_norm) < 1e-5


class TestAnalyzeCondition:
    """Tests for analyze_condition convenience function."""

    def test_returns_full_analysis(self, well_conditioned_M_I: Tensor):
        """analyze_condition should return comprehensive analysis."""
        result = analyze_condition(well_conditioned_M_I)

        assert "condition_number" in result
        assert "effective_rank" in result
        assert "eigenvalue_min" in result
        assert "eigenvalue_max" in result
        assert "log_spread" in result
        assert "rank_percentage" in result
        assert "is_well_conditioned" in result

    def test_well_conditioned_flag(self, well_conditioned_M_I: Tensor):
        """Well-conditioned matrix should be flagged as such."""
        result = analyze_condition(well_conditioned_M_I)
        assert result["is_well_conditioned"] is True

    def test_ill_conditioned_flag(self, ill_conditioned_M_I: Tensor):
        """Ill-conditioned matrix should be flagged as such."""
        result = analyze_condition(ill_conditioned_M_I)
        assert result["is_well_conditioned"] is False

    def test_rank_percentage(self, rank_deficient_M_I: Tensor):
        """Rank percentage should be < 100 for rank-deficient matrix."""
        result = analyze_condition(rank_deficient_M_I)
        assert result["rank_percentage"] < 100.0


class TestSolverMetricsCUDA:
    """CUDA-specific tests for solver metrics."""

    @pytest.mark.gpu
    def test_condition_number_cuda(self, cuda_tensors):
        """ConditionNumber should work on CUDA tensors."""
        if cuda_tensors is None:
            pytest.skip("CUDA not available")

        metric = ConditionNumber()
        result = metric.compute(M_I=cuda_tensors["M_I"])

        assert isinstance(result, float)
        assert result >= 1.0
