"""Pytest fixtures for metrics tests.

Provides:
- Tensor fixtures for optimal weights (alpha)
- Feature perturbation samples (I_samples)
- Model outputs (g_z0, g_perturbed)
- Well-conditioned and ill-conditioned matrices
- Heatmap fixtures
"""

from __future__ import annotations

import pytest
import torch
from torch import Tensor


# =============================================================================
# Optimal Weights Fixtures
# =============================================================================


@pytest.fixture
def sample_alpha() -> Tensor:
    """Create sample optimal weights [K=64]."""
    torch.manual_seed(42)
    return torch.randn(64)


@pytest.fixture
def sample_alpha_positive() -> Tensor:
    """Create sample positive optimal weights [K=64]."""
    torch.manual_seed(42)
    return torch.rand(64) + 0.1


@pytest.fixture
def sample_alpha_sparse() -> Tensor:
    """Create sparse optimal weights (mostly zeros) [K=64]."""
    torch.manual_seed(42)
    alpha = torch.zeros(64)
    # Only 10 non-zero values
    indices = torch.randperm(64)[:10]
    alpha[indices] = torch.randn(10)
    return alpha


# =============================================================================
# Infidelity Computation Fixtures
# =============================================================================


@pytest.fixture
def sample_I_samples() -> Tensor:
    """Create sample feature space perturbations [M=100, K=64]."""
    torch.manual_seed(42)
    return torch.randn(100, 64)


@pytest.fixture
def sample_g_z0() -> float:
    """Create reference model output (scalar)."""
    return 5.0


@pytest.fixture
def sample_g_perturbed(sample_I_samples: Tensor, sample_alpha: Tensor) -> Tensor:
    """Create perturbed model outputs [M=100].

    For testing, generate outputs that have some correlation with
    I_samples @ alpha (the expected prediction).
    """
    torch.manual_seed(42)
    M = sample_I_samples.shape[0]
    # Add noise to the "perfect" predictions
    noise = torch.randn(M) * 0.1
    # Return values that will give non-zero but small infidelity
    predicted = torch.mv(sample_I_samples, sample_alpha)
    return 5.0 - predicted + noise


@pytest.fixture
def perfect_g_perturbed(sample_I_samples: Tensor, sample_alpha: Tensor) -> Tensor:
    """Create perfect perturbed outputs (zero infidelity) [M=100]."""
    predicted = torch.mv(sample_I_samples, sample_alpha)
    return 5.0 - predicted


# =============================================================================
# Matrix Fixtures for Solver Metrics
# =============================================================================


@pytest.fixture
def well_conditioned_M_I() -> Tensor:
    """Create well-conditioned M_I matrix [K=64, K=64].

    Condition number ~10.
    """
    torch.manual_seed(42)
    A = torch.randn(64, 64)
    # Add strong diagonal to improve conditioning
    return A @ A.T + torch.eye(64) * 10.0


@pytest.fixture
def ill_conditioned_M_I() -> Tensor:
    """Create ill-conditioned M_I matrix [K=64, K=64].

    Condition number ~1e8.
    """
    torch.manual_seed(42)
    U = torch.linalg.qr(torch.randn(64, 64))[0]
    # Eigenvalues from 1e-8 to 1
    S = torch.logspace(-8, 0, 64)
    return U @ torch.diag(S) @ U.T


@pytest.fixture
def rank_deficient_M_I() -> Tensor:
    """Create rank-deficient M_I matrix [K=64, K=64].

    Rank 32 out of 64.
    """
    torch.manual_seed(42)
    A = torch.randn(64, 32)
    return A @ A.T


@pytest.fixture
def sample_b() -> Tensor:
    """Create sample right-hand side vector [K=64]."""
    torch.manual_seed(42)
    return torch.randn(64)


# =============================================================================
# Heatmap Fixtures
# =============================================================================


@pytest.fixture
def sample_heatmap_2d() -> Tensor:
    """Create sample 2D heatmap [H=8, W=8]."""
    torch.manual_seed(42)
    return torch.rand(8, 8)


@pytest.fixture
def sample_heatmap_3d() -> Tensor:
    """Create sample batched heatmap [B=1, H=8, W=8]."""
    torch.manual_seed(42)
    return torch.rand(1, 8, 8)


@pytest.fixture
def uniform_heatmap() -> Tensor:
    """Create uniform heatmap (max entropy)."""
    return torch.ones(8, 8) / 64


@pytest.fixture
def focused_heatmap() -> Tensor:
    """Create focused heatmap (low entropy, high Gini)."""
    heatmap = torch.zeros(8, 8)
    heatmap[3:5, 3:5] = 0.25  # Concentrated in center
    return heatmap


@pytest.fixture
def single_point_heatmap() -> Tensor:
    """Create single-point heatmap (minimum entropy, max Gini)."""
    heatmap = torch.zeros(8, 8)
    heatmap[4, 4] = 1.0
    return heatmap


# =============================================================================
# Device Fixtures
# =============================================================================


@pytest.fixture
def cuda_tensors(
    sample_alpha: Tensor,
    sample_I_samples: Tensor,
    well_conditioned_M_I: Tensor,
) -> dict[str, Tensor] | None:
    """Move tensors to CUDA if available."""
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")

    return {
        "alpha": sample_alpha.cuda(),
        "I_samples": sample_I_samples.cuda(),
        "M_I": well_conditioned_M_I.cuda(),
    }
