"""Unit tests for the visualization module.

Tests colormap application, heatmap overlay, and comparison utilities.
"""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest
import torch
from torch import Tensor

from expected_gradcam.utils.visualization import (
    COLORMAPS,
    apply_colormap,
    create_comparison,
    create_grid,
    normalize_heatmap,
    numpy_to_pil,
    overlay_heatmap,
    overlay_heatmap_pil,
    pil_to_numpy,
    resize_heatmap,
    save_visualization,
    tensor_to_numpy,
)


class TestColormap:
    """Test colormap functions."""

    def test_jet_colormap(self):
        """Test jet colormap."""
        heatmap = torch.rand(8, 8)
        result = apply_colormap(heatmap, colormap="jet")

        assert result.shape == (8, 8, 3)
        assert result.dtype == np.uint8
        assert result.min() >= 0
        assert result.max() <= 255

    def test_viridis_colormap(self):
        """Test viridis colormap."""
        heatmap = torch.rand(8, 8)
        result = apply_colormap(heatmap, colormap="viridis")

        assert result.shape == (8, 8, 3)
        assert result.dtype == np.uint8

    def test_inferno_colormap(self):
        """Test inferno colormap."""
        heatmap = torch.rand(8, 8)
        result = apply_colormap(heatmap, colormap="inferno")

        assert result.shape == (8, 8, 3)

    def test_hot_colormap(self):
        """Test hot colormap."""
        heatmap = torch.rand(8, 8)
        result = apply_colormap(heatmap, colormap="hot")

        assert result.shape == (8, 8, 3)

    def test_invalid_colormap(self):
        """Test invalid colormap raises error."""
        heatmap = torch.rand(8, 8)

        with pytest.raises(ValueError, match="Unknown colormap"):
            apply_colormap(heatmap, colormap="invalid")

    def test_numpy_input(self):
        """Test with numpy input."""
        heatmap = np.random.rand(8, 8).astype(np.float32)
        result = apply_colormap(heatmap, colormap="jet")

        assert result.shape == (8, 8, 3)

    def test_normalization(self):
        """Test with normalization."""
        heatmap = torch.rand(8, 8) * 100 + 50  # Values in [50, 150]
        result = apply_colormap(heatmap, colormap="jet", normalize=True)

        assert result.shape == (8, 8, 3)

    def test_available_colormaps(self):
        """Test all available colormaps are valid."""
        heatmap = torch.rand(8, 8)

        for name in COLORMAPS:
            result = apply_colormap(heatmap, colormap=name)
            assert result.shape == (8, 8, 3)


class TestTensorToNumpy:
    """Test tensor_to_numpy conversion."""

    def test_basic_conversion(self):
        """Test basic tensor to numpy conversion."""
        tensor = torch.rand(3, 32, 32)
        result = tensor_to_numpy(tensor, denormalize=False)

        assert isinstance(result, np.ndarray)
        assert result.shape == (32, 32, 3)
        assert result.dtype == np.uint8

    def test_batch_dimension(self):
        """Test with batch dimension."""
        tensor = torch.rand(4, 3, 32, 32)
        result = tensor_to_numpy(tensor, denormalize=False)

        # Should take first image from batch
        assert result.shape == (32, 32, 3)

    def test_denormalization(self):
        """Test denormalization."""
        # Create normalized tensor (ImageNet stats)
        tensor = torch.zeros(3, 32, 32)  # Black after normalization
        result = tensor_to_numpy(tensor, denormalize=True)

        # Should be close to ImageNet mean values (gray)
        assert result.mean() > 0


class TestImageConversion:
    """Test image format conversion utilities."""

    def test_numpy_to_pil(self):
        """Test numpy to PIL conversion."""
        from PIL import Image

        arr = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
        img = numpy_to_pil(arr)

        assert isinstance(img, Image.Image)
        assert img.size == (32, 32)

    def test_pil_to_numpy(self):
        """Test PIL to numpy conversion."""
        from PIL import Image

        img = Image.new("RGB", (32, 32), color="red")
        arr = pil_to_numpy(img)

        assert isinstance(arr, np.ndarray)
        assert arr.shape == (32, 32, 3)
        assert arr.dtype == np.uint8


class TestOverlayHeatmap:
    """Test overlay_heatmap function."""

    @pytest.fixture
    def sample_image(self) -> np.ndarray:
        """Create sample image."""
        return np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)

    @pytest.fixture
    def sample_heatmap(self) -> Tensor:
        """Create sample heatmap."""
        return torch.rand(8, 8)

    def test_basic_overlay(self, sample_image, sample_heatmap):
        """Test basic overlay."""
        result = overlay_heatmap(sample_image, sample_heatmap, alpha=0.5)

        assert result.shape == (32, 32, 3)
        assert result.dtype == np.uint8

    def test_alpha_zero(self, sample_image, sample_heatmap):
        """Test alpha=0 returns original image."""
        result = overlay_heatmap(sample_image, sample_heatmap, alpha=0.0)

        # Should be close to original (might differ slightly due to resize)
        # Just check shape is correct
        assert result.shape == sample_image.shape

    def test_alpha_one(self, sample_image, sample_heatmap):
        """Test alpha=1 returns pure heatmap."""
        result = overlay_heatmap(sample_image, sample_heatmap, alpha=1.0)

        assert result.shape == (32, 32, 3)

    def test_different_colormaps(self, sample_image, sample_heatmap):
        """Test with different colormaps."""
        for cmap in ["jet", "viridis", "inferno"]:
            result = overlay_heatmap(
                sample_image, sample_heatmap, colormap=cmap
            )
            assert result.shape == (32, 32, 3)

    def test_tensor_image_input(self, sample_heatmap):
        """Test with tensor image input."""
        image_tensor = torch.rand(3, 32, 32)
        result = overlay_heatmap(image_tensor, sample_heatmap, denormalize=False)

        assert result.shape == (32, 32, 3)

    def test_pil_image_input(self, sample_heatmap):
        """Test with PIL image input."""
        from PIL import Image

        pil_image = Image.new("RGB", (32, 32), color="blue")
        result = overlay_heatmap(pil_image, sample_heatmap)

        assert result.shape == (32, 32, 3)


class TestOverlayHeatmapPil:
    """Test overlay_heatmap_pil function."""

    def test_returns_pil(self):
        """Test returns PIL Image."""
        from PIL import Image

        image = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
        heatmap = torch.rand(8, 8)

        result = overlay_heatmap_pil(image, heatmap)

        assert isinstance(result, Image.Image)


class TestResizeHeatmap:
    """Test resize_heatmap function."""

    def test_upscale(self):
        """Test upscaling."""
        heatmap = torch.rand(8, 8)
        result = resize_heatmap(heatmap, size=(32, 32))

        assert result.shape == (32, 32)

    def test_downscale(self):
        """Test downscaling."""
        heatmap = torch.rand(32, 32)
        result = resize_heatmap(heatmap, size=(8, 8))

        assert result.shape == (8, 8)

    def test_batch_resize(self):
        """Test batch resizing."""
        heatmap = torch.rand(4, 8, 8)
        result = resize_heatmap(heatmap, size=(16, 16))

        assert result.shape == (4, 16, 16)


class TestNormalizeHeatmap:
    """Test normalize_heatmap function."""

    def test_minmax(self):
        """Test minmax normalization."""
        heatmap = torch.rand(8, 8) * 10
        result = normalize_heatmap(heatmap, method="minmax")

        assert result.min() >= 0.0
        assert result.max() <= 1.0

    def test_percentile(self):
        """Test percentile normalization."""
        heatmap = torch.rand(8, 8) * 10
        result = normalize_heatmap(heatmap, method="percentile")

        assert result.min() >= 0.0
        assert result.max() <= 1.0

    def test_std(self):
        """Test std normalization."""
        heatmap = torch.randn(8, 8)
        result = normalize_heatmap(heatmap, method="std")

        assert result.min() >= 0.0
        assert result.max() <= 1.0


class TestCreateComparison:
    """Test create_comparison function."""

    def test_creates_figure(self):
        """Test creates matplotlib figure."""
        pytest.importorskip("matplotlib")

        image = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
        heatmaps = {
            "Method A": torch.rand(8, 8),
            "Method B": torch.rand(8, 8),
        }

        fig = create_comparison(image, heatmaps)

        assert hasattr(fig, "savefig")

    def test_with_title(self):
        """Test with title."""
        pytest.importorskip("matplotlib")

        image = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
        heatmaps = {"Method A": torch.rand(8, 8)}

        fig = create_comparison(image, heatmaps, title="Test Comparison")

        assert hasattr(fig, "savefig")


class TestCreateGrid:
    """Test create_grid function."""

    def test_creates_grid(self):
        """Test creates image grid."""
        pytest.importorskip("matplotlib")

        images = [np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8) for _ in range(4)]

        fig = create_grid(images, ncols=2)

        assert hasattr(fig, "savefig")

    def test_with_titles(self):
        """Test with titles."""
        pytest.importorskip("matplotlib")

        images = [np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8) for _ in range(4)]
        titles = ["A", "B", "C", "D"]

        fig = create_grid(images, titles=titles)

        assert hasattr(fig, "savefig")


class TestSaveVisualization:
    """Test save_visualization function."""

    def test_save_numpy_image(self, tmp_path):
        """Test saving numpy image."""
        image = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
        path = tmp_path / "test.png"

        save_visualization(image, path)

        assert path.exists()

    def test_save_pil_image(self, tmp_path):
        """Test saving PIL image."""
        from PIL import Image

        img = Image.new("RGB", (32, 32), color="red")
        path = tmp_path / "test.jpg"

        save_visualization(img, path)

        assert path.exists()

    def test_save_matplotlib_figure(self, tmp_path):
        """Test saving matplotlib figure."""
        plt = pytest.importorskip("matplotlib.pyplot")

        fig, ax = plt.subplots()
        ax.plot([1, 2, 3])
        path = tmp_path / "test.png"

        save_visualization(fig, path)

        assert path.exists()

    def test_creates_parent_dirs(self, tmp_path):
        """Test creates parent directories."""
        image = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
        path = tmp_path / "subdir" / "test.png"

        save_visualization(image, path)

        assert path.exists()
