"""Integration tests with real ImageNet models.

These tests require ImageNet dataset and pretrained models.
Run with: pytest --imagenet=/path/to/imagenet tests/integration/test_imagenet.py
"""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest
import torch
from torch import Tensor, nn

# Skip all tests in this module if ImageNet not available
pytestmark = pytest.mark.imagenet


class TestPretrainedModels:
    """Test with pretrained torchvision models."""

    @pytest.fixture
    def resnet50(self):
        """Load pretrained ResNet-50."""
        try:
            from torchvision.models import resnet50, ResNet50_Weights

            model = resnet50(weights=ResNet50_Weights.DEFAULT)
            model.eval()
            return model
        except ImportError:
            pytest.skip("torchvision not available")

    @pytest.fixture
    def sample_image(self) -> Tensor:
        """Create sample ImageNet-style image."""
        torch.manual_seed(42)
        # ImageNet normalized
        return torch.randn(1, 3, 224, 224)

    @pytest.fixture
    def imagenet_transform(self):
        """Get ImageNet preprocessing transform."""
        try:
            from torchvision import transforms

            return transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ])
        except ImportError:
            pytest.skip("torchvision not available")

    @pytest.mark.slow
    def test_resnet50_forward(self, resnet50, sample_image):
        """Test ResNet-50 forward pass."""
        with torch.no_grad():
            output = resnet50(sample_image)

        assert output.shape == (1, 1000)

    @pytest.mark.slow
    def test_resnet50_feature_extraction(self, resnet50, sample_image):
        """Test feature extraction from ResNet-50 layer4."""
        from expected_gradcam.hooks import FeatureMapHook

        with FeatureMapHook(resnet50.layer4) as hook:
            _ = resnet50(sample_image)

            assert hook.features is not None
            assert hook.features.shape == (1, 2048, 7, 7)

    @pytest.mark.slow
    def test_resnet50_gradcam_style(self, resnet50, sample_image):
        """Test GradCAM-style computation on ResNet-50."""
        from expected_gradcam.hooks import FeatureMapHook
        from expected_gradcam.core.heatmap import generate_heatmap

        sample_image.requires_grad_(True)
        target_class = 243  # Imaginary class index

        with FeatureMapHook(resnet50.layer4) as hook:
            output = resnet50(sample_image)
            score = output[0, target_class]
            score.backward()

            features = hook.features
            gradients = hook.gradients

        assert features is not None
        assert gradients is not None

        # GradCAM weights (global average of gradients)
        weights = gradients.mean(dim=[2, 3]).squeeze()

        # Generate heatmap
        heatmap = generate_heatmap(weights, features)

        assert heatmap.shape[-2:] == (7, 7)
        assert (heatmap >= 0).all()


class TestImageNetDataset:
    """Test with actual ImageNet images."""

    @pytest.fixture
    def imagenet_path(self, request) -> Path | None:
        """Get ImageNet path from command line."""
        path = request.config.getoption("--imagenet")
        if path is None:
            pytest.skip("ImageNet path not provided")
        return Path(path)

    @pytest.fixture
    def imagenet_transform(self):
        """Get ImageNet preprocessing transform."""
        try:
            from torchvision import transforms

            return transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ])
        except ImportError:
            pytest.skip("torchvision not available")

    @pytest.fixture
    def sample_imagenet_image(self, imagenet_path, imagenet_transform) -> tuple:
        """Load a sample ImageNet image."""
        from PIL import Image

        # Find first image in val set
        val_path = imagenet_path / "val"
        if not val_path.exists():
            pytest.skip("ImageNet val directory not found")

        # Get first class directory
        class_dirs = list(val_path.iterdir())
        if not class_dirs:
            pytest.skip("No class directories found")

        # Get first image
        images = list(class_dirs[0].glob("*.JPEG"))
        if not images:
            pytest.skip("No images found")

        img_path = images[0]
        img = Image.open(img_path).convert("RGB")
        tensor = imagenet_transform(img).unsqueeze(0)

        return tensor, np.array(img)

    @pytest.mark.slow
    def test_real_image_processing(self, sample_imagenet_image):
        """Test with real ImageNet image."""
        tensor, numpy_img = sample_imagenet_image

        assert tensor.shape == (1, 3, 224, 224)
        assert numpy_img.shape[2] == 3

    @pytest.mark.slow
    def test_full_pipeline_real_image(self, sample_imagenet_image):
        """Test full E-GradCAM pipeline with real image."""
        try:
            from torchvision.models import resnet50, ResNet50_Weights
        except ImportError:
            pytest.skip("torchvision not available")

        from expected_gradcam.hooks import FeatureMapHook
        from expected_gradcam.core.heatmap import process_heatmap
        from expected_gradcam.core.optimal_weights import compute_optimal_weights

        tensor, numpy_img = sample_imagenet_image

        # Load model
        model = resnet50(weights=ResNet50_Weights.DEFAULT)
        model.eval()

        # Get prediction
        with torch.no_grad():
            output = model(tensor)
            pred_class = output.argmax(dim=1).item()

        # Get features
        with FeatureMapHook(model.layer4) as hook:
            _ = model(tensor)
            features = hook.features

        # Simulate E-GradCAM computation
        K = features.shape[1]  # 2048 for ResNet-50

        # Sample perturbations (simplified)
        M = 20
        torch.manual_seed(42)
        I_samples = torch.rand(M, K)

        # Compute simplified second moment
        M_I = (I_samples.T @ I_samples) / M

        # Compute simplified cross moment (random for test)
        cross_moment = torch.randn(K)

        # Compute optimal weights
        weights, diagnostics = compute_optimal_weights(M_I, cross_moment)

        # Generate heatmap
        heatmap = process_heatmap(
            weights,
            features,
            target_size=(224, 224),
            normalize=True,
        )

        assert heatmap.shape[-2:] == (224, 224)
        assert heatmap.min() >= 0.0
        assert heatmap.max() <= 1.0
        assert not torch.isnan(heatmap).any()


class TestMultipleArchitectures:
    """Test E-GradCAM with multiple architectures."""

    @pytest.fixture(params=["resnet18", "resnet50", "vgg16", "densenet121"])
    def model_name(self, request):
        """Parameterized model names."""
        return request.param

    @pytest.fixture
    def model_and_layer(self, model_name):
        """Load model and get target layer."""
        try:
            import torchvision.models as models
        except ImportError:
            pytest.skip("torchvision not available")

        if model_name == "resnet18":
            model = models.resnet18(weights="DEFAULT")
            target_layer = model.layer4
        elif model_name == "resnet50":
            model = models.resnet50(weights="DEFAULT")
            target_layer = model.layer4
        elif model_name == "vgg16":
            model = models.vgg16(weights="DEFAULT")
            target_layer = model.features[-1]
        elif model_name == "densenet121":
            model = models.densenet121(weights="DEFAULT")
            target_layer = model.features.denseblock4
        else:
            pytest.skip(f"Unknown model: {model_name}")

        model.eval()
        return model, target_layer

    @pytest.fixture
    def sample_image(self) -> Tensor:
        """Create sample image."""
        torch.manual_seed(42)
        return torch.randn(1, 3, 224, 224)

    @pytest.mark.slow
    def test_feature_extraction(self, model_and_layer, sample_image):
        """Test feature extraction works for each architecture."""
        from expected_gradcam.hooks import FeatureMapHook

        model, target_layer = model_and_layer

        with FeatureMapHook(target_layer) as hook:
            _ = model(sample_image)

            assert hook.features is not None
            assert hook.features.ndim == 4  # [B, C, H, W]
