"""Heatmap generation and processing for Expected GradCAM.

Generates class activation heatmaps by combining feature maps with optimal weights:
    L^c = ReLU(Σ_k α_k * A^k)
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Union

import numpy as np
import torch
import torch.nn.functional as F

from expected_gradcam.core.weight_transform import transform_weights


if TYPE_CHECKING:
    from numpy.typing import NDArray
    from torch import Tensor


def generate_heatmap(
    feature_maps: "Tensor",
    weights: "Tensor",
    apply_relu: bool = True,
) -> "Tensor":
    """Generate GradCAM heatmap from feature maps and weights.

    Mathematical specification:
        L^c = ReLU(Σ_k α_k * A^k)

    ReLU retains only positive contributions (regions increasing class score).

    Args:
        feature_maps: Feature maps A [B, K, U, V] from target layer.
        weights: Importance weights α [K].
        apply_relu: Whether to apply ReLU.

    Returns:
        Coarse heatmap [B, U, V] at feature map resolution.
    """
    # Weighted sum: [B, K, U, V] * [1, K, 1, 1] -> sum over K -> [B, U, V]
    weighted_sum = (feature_maps * weights.view(1, -1, 1, 1)).sum(dim=1)

    if apply_relu:
        heatmap = F.relu(weighted_sum)
    else:
        heatmap = weighted_sum

    return heatmap


def upsample_heatmap(
    heatmap: "Tensor",
    target_size: tuple[int, int],
    mode: str = "bilinear",
    align_corners: bool = False,
) -> "Tensor":
    """Upsample heatmap to target resolution.

    Args:
        heatmap: Coarse heatmap [B, U, V] or [U, V].
        target_size: Target size (H, W).
        mode: Interpolation mode.
        align_corners: Whether to align corners.

    Returns:
        Upsampled heatmap [B, H, W] or [H, W].
    """
    squeeze_batch = False
    if heatmap.dim() == 2:
        heatmap = heatmap.unsqueeze(0)
        squeeze_batch = True

    # Add channel dimension: [B, U, V] -> [B, 1, U, V]
    heatmap_4d = heatmap.unsqueeze(1)

    # Upsample
    align = align_corners if mode in ["linear", "bilinear", "bicubic", "trilinear"] else None
    upsampled = F.interpolate(
        heatmap_4d,
        size=target_size,
        mode=mode,
        align_corners=align,
    )

    # Remove channel dimension: [B, 1, H, W] -> [B, H, W]
    upsampled = upsampled.squeeze(1)

    if squeeze_batch:
        upsampled = upsampled.squeeze(0)

    return upsampled


def normalize_heatmap(
    heatmap: "Tensor",
    method: Literal["minmax", "quantile", "sum", "max"] = "minmax",
    quantile_low: float = 0.02,
    quantile_high: float = 0.98,
    eps: float = 1e-8,
) -> "Tensor":
    """Normalize heatmap to [0, 1] range.

    Args:
        heatmap: Heatmap tensor [B, H, W] or [H, W].
        method: Normalization method.
        quantile_low: Lower quantile for quantile normalization.
        quantile_high: Upper quantile for quantile normalization.
        eps: Small constant for numerical stability.

    Returns:
        Normalized heatmap in [0, 1].
    """
    if method == "minmax":
        if heatmap.dim() == 2:
            hm_min = heatmap.min()
            hm_max = heatmap.max()
            normalized = (heatmap - hm_min) / (hm_max - hm_min + eps)
        else:
            B = heatmap.shape[0]
            hm_flat = heatmap.view(B, -1)
            hm_min = hm_flat.min(dim=1, keepdim=True)[0]
            hm_max = hm_flat.max(dim=1, keepdim=True)[0]
            hm_flat_norm = (hm_flat - hm_min) / (hm_max - hm_min + eps)
            normalized = hm_flat_norm.view_as(heatmap)

    elif method == "quantile":
        if heatmap.dim() == 2:
            q_low = torch.quantile(heatmap, quantile_low)
            q_high = torch.quantile(heatmap, quantile_high)
            normalized = (heatmap - q_low) / (q_high - q_low + eps)
            normalized = normalized.clamp(0, 1)
        else:
            B = heatmap.shape[0]
            hm_flat = heatmap.view(B, -1)
            q_low = torch.quantile(hm_flat, quantile_low, dim=1, keepdim=True)
            q_high = torch.quantile(hm_flat, quantile_high, dim=1, keepdim=True)
            hm_flat_norm = (hm_flat - q_low) / (q_high - q_low + eps)
            normalized = hm_flat_norm.clamp(0, 1).view_as(heatmap)

    elif method == "sum":
        if heatmap.dim() == 2:
            normalized = heatmap / (heatmap.sum() + eps)
        else:
            B = heatmap.shape[0]
            hm_flat = heatmap.view(B, -1)
            hm_sum = hm_flat.sum(dim=1, keepdim=True)
            normalized = (hm_flat / (hm_sum + eps)).view_as(heatmap)

    elif method == "max":
        if heatmap.dim() == 2:
            normalized = heatmap / (heatmap.max() + eps)
        else:
            B = heatmap.shape[0]
            hm_flat = heatmap.view(B, -1)
            hm_max = hm_flat.max(dim=1, keepdim=True)[0]
            normalized = (hm_flat / (hm_max + eps)).view_as(heatmap)

    else:
        raise ValueError(f"Unknown normalization method: {method}")

    return normalized


def apply_contrast_enhancement(
    heatmap: Union["Tensor", "NDArray"],
    boost_factor: float = 0.15,
) -> Union["Tensor", "NDArray"]:
    """Apply contrast enhancement to heatmap.

    This enhancement increases the visual distinction between high and low
    activation regions by applying center-relative scaling. The formula is:
        enhanced = mean + (heatmap - mean) * (1 + boost_factor)

    This matches the refinement phase of progressive heatmap generation,
    ensuring consistent output quality between progressive and non-progressive modes.

    Args:
        heatmap: Normalized heatmap [H, W] or [B, H, W] in range [0, 1].
            Supports both PyTorch tensors and NumPy arrays.
        boost_factor: Contrast boost factor. 0.15 means 15% contrast increase.
            Valid range: [0.0, 0.5]. Default matches progressive generator.

    Returns:
        Enhanced heatmap with same shape and type as input, clamped to [0, 1].

    Example:
        >>> heatmap = torch.rand(224, 224)
        >>> enhanced = apply_contrast_enhancement(heatmap, boost_factor=0.15)
        >>> assert enhanced.shape == heatmap.shape
    """
    if isinstance(heatmap, np.ndarray):
        mean = heatmap.mean()
        enhanced = mean + (heatmap - mean) * (1 + boost_factor)
        return np.clip(enhanced, 0, 1).astype(heatmap.dtype)
    else:
        # PyTorch tensor
        mean = heatmap.mean()
        enhanced = mean + (heatmap - mean) * (1 + boost_factor)
        return torch.clamp(enhanced, 0, 1)


def process_heatmap(
    feature_maps: "Tensor",
    weights: "Tensor",
    input_size: tuple[int, int],
    apply_relu: bool = True,
    normalize: bool = True,
    normalization_method: Literal["minmax", "quantile", "sum", "max"] = "quantile",
    weight_transform: Literal[
        "none", "double_power", "extreme_power", "feature_adaptive", "fixed_power"
    ] = "double_power",
    quantile_low: float = 0.02,
    quantile_high: float = 0.98,
    transform_exponent: float = 3.0,
) -> tuple["Tensor", "Tensor"]:
    """Full heatmap processing pipeline.

    Combines:
    1. Weight transformation (WAVE5 C13 double_power by default)
    2. Weighted combination of feature maps
    3. ReLU (optional)
    4. Upsampling to input resolution
    5. Normalization (optional)

    Args:
        feature_maps: Feature maps [B, K, U, V].
        weights: Importance weights [K].
        input_size: Target size (H, W) for upsampling.
        apply_relu: Whether to apply ReLU.
        normalize: Whether to normalize to [0, 1].
        normalization_method: Normalization method.
        weight_transform: Weight transformation method.
        quantile_low: Lower quantile for quantile normalization.
        quantile_high: Upper quantile for quantile normalization.
        transform_exponent: Exponent for extreme_power transform.

    Returns:
        Tuple of:
        - Final heatmap [B, H, W] (or [H, W] if B=1)
        - Coarse heatmap [B, U, V] before upsampling
    """
    # Step 1: Apply weight transformation
    if weight_transform != "none":
        weights = transform_weights(
            weights=weights,
            method=weight_transform,
            feature_maps=feature_maps if weight_transform == "feature_adaptive" else None,
            exponent=transform_exponent,
        )

    # Step 2: Generate coarse heatmap
    coarse_heatmap = generate_heatmap(feature_maps, weights, apply_relu)

    # Step 3: Upsample
    heatmap = upsample_heatmap(coarse_heatmap, input_size)

    # Step 4: Normalize
    if normalize:
        heatmap = normalize_heatmap(
            heatmap,
            method=normalization_method,
            quantile_low=quantile_low,
            quantile_high=quantile_high,
        )

    # Squeeze batch dimension if B=1
    if heatmap.dim() == 3 and heatmap.shape[0] == 1:
        heatmap = heatmap.squeeze(0)
        coarse_heatmap = coarse_heatmap.squeeze(0)

    return heatmap, coarse_heatmap
