"""Visualization utilities for Expected GradCAM.

This module provides utilities for visualizing CAM heatmaps, including
overlay on images, colormap application, and comparison grids.

Example:
    >>> from expected_gradcam.utils.visualization import overlay_heatmap, create_comparison
    >>>
    >>> # Overlay heatmap on image
    >>> overlay = overlay_heatmap(image, heatmap, alpha=0.5)
    >>>
    >>> # Create comparison of multiple methods
    >>> fig = create_comparison(
    ...     image,
    ...     {"GradCAM": gradcam_heatmap, "E-GradCAM": egcam_heatmap},
    ... )
    >>> fig.savefig("comparison.png")
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import torch
from torch import Tensor

if TYPE_CHECKING:
    from numpy.typing import NDArray
    from PIL import Image


# =============================================================================
# Colormap Functions
# =============================================================================


def _jet_colormap(value: "NDArray[np.floating]") -> "NDArray[np.uint8]":
    """Apply jet colormap to normalized values.

    Args:
        value: Normalized values in [0, 1] with shape [...].

    Returns:
        RGB values with shape [..., 3] in uint8.
    """
    # Jet colormap approximation
    r = np.clip(1.5 - np.abs(4 * value - 3), 0, 1)
    g = np.clip(1.5 - np.abs(4 * value - 2), 0, 1)
    b = np.clip(1.5 - np.abs(4 * value - 1), 0, 1)

    rgb = np.stack([r, g, b], axis=-1)
    return (rgb * 255).astype(np.uint8)


def _viridis_colormap(value: "NDArray[np.floating]") -> "NDArray[np.uint8]":
    """Apply viridis colormap to normalized values.

    This is a perceptually uniform colormap, good for scientific visualization.

    Args:
        value: Normalized values in [0, 1].

    Returns:
        RGB values in uint8.
    """
    # Viridis colormap control points
    colors = np.array(
        [
            [0.267004, 0.004874, 0.329415],
            [0.282327, 0.140926, 0.457517],
            [0.253935, 0.265254, 0.529983],
            [0.206756, 0.371758, 0.553117],
            [0.163625, 0.471133, 0.558148],
            [0.127568, 0.566949, 0.550556],
            [0.134692, 0.658636, 0.517649],
            [0.266941, 0.748751, 0.440573],
            [0.477504, 0.821444, 0.318195],
            [0.741388, 0.873449, 0.149561],
            [0.993248, 0.906157, 0.143936],
        ]
    )

    # Interpolate
    indices = value * (len(colors) - 1)
    lower = np.floor(indices).astype(int)
    upper = np.ceil(indices).astype(int)
    frac = indices - lower

    lower = np.clip(lower, 0, len(colors) - 1)
    upper = np.clip(upper, 0, len(colors) - 1)

    rgb = colors[lower] * (1 - frac[..., None]) + colors[upper] * frac[..., None]
    return (rgb * 255).astype(np.uint8)


def _inferno_colormap(value: "NDArray[np.floating]") -> "NDArray[np.uint8]":
    """Apply inferno colormap to normalized values.

    Args:
        value: Normalized values in [0, 1].

    Returns:
        RGB values in uint8.
    """
    # Inferno colormap control points
    colors = np.array(
        [
            [0.001462, 0.000466, 0.013866],
            [0.087411, 0.044556, 0.224813],
            [0.258234, 0.038571, 0.406152],
            [0.416331, 0.090834, 0.432943],
            [0.578304, 0.148039, 0.404411],
            [0.735683, 0.215906, 0.330245],
            [0.865006, 0.316822, 0.226055],
            [0.954506, 0.468744, 0.099874],
            [0.987622, 0.645320, 0.039886],
            [0.988362, 0.838116, 0.286601],
            [0.988362, 0.998364, 0.644924],
        ]
    )

    indices = value * (len(colors) - 1)
    lower = np.floor(indices).astype(int)
    upper = np.ceil(indices).astype(int)
    frac = indices - lower

    lower = np.clip(lower, 0, len(colors) - 1)
    upper = np.clip(upper, 0, len(colors) - 1)

    rgb = colors[lower] * (1 - frac[..., None]) + colors[upper] * frac[..., None]
    return (rgb * 255).astype(np.uint8)


def _hot_colormap(value: "NDArray[np.floating]") -> "NDArray[np.uint8]":
    """Apply hot colormap to normalized values.

    Args:
        value: Normalized values in [0, 1].

    Returns:
        RGB values in uint8.
    """
    r = np.clip(3 * value, 0, 1)
    g = np.clip(3 * value - 1, 0, 1)
    b = np.clip(3 * value - 2, 0, 1)

    rgb = np.stack([r, g, b], axis=-1)
    return (rgb * 255).astype(np.uint8)


COLORMAPS = {
    "jet": _jet_colormap,
    "viridis": _viridis_colormap,
    "inferno": _inferno_colormap,
    "hot": _hot_colormap,
}

ColormapName = Literal["jet", "viridis", "inferno", "hot"]


def apply_colormap(
    heatmap: Tensor | "NDArray[np.floating]",
    colormap: ColormapName = "jet",
    normalize: bool = True,
) -> "NDArray[np.uint8]":
    """Apply colormap to a heatmap.

    Args:
        heatmap: 2D heatmap tensor or array [H, W].
        colormap: Name of colormap to use.
        normalize: Whether to normalize heatmap to [0, 1].

    Returns:
        RGB image as uint8 array [H, W, 3].
    """
    if isinstance(heatmap, Tensor):
        heatmap = heatmap.detach().cpu().numpy()

    heatmap = np.asarray(heatmap, dtype=np.float32)

    # Handle batch dimension
    if heatmap.ndim == 3:
        heatmap = heatmap[0]

    if normalize:
        vmin, vmax = heatmap.min(), heatmap.max()
        if vmax - vmin > 1e-8:
            heatmap = (heatmap - vmin) / (vmax - vmin)
        else:
            heatmap = np.zeros_like(heatmap)

    heatmap = np.clip(heatmap, 0, 1)

    if colormap not in COLORMAPS:
        raise ValueError(f"Unknown colormap: {colormap}. Available: {list(COLORMAPS)}")

    return COLORMAPS[colormap](heatmap)


# =============================================================================
# Image Conversion Utilities
# =============================================================================


def tensor_to_numpy(
    tensor: Tensor,
    denormalize: bool = True,
    mean: tuple[float, ...] = (0.485, 0.456, 0.406),
    std: tuple[float, ...] = (0.229, 0.224, 0.225),
) -> "NDArray[np.uint8]":
    """Convert a tensor image to numpy array.

    Args:
        tensor: Image tensor [C, H, W] or [B, C, H, W].
        denormalize: Whether to denormalize ImageNet normalization.
        mean: Mean values used for normalization.
        std: Std values used for normalization.

    Returns:
        RGB image as uint8 array [H, W, 3].
    """
    img = tensor.detach().cpu()

    # Handle batch dimension
    if img.ndim == 4:
        img = img[0]

    # [C, H, W] -> [H, W, C]
    img = img.permute(1, 2, 0).numpy()

    if denormalize:
        mean_arr = np.array(mean, dtype=np.float32)
        std_arr = np.array(std, dtype=np.float32)
        img = img * std_arr + mean_arr

    img = np.clip(img * 255, 0, 255).astype(np.uint8)
    return img


def numpy_to_pil(image: "NDArray[np.uint8]") -> "Image.Image":
    """Convert numpy array to PIL Image.

    Args:
        image: RGB image array [H, W, 3].

    Returns:
        PIL Image.
    """
    from PIL import Image

    return Image.fromarray(image)


def pil_to_numpy(image: "Image.Image") -> "NDArray[np.uint8]":
    """Convert PIL Image to numpy array.

    Args:
        image: PIL Image.

    Returns:
        RGB image array [H, W, 3].
    """
    return np.asarray(image.convert("RGB"), dtype=np.uint8)


# =============================================================================
# Overlay Functions
# =============================================================================


def overlay_heatmap(
    image: Tensor | "NDArray[np.uint8]" | "Image.Image",
    heatmap: Tensor | "NDArray[np.floating]",
    alpha: float = 0.5,
    colormap: ColormapName = "jet",
    denormalize: bool = True,
) -> "NDArray[np.uint8]":
    """Overlay a heatmap on an image.

    Args:
        image: Original image (tensor, numpy, or PIL).
        heatmap: CAM heatmap [H, W].
        alpha: Blending factor (0=image only, 1=heatmap only).
        colormap: Colormap to apply to heatmap.
        denormalize: Whether to denormalize tensor images.

    Returns:
        Blended image as uint8 array [H, W, 3].
    """
    from PIL import Image as PILImage

    # Convert image to numpy
    if isinstance(image, Tensor):
        img_np = tensor_to_numpy(image, denormalize=denormalize)
    elif isinstance(image, PILImage.Image):
        img_np = pil_to_numpy(image)
    else:
        img_np = np.asarray(image, dtype=np.uint8)

    # Get heatmap as numpy
    if isinstance(heatmap, Tensor):
        heatmap_np = heatmap.detach().cpu().numpy()
    else:
        heatmap_np = np.asarray(heatmap)

    if heatmap_np.ndim == 3:
        heatmap_np = heatmap_np[0]

    # Resize heatmap to image size
    from PIL import Image

    h, w = img_np.shape[:2]
    heatmap_pil = Image.fromarray(heatmap_np.astype(np.float32), mode="F")
    heatmap_resized = np.array(heatmap_pil.resize((w, h), Image.BILINEAR))

    # Apply colormap
    heatmap_rgb = apply_colormap(heatmap_resized, colormap=colormap, normalize=True)

    # Blend
    blended = (1 - alpha) * img_np.astype(np.float32) + alpha * heatmap_rgb.astype(
        np.float32
    )
    return np.clip(blended, 0, 255).astype(np.uint8)


def overlay_heatmap_pil(
    image: Tensor | "NDArray[np.uint8]" | "Image.Image",
    heatmap: Tensor | "NDArray[np.floating]",
    alpha: float = 0.5,
    colormap: ColormapName = "jet",
    denormalize: bool = True,
) -> "Image.Image":
    """Overlay a heatmap on an image, returning PIL Image.

    Args:
        image: Original image.
        heatmap: CAM heatmap.
        alpha: Blending factor.
        colormap: Colormap name.
        denormalize: Whether to denormalize tensor images.

    Returns:
        PIL Image with overlay.
    """
    overlay = overlay_heatmap(image, heatmap, alpha, colormap, denormalize)
    return numpy_to_pil(overlay)


# =============================================================================
# Comparison and Grid Visualization
# =============================================================================


def create_comparison(
    image: Tensor | "NDArray[np.uint8]" | "Image.Image",
    heatmaps: dict[str, Tensor | "NDArray[np.floating]"],
    alpha: float = 0.5,
    colormap: ColormapName = "jet",
    figsize: tuple[float, float] | None = None,
    title: str | None = None,
    denormalize: bool = True,
) -> Any:
    """Create side-by-side comparison of multiple heatmaps.

    Args:
        image: Original image.
        heatmaps: Dict mapping method names to heatmaps.
        alpha: Overlay alpha.
        colormap: Colormap name.
        figsize: Figure size (width, height).
        title: Overall figure title.
        denormalize: Whether to denormalize tensor images.

    Returns:
        Matplotlib figure.
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError as e:
        raise ImportError(
            "matplotlib is required for create_comparison. "
            "Install with: pip install matplotlib"
        ) from e

    from PIL import Image as PILImage

    # Convert image to numpy
    if isinstance(image, Tensor):
        img_np = tensor_to_numpy(image, denormalize=denormalize)
    elif isinstance(image, PILImage.Image):
        img_np = pil_to_numpy(image)
    else:
        img_np = np.asarray(image, dtype=np.uint8)

    n_methods = len(heatmaps)
    n_cols = n_methods + 1  # Original + methods

    if figsize is None:
        figsize = (4 * n_cols, 4)

    fig, axes = plt.subplots(1, n_cols, figsize=figsize)

    # Original image
    axes[0].imshow(img_np)
    axes[0].set_title("Original")
    axes[0].axis("off")

    # Overlays
    for ax, (name, heatmap) in zip(axes[1:], heatmaps.items()):
        overlay = overlay_heatmap(img_np, heatmap, alpha, colormap, denormalize=False)
        ax.imshow(overlay)
        ax.set_title(name)
        ax.axis("off")

    if title:
        fig.suptitle(title, fontsize=14)

    plt.tight_layout()
    return fig


def create_grid(
    images: list["NDArray[np.uint8]"] | list[Tensor],
    titles: list[str] | None = None,
    ncols: int = 4,
    figsize: tuple[float, float] | None = None,
    denormalize: bool = True,
) -> Any:
    """Create a grid of images.

    Args:
        images: List of images to display.
        titles: Optional titles for each image.
        ncols: Number of columns.
        figsize: Figure size.
        denormalize: Whether to denormalize tensor images.

    Returns:
        Matplotlib figure.
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError as e:
        raise ImportError(
            "matplotlib is required for create_grid. "
            "Install with: pip install matplotlib"
        ) from e

    n = len(images)
    nrows = (n + ncols - 1) // ncols

    if figsize is None:
        figsize = (3 * ncols, 3 * nrows)

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = np.atleast_2d(axes)

    for idx, img in enumerate(images):
        row, col = idx // ncols, idx % ncols

        if isinstance(img, Tensor):
            img = tensor_to_numpy(img, denormalize=denormalize)

        axes[row, col].imshow(img)
        if titles and idx < len(titles):
            axes[row, col].set_title(titles[idx])
        axes[row, col].axis("off")

    # Hide unused subplots
    for idx in range(n, nrows * ncols):
        row, col = idx // ncols, idx % ncols
        axes[row, col].axis("off")

    plt.tight_layout()
    return fig


def save_visualization(
    data: "NDArray[np.uint8]" | "Image.Image" | Any,
    path: str | Path,
    dpi: int = 150,
) -> None:
    """Save visualization to file.

    Automatically detects format from extension and handles
    different input types (numpy array, PIL Image, matplotlib figure).

    Args:
        data: Image data or matplotlib figure.
        path: Output path.
        dpi: DPI for matplotlib figures.
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    from PIL import Image as PILImage

    # Check if matplotlib figure
    try:
        import matplotlib.pyplot as plt

        if hasattr(data, "savefig"):
            data.savefig(path, dpi=dpi, bbox_inches="tight")
            plt.close(data)
            return
    except ImportError:
        pass

    # Handle PIL Image
    if isinstance(data, PILImage.Image):
        data.save(path)
        return

    # Handle numpy array
    if isinstance(data, np.ndarray):
        img = PILImage.fromarray(data)
        img.save(path)
        return

    raise TypeError(f"Cannot save type {type(data)}")


# =============================================================================
# Heatmap Processing
# =============================================================================


def resize_heatmap(
    heatmap: Tensor | "NDArray[np.floating]",
    size: tuple[int, int],
    mode: str = "bilinear",
) -> "NDArray[np.floating]":
    """Resize a heatmap to target size.

    Args:
        heatmap: Heatmap [H, W] or [B, H, W].
        size: Target size (height, width).
        mode: Interpolation mode.

    Returns:
        Resized heatmap.
    """
    from PIL import Image

    if isinstance(heatmap, Tensor):
        heatmap = heatmap.detach().cpu().numpy()

    heatmap = np.asarray(heatmap)

    # Handle batch dimension
    squeeze = False
    if heatmap.ndim == 2:
        heatmap = heatmap[np.newaxis]
        squeeze = True

    resized = []
    for h in heatmap:
        pil_mode = {"bilinear": Image.BILINEAR, "bicubic": Image.BICUBIC}.get(
            mode, Image.BILINEAR
        )
        img = Image.fromarray(h.astype(np.float32), mode="F")
        resized_img = img.resize((size[1], size[0]), pil_mode)
        resized.append(np.array(resized_img))

    result = np.stack(resized)
    return result[0] if squeeze else result


def normalize_heatmap(
    heatmap: Tensor | "NDArray[np.floating]",
    method: Literal["minmax", "percentile", "std"] = "minmax",
    percentile: float = 99,
) -> "NDArray[np.floating]":
    """Normalize heatmap values.

    Args:
        heatmap: Input heatmap.
        method: Normalization method.
        percentile: Percentile for percentile method.

    Returns:
        Normalized heatmap in [0, 1].
    """
    if isinstance(heatmap, Tensor):
        heatmap = heatmap.detach().cpu().numpy()

    heatmap = np.asarray(heatmap, dtype=np.float32)

    if method == "minmax":
        vmin, vmax = heatmap.min(), heatmap.max()
        if vmax - vmin > 1e-8:
            return (heatmap - vmin) / (vmax - vmin)
        return np.zeros_like(heatmap)

    elif method == "percentile":
        vmax = np.percentile(heatmap, percentile)
        vmin = np.percentile(heatmap, 100 - percentile)
        if vmax - vmin > 1e-8:
            return np.clip((heatmap - vmin) / (vmax - vmin), 0, 1)
        return np.zeros_like(heatmap)

    elif method == "std":
        mean = heatmap.mean()
        std = heatmap.std()
        if std > 1e-8:
            normalized = (heatmap - mean) / (2 * std) + 0.5
            return np.clip(normalized, 0, 1)
        return np.ones_like(heatmap) * 0.5

    else:
        raise ValueError(f"Unknown normalization method: {method}")


__all__ = [
    # Colormap
    "apply_colormap",
    "COLORMAPS",
    # Image conversion
    "tensor_to_numpy",
    "numpy_to_pil",
    "pil_to_numpy",
    # Overlay
    "overlay_heatmap",
    "overlay_heatmap_pil",
    # Comparison
    "create_comparison",
    "create_grid",
    # Save
    "save_visualization",
    # Processing
    "resize_heatmap",
    "normalize_heatmap",
]
