"""Pytest configuration and shared fixtures for Expected GradCAM tests.

This module provides common fixtures used across all test modules:
- Model fixtures: simple_cnn, resnet_small
- Mock fixtures: mock_sam, mock_dino
- Data fixtures: sample_image, sample_features, sample_heatmap
- Utility fixtures: device, tmp_path_factory
"""

from __future__ import annotations

import sys
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pytest
import torch
from torch import Tensor, nn

# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

if TYPE_CHECKING:
    from numpy.typing import NDArray


# =============================================================================
# Device Fixtures
# =============================================================================


@pytest.fixture(scope="session")
def device() -> torch.device:
    """Get the best available device for testing."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


@pytest.fixture
def cpu_device() -> torch.device:
    """Get CPU device."""
    return torch.device("cpu")


@pytest.fixture
def cuda_device() -> torch.device | None:
    """Get CUDA device if available, else skip."""
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")
    return torch.device("cuda")


# =============================================================================
# Simple CNN Model Fixtures
# =============================================================================


class SimpleCNN(nn.Module):
    """A minimal CNN for testing purposes.

    Architecture:
        - conv1: 3 -> 16 channels, 3x3 kernel
        - conv2: 16 -> 32 channels, 3x3 kernel
        - conv3: 32 -> 64 channels, 3x3 kernel (target layer)
        - pool: AdaptiveAvgPool2d(1)
        - fc: 64 -> num_classes
    """

    def __init__(self, num_classes: int = 10, input_size: int = 32) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.input_size = input_size

        # Feature extraction layers
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        # Pooling and classifier
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, num_classes)

    @property
    def target_layer(self) -> nn.Module:
        """Get the target layer for CAM methods."""
        return self.conv3

    @property
    def classifier_head(self) -> nn.Module:
        """Get the classifier head."""
        return self.fc

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_features(self, x: Tensor) -> Tensor:
        """Get feature maps from target layer."""
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x


class SmallResNet(nn.Module):
    """A small ResNet-like model for testing.

    Has residual connections and mimics ResNet architecture
    for testing architecture detection.
    """

    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.num_classes = num_classes

        # Initial conv
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(2)

        # Residual blocks (layer1-4 naming like ResNet)
        self.layer1 = self._make_layer(32, 32)
        self.layer2 = self._make_layer(32, 64, stride=2)
        self.layer3 = self._make_layer(64, 128, stride=2)
        self.layer4 = self._make_layer(128, 256, stride=2)

        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(
        self, in_channels: int, out_channels: int, stride: int = 1
    ) -> nn.Sequential:
        """Create a residual-like layer."""
        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels),
            )

        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size=3, stride=stride, padding=1
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


@pytest.fixture
def simple_cnn() -> SimpleCNN:
    """Create a simple CNN model for testing."""
    model = SimpleCNN(num_classes=10, input_size=32)
    model.eval()
    return model


@pytest.fixture
def simple_cnn_cuda(cuda_device: torch.device) -> SimpleCNN:
    """Create a simple CNN model on CUDA."""
    model = SimpleCNN(num_classes=10, input_size=32)
    model.eval()
    return model.to(cuda_device)


@pytest.fixture
def small_resnet() -> SmallResNet:
    """Create a small ResNet-like model for testing."""
    model = SmallResNet(num_classes=10)
    model.eval()
    return model


@pytest.fixture
def imagenet_cnn() -> SimpleCNN:
    """Create a CNN matching ImageNet input size (224x224)."""
    model = SimpleCNN(num_classes=1000, input_size=224)
    model.eval()
    return model


# =============================================================================
# Mock SAM Fixtures
# =============================================================================


class MockSAMPredictor:
    """Mock SAM predictor for testing without actual SAM installation."""

    def __init__(self, num_segments: int = 5) -> None:
        self.num_segments = num_segments
        self._image_set = False
        self._image_shape: tuple[int, int] | None = None

    def set_image(self, image: "NDArray[np.uint8]") -> None:
        """Set the image for segmentation."""
        self._image_set = True
        self._image_shape = image.shape[:2]

    def generate(
        self,
        points_per_side: int = 32,
        **kwargs,
    ) -> list[dict]:
        """Generate mock segmentation masks."""
        if not self._image_set or self._image_shape is None:
            raise RuntimeError("Image not set")

        h, w = self._image_shape
        masks = []

        # Generate grid-like segments
        rows = int(np.sqrt(self.num_segments))
        cols = (self.num_segments + rows - 1) // rows

        for i in range(self.num_segments):
            row = i // cols
            col = i % cols

            # Create rectangular mask
            mask = np.zeros((h, w), dtype=bool)
            r_start = row * (h // rows)
            r_end = (row + 1) * (h // rows) if row < rows - 1 else h
            c_start = col * (w // cols)
            c_end = (col + 1) * (w // cols) if col < cols - 1 else w

            mask[r_start:r_end, c_start:c_end] = True

            masks.append(
                {
                    "segmentation": mask,
                    "area": int(mask.sum()),
                    "bbox": [c_start, r_start, c_end - c_start, r_end - r_start],
                    "predicted_iou": 0.9 - i * 0.05,
                    "stability_score": 0.95 - i * 0.02,
                }
            )

        return masks


class MockSAMSegmenter:
    """Mock SAM segmenter matching UnifiedSAMSegmenter interface."""

    def __init__(self, num_segments: int = 5) -> None:
        self.num_segments = num_segments
        self.predictor = MockSAMPredictor(num_segments)

    def segment(
        self,
        image: "NDArray[np.uint8]",
        points_per_side: int = 32,
        **kwargs,
    ) -> tuple["NDArray[np.bool_]", list[dict]]:
        """Segment image and return masks with metadata."""
        self.predictor.set_image(image)
        results = self.predictor.generate(points_per_side=points_per_side)

        masks = np.stack([r["segmentation"] for r in results])
        return masks, results


@pytest.fixture
def mock_sam() -> MockSAMSegmenter:
    """Create a mock SAM segmenter."""
    return MockSAMSegmenter(num_segments=5)


@pytest.fixture
def mock_sam_predictor() -> MockSAMPredictor:
    """Create a mock SAM predictor."""
    return MockSAMPredictor(num_segments=5)


# =============================================================================
# Mock DINO Fixtures
# =============================================================================


class MockDINOFeatureExtractor:
    """Mock DINO feature extractor for testing."""

    def __init__(self, feature_dim: int = 384, patch_size: int = 14) -> None:
        self.feature_dim = feature_dim
        self.patch_size = patch_size

    def extract_features(
        self,
        image: Tensor | "NDArray[np.uint8]",
    ) -> Tensor:
        """Extract mock DINO features."""
        if isinstance(image, np.ndarray):
            h, w = image.shape[:2]
        else:
            h, w = image.shape[-2:]

        # Compute patch grid size
        ph = h // self.patch_size
        pw = w // self.patch_size

        # Generate deterministic but varied features
        features = torch.randn(1, ph * pw, self.feature_dim)
        # Normalize to unit vectors
        features = features / features.norm(dim=-1, keepdim=True)

        return features

    def compute_affinity(
        self,
        features: Tensor,
        segment_masks: "NDArray[np.bool_]",
    ) -> "NDArray[np.floating]":
        """Compute mock affinity between segments."""
        n_segments = len(segment_masks)

        # Generate random but symmetric affinity matrix
        affinity = np.random.rand(n_segments, n_segments).astype(np.float32)
        affinity = (affinity + affinity.T) / 2
        np.fill_diagonal(affinity, 1.0)

        return affinity


@pytest.fixture
def mock_dino() -> MockDINOFeatureExtractor:
    """Create a mock DINO feature extractor."""
    return MockDINOFeatureExtractor()


# =============================================================================
# Data Fixtures
# =============================================================================


@pytest.fixture
def sample_image() -> Tensor:
    """Create a sample image tensor [1, 3, 32, 32]."""
    torch.manual_seed(42)
    return torch.randn(1, 3, 32, 32)


@pytest.fixture
def sample_image_224() -> Tensor:
    """Create a sample ImageNet-sized image [1, 3, 224, 224]."""
    torch.manual_seed(42)
    return torch.randn(1, 3, 224, 224)


@pytest.fixture
def sample_image_batch() -> Tensor:
    """Create a batch of sample images [4, 3, 32, 32]."""
    torch.manual_seed(42)
    return torch.randn(4, 3, 32, 32)


@pytest.fixture
def sample_features() -> Tensor:
    """Create sample feature maps [1, 64, 8, 8]."""
    torch.manual_seed(42)
    features = torch.randn(1, 64, 8, 8)
    return torch.relu(features)  # Make non-negative like real activations


@pytest.fixture
def sample_features_batch() -> Tensor:
    """Create a batch of feature maps [4, 64, 8, 8]."""
    torch.manual_seed(42)
    features = torch.randn(4, 64, 8, 8)
    return torch.relu(features)


@pytest.fixture
def sample_heatmap() -> Tensor:
    """Create a sample heatmap [1, 8, 8]."""
    torch.manual_seed(42)
    heatmap = torch.rand(1, 8, 8)
    return heatmap


@pytest.fixture
def sample_weights() -> Tensor:
    """Create sample channel weights [64]."""
    torch.manual_seed(42)
    weights = torch.randn(64)
    return weights


@pytest.fixture
def sample_numpy_image() -> "NDArray[np.uint8]":
    """Create a sample numpy image [32, 32, 3]."""
    np.random.seed(42)
    return np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)


@pytest.fixture
def sample_numpy_image_224() -> "NDArray[np.uint8]":
    """Create a sample ImageNet-sized numpy image [224, 224, 3]."""
    np.random.seed(42)
    return np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)


# =============================================================================
# Matrix Fixtures for Solver Tests
# =============================================================================


@pytest.fixture
def well_conditioned_matrix() -> Tensor:
    """Create a well-conditioned positive definite matrix."""
    torch.manual_seed(42)
    A = torch.randn(64, 64)
    # Make it positive definite
    return A @ A.T + torch.eye(64) * 0.1


@pytest.fixture
def ill_conditioned_matrix() -> Tensor:
    """Create an ill-conditioned matrix for testing regularization."""
    torch.manual_seed(42)
    # Create matrix with varying singular values
    U = torch.linalg.qr(torch.randn(64, 64))[0]
    S = torch.logspace(-10, 0, 64)  # Large condition number
    return U @ torch.diag(S) @ U.T


@pytest.fixture
def singular_matrix() -> Tensor:
    """Create a singular (rank-deficient) matrix."""
    torch.manual_seed(42)
    A = torch.randn(64, 32)
    return A @ A.T  # Rank 32


# =============================================================================
# Temporary Directory Fixtures
# =============================================================================


@pytest.fixture
def temp_cache_dir(tmp_path: Path) -> Path:
    """Create a temporary directory for cache files."""
    cache_dir = tmp_path / "cache"
    cache_dir.mkdir()
    return cache_dir


@pytest.fixture
def temp_output_dir(tmp_path: Path) -> Path:
    """Create a temporary directory for output files."""
    output_dir = tmp_path / "output"
    output_dir.mkdir()
    return output_dir


# =============================================================================
# Skip Markers
# =============================================================================


def pytest_configure(config):
    """Configure custom pytest markers."""
    config.addinivalue_line("markers", "slow: mark test as slow to run")
    config.addinivalue_line("markers", "gpu: mark test as requiring GPU")
    config.addinivalue_line("markers", "integration: mark as integration test")
    config.addinivalue_line("markers", "imagenet: mark as requiring ImageNet dataset")


@pytest.fixture(autouse=True)
def skip_slow_tests(request):
    """Skip slow tests unless --runslow is provided."""
    if request.node.get_closest_marker("slow"):
        if not request.config.getoption("--runslow", default=False):
            pytest.skip("need --runslow option to run")


def pytest_addoption(parser):
    """Add custom command line options."""
    parser.addoption(
        "--runslow", action="store_true", default=False, help="run slow tests"
    )
    parser.addoption(
        "--imagenet",
        type=str,
        default=None,
        help="path to ImageNet dataset for integration tests",
    )


@pytest.fixture
def imagenet_path(request) -> Path | None:
    """Get ImageNet path if provided."""
    path = request.config.getoption("--imagenet")
    if path is None:
        return None
    return Path(path)
