"""Data fixtures for testing Expected GradCAM.

Provides sample tensors and data for testing:
- Image tensors (various sizes and batch sizes)
- Feature maps
- Heatmaps
- Weight vectors
- Matrices for solver testing
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import torch
from torch import Tensor

if TYPE_CHECKING:
    from numpy.typing import NDArray


# =============================================================================
# Image Fixtures
# =============================================================================


def create_sample_image(
    batch_size: int = 1,
    channels: int = 3,
    height: int = 32,
    width: int = 32,
    seed: int = 42,
    normalized: bool = True,
) -> Tensor:
    """Create sample image tensor.

    Args:
        batch_size: Number of images in batch.
        channels: Number of color channels.
        height: Image height.
        width: Image width.
        seed: Random seed for reproducibility.
        normalized: If True, values are in range suitable for ImageNet normalization.

    Returns:
        Image tensor [B, C, H, W].
    """
    torch.manual_seed(seed)

    if normalized:
        # Generate images that look like ImageNet-normalized inputs
        # Mean ~0, std ~1 per channel
        image = torch.randn(batch_size, channels, height, width)
    else:
        # Generate images in [0, 1] range
        image = torch.rand(batch_size, channels, height, width)

    return image


def create_sample_numpy_image(
    height: int = 224,
    width: int = 224,
    channels: int = 3,
    seed: int = 42,
) -> "NDArray[np.uint8]":
    """Create sample numpy image.

    Args:
        height: Image height.
        width: Image width.
        channels: Number of color channels.
        seed: Random seed.

    Returns:
        Image array [H, W, C] as uint8.
    """
    np.random.seed(seed)
    return np.random.randint(0, 256, (height, width, channels), dtype=np.uint8)


def create_sample_pil_image(
    height: int = 224,
    width: int = 224,
    seed: int = 42,
):
    """Create sample PIL image.

    Args:
        height: Image height.
        width: Image width.
        seed: Random seed.

    Returns:
        PIL Image.
    """
    from PIL import Image

    arr = create_sample_numpy_image(height, width, seed=seed)
    return Image.fromarray(arr)


# =============================================================================
# Feature Map Fixtures
# =============================================================================


def create_sample_features(
    batch_size: int = 1,
    channels: int = 64,
    height: int = 8,
    width: int = 8,
    seed: int = 42,
    non_negative: bool = True,
) -> Tensor:
    """Create sample feature maps.

    Args:
        batch_size: Number of samples.
        channels: Number of feature channels.
        height: Feature map height.
        width: Feature map width.
        seed: Random seed.
        non_negative: If True, apply ReLU to ensure non-negative values
            (like real CNN activations).

    Returns:
        Feature tensor [B, C, H, W].
    """
    torch.manual_seed(seed)
    features = torch.randn(batch_size, channels, height, width)

    if non_negative:
        features = torch.relu(features)

    return features


def create_sample_features_with_pattern(
    batch_size: int = 1,
    channels: int = 64,
    height: int = 8,
    width: int = 8,
    pattern: str = "gaussian",
    seed: int = 42,
) -> Tensor:
    """Create feature maps with specific patterns for testing.

    Args:
        batch_size: Number of samples.
        channels: Number of feature channels.
        height: Feature map height.
        width: Feature map width.
        pattern: Pattern type ("gaussian", "gradient", "checkerboard", "corner").
        seed: Random seed.

    Returns:
        Feature tensor [B, C, H, W].
    """
    torch.manual_seed(seed)

    if pattern == "gaussian":
        # Gaussian blob in center
        y, x = torch.meshgrid(
            torch.linspace(-1, 1, height),
            torch.linspace(-1, 1, width),
            indexing="ij",
        )
        base = torch.exp(-(x**2 + y**2) / 0.5)
        features = base.unsqueeze(0).unsqueeze(0).expand(batch_size, channels, -1, -1)

    elif pattern == "gradient":
        # Horizontal gradient
        x = torch.linspace(0, 1, width).view(1, 1, 1, width)
        features = x.expand(batch_size, channels, height, width)

    elif pattern == "checkerboard":
        # Checkerboard pattern
        y, x = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
        base = ((x + y) % 2).float()
        features = base.unsqueeze(0).unsqueeze(0).expand(batch_size, channels, -1, -1)

    elif pattern == "corner":
        # Activation in top-left corner
        features = torch.zeros(batch_size, channels, height, width)
        features[:, :, : height // 4, : width // 4] = 1.0

    else:
        features = create_sample_features(batch_size, channels, height, width, seed)

    # Add small random variation per channel
    noise = torch.randn(batch_size, channels, 1, 1) * 0.1
    features = features * (1 + noise).abs()

    return features.contiguous()


# =============================================================================
# Heatmap Fixtures
# =============================================================================


def create_sample_heatmap(
    batch_size: int = 1,
    height: int = 8,
    width: int = 8,
    seed: int = 42,
    normalized: bool = True,
) -> Tensor:
    """Create sample heatmap.

    Args:
        batch_size: Number of samples.
        height: Heatmap height.
        width: Heatmap width.
        seed: Random seed.
        normalized: If True, normalize to [0, 1] range.

    Returns:
        Heatmap tensor [B, H, W].
    """
    torch.manual_seed(seed)
    heatmap = torch.rand(batch_size, height, width)

    if normalized:
        # Normalize per sample
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)

    return heatmap


def create_gaussian_heatmap(
    height: int = 8,
    width: int = 8,
    center: tuple[float, float] | None = None,
    sigma: float = 0.3,
) -> Tensor:
    """Create a Gaussian heatmap centered at a specific location.

    Args:
        height: Heatmap height.
        width: Heatmap width.
        center: (y, x) center coordinates in [0, 1] range.
        sigma: Gaussian standard deviation.

    Returns:
        Heatmap tensor [1, H, W].
    """
    if center is None:
        center = (0.5, 0.5)

    y, x = torch.meshgrid(
        torch.linspace(0, 1, height),
        torch.linspace(0, 1, width),
        indexing="ij",
    )

    cy, cx = center
    heatmap = torch.exp(-((y - cy) ** 2 + (x - cx) ** 2) / (2 * sigma**2))

    return heatmap.unsqueeze(0)


# =============================================================================
# Weight Fixtures
# =============================================================================


def create_sample_weights(
    channels: int = 64,
    seed: int = 42,
    sparse: bool = False,
    sparsity: float = 0.8,
) -> Tensor:
    """Create sample channel weights.

    Args:
        channels: Number of channels.
        seed: Random seed.
        sparse: If True, create sparse weights.
        sparsity: Fraction of zero weights if sparse.

    Returns:
        Weight tensor [C].
    """
    torch.manual_seed(seed)
    weights = torch.randn(channels)

    if sparse:
        mask = torch.rand(channels) > sparsity
        weights = weights * mask.float()

    return weights


def create_sample_weight_matrix(
    rows: int = 64,
    cols: int = 64,
    seed: int = 42,
) -> Tensor:
    """Create sample weight matrix.

    Args:
        rows: Number of rows.
        cols: Number of columns.
        seed: Random seed.

    Returns:
        Weight matrix [R, C].
    """
    torch.manual_seed(seed)
    return torch.randn(rows, cols)


# =============================================================================
# Matrix Fixtures for Solver Testing
# =============================================================================


def create_well_conditioned_matrix(
    size: int = 64,
    condition_number: float = 10.0,
    seed: int = 42,
) -> Tensor:
    """Create a well-conditioned positive definite matrix.

    Args:
        size: Matrix size.
        condition_number: Desired condition number.
        seed: Random seed.

    Returns:
        Positive definite matrix [size, size].
    """
    torch.manual_seed(seed)

    # Create orthogonal matrix
    Q = torch.linalg.qr(torch.randn(size, size))[0]

    # Create eigenvalues with desired condition number
    eigenvalues = torch.linspace(1 / condition_number, 1, size)

    # Construct matrix
    matrix = Q @ torch.diag(eigenvalues) @ Q.T

    return matrix


def create_ill_conditioned_matrix(
    size: int = 64,
    condition_number: float = 1e10,
    seed: int = 42,
) -> Tensor:
    """Create an ill-conditioned matrix for testing regularization.

    Args:
        size: Matrix size.
        condition_number: Desired condition number.
        seed: Random seed.

    Returns:
        Ill-conditioned matrix [size, size].
    """
    torch.manual_seed(seed)

    Q = torch.linalg.qr(torch.randn(size, size))[0]
    eigenvalues = torch.logspace(-10, 0, size)

    matrix = Q @ torch.diag(eigenvalues) @ Q.T

    return matrix


def create_singular_matrix(
    size: int = 64,
    rank: int = 32,
    seed: int = 42,
) -> Tensor:
    """Create a rank-deficient (singular) matrix.

    Args:
        size: Matrix size.
        rank: Desired rank (< size).
        seed: Random seed.

    Returns:
        Singular matrix [size, size] with given rank.
    """
    torch.manual_seed(seed)

    A = torch.randn(size, rank)
    matrix = A @ A.T

    return matrix


# =============================================================================
# Segment Mask Fixtures
# =============================================================================


def create_sample_segment_masks(
    num_segments: int = 5,
    height: int = 224,
    width: int = 224,
    pattern: str = "grid",
    seed: int = 42,
) -> "NDArray[np.bool_]":
    """Create sample segment masks.

    Args:
        num_segments: Number of segments.
        height: Mask height.
        width: Mask width.
        pattern: Pattern type ("grid", "random", "concentric").
        seed: Random seed.

    Returns:
        Mask array [N, H, W].
    """
    np.random.seed(seed)

    if pattern == "grid":
        masks = []
        rows = int(np.ceil(np.sqrt(num_segments)))
        cols = int(np.ceil(num_segments / rows))

        for i in range(num_segments):
            row = i // cols
            col = i % cols

            mask = np.zeros((height, width), dtype=bool)
            r_start = row * (height // rows)
            r_end = (row + 1) * (height // rows) if row < rows - 1 else height
            c_start = col * (width // cols)
            c_end = (col + 1) * (width // cols) if col < cols - 1 else width

            mask[r_start:r_end, c_start:c_end] = True
            masks.append(mask)

    elif pattern == "concentric":
        masks = []
        cy, cx = height // 2, width // 2
        max_r = min(height, width) // 2

        for i in range(num_segments):
            r_outer = max_r * (num_segments - i) / num_segments
            r_inner = max_r * (num_segments - i - 1) / num_segments

            y, x = np.ogrid[:height, :width]
            dist = np.sqrt((y - cy) ** 2 + (x - cx) ** 2)
            mask = (dist <= r_outer) & (dist > r_inner)
            masks.append(mask)

    else:  # random
        masks = []
        for _ in range(num_segments):
            cy = np.random.randint(height // 4, 3 * height // 4)
            cx = np.random.randint(width // 4, 3 * width // 4)
            r = np.random.randint(min(height, width) // 8, min(height, width) // 4)

            y, x = np.ogrid[:height, :width]
            mask = ((y - cy) ** 2 + (x - cx) ** 2) <= r**2
            masks.append(mask)

    return np.stack(masks)


# =============================================================================
# Dataset Fixtures
# =============================================================================


def create_sample_dataset(
    num_samples: int = 100,
    image_size: int = 32,
    num_classes: int = 10,
    seed: int = 42,
) -> tuple[Tensor, Tensor]:
    """Create a sample dataset for testing.

    Args:
        num_samples: Number of samples.
        image_size: Image size.
        num_classes: Number of classes.
        seed: Random seed.

    Returns:
        Tuple of (images [N, 3, H, W], labels [N]).
    """
    torch.manual_seed(seed)

    images = torch.randn(num_samples, 3, image_size, image_size)
    labels = torch.randint(0, num_classes, (num_samples,))

    return images, labels


def create_feature_cache_data(
    num_samples: int = 100,
    channels: int = 64,
    height: int = 8,
    width: int = 8,
    seed: int = 42,
) -> "NDArray[np.floating]":
    """Create data for feature cache testing.

    Args:
        num_samples: Number of samples.
        channels: Feature channels.
        height: Feature height.
        width: Feature width.
        seed: Random seed.

    Returns:
        Features array [N, C, H, W].
    """
    np.random.seed(seed)
    features = np.random.randn(num_samples, channels, height, width).astype(np.float32)
    return np.maximum(features, 0)  # ReLU-like
