"""Slide visualization helpers.

Lightweight plotting utilities for patches, thumbnails, heatmaps, and
interactive exploration. These functions consume data from ``pathfmtools.image``
and project per-tile values using the TileIndex geometry without altering any
compute logic.
"""

from __future__ import annotations

import logging
import math
from typing import Literal

import cv2
import distinctipy
import ipywidgets as widgets
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.cm import get_cmap
from matplotlib.colors import LogNorm, Normalize
from matplotlib.patches import Rectangle

from pathfmtools.image import Slide
from pathfmtools.io.schema import StoreKeys as SK

logger = logging.getLogger(__name__)


def _read_patch_by_idx(slide: Slide, idx: int) -> np.ndarray:
    """Read a single patch by flattened index from the slide store.

    Args:
        slide: Slide object providing the store and metadata.
        idx: Flattened foreground-tile index.

    Returns:
        Patch pixels as ``H x W x 3`` uint8 array.

    """
    with slide.store.open_slide_store_file_readonly(slide.id_) as f:
        return f[SK.DS_TILES][idx]


def _read_patch_by_point(slide: Slide, x_px: int, y_px: int) -> np.ndarray:
    """Read the patch that contains a given pixel coordinate.

    Args:
        slide: Slide object.
        x_px: X coordinate in slide pixel space.
        y_px: Y coordinate in slide pixel space.

    Returns:
        Patch pixels as ``H x W x 3`` uint8 array.

    """
    ti = slide.tile_index
    r = y_px // ti.tile_size
    c = x_px // ti.tile_size
    idx = int(ti.rowcol_to_idx(np.array([r], np.int32), np.array([c], np.int32))[0])
    return _read_patch_by_idx(slide, idx)


def _thumbnail(slide: Slide, target_h: int) -> np.ndarray:
    """Compute a proportional thumbnail with target height.

    Args:
        slide: Slide to render.
        target_h: Target thumbnail height in pixels; width is scaled accordingly.

    Returns:
        RGB thumbnail as ``H x W x 3`` uint8 array.

    """
    w, h = slide.slide_reader.width, slide.slide_reader.height
    scale = target_h / h
    target_w = int(round(w * scale))
    return slide.slide_reader.get_thumbnail((target_w, target_h))


def plot_patch(slide: Slide, x: int, y: int, ax: Axes | None = None):
    """Plot the patch that contains a pixel coordinate.

    Args:
        slide: Slide object to sample from.
        x: X coordinate in slide pixel space.
        y: Y coordinate in slide pixel space.
        ax: Optional Matplotlib axis to draw on.

    Returns:
        The created Figure, or None when ``ax`` is provided.

    """
    patch = _read_patch_by_point(slide, x, y)
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
    else:
        fig = None
    ax.imshow(patch)
    ax.set_title(f"Patch containing ({x}, {y})")
    return fig


def plot_sample_patches(
    slide: Slide,
    n_patches: int,
    patches_per_row: int = 5,
    seed: int | None = None,
):
    """Plot a random sample of foreground patches.

    Args:
        slide: Slide object.
        n_patches: Number of patches to sample (capped by available foreground tiles).
        patches_per_row: Grid layout for display.
        seed: Optional RNG seed for reproducibility.

    Returns:
        Matplotlib Figure containing the grid of sampled patches.

    """
    rng = np.random.default_rng(seed)
    ti = slide.tile_index
    idxs = rng.choice(
        ti.n_foreground_tiles, size=min(n_patches, ti.n_foreground_tiles), replace=False
    )
    idxs.sort()

    n_rows = math.ceil(len(idxs) / patches_per_row)
    fig, axes = plt.subplots(n_rows, patches_per_row, figsize=(4 * patches_per_row, 4 * n_rows))
    axes = np.atleast_1d(axes)

    with slide.store.open_slide_store_file_readonly(slide.id_) as f:
        tiles = f[SK.DS_TILES][idxs, ...]
    xs, ys = ti.idx_to_xy(idxs.astype(np.int32))

    for i, ax in enumerate(axes.flat):
        if i >= len(idxs):
            ax.axis("off")
            continue
        ax.imshow(tiles[i])
        ax.set_title(f"{int(xs[i])}, {int(ys[i])}")

    fig.suptitle(f"Slide ID: {slide.id_}")
    return fig


def plot_patch_with_context(
    slide: Slide,
    patch_x: int,
    patch_y: int,
    patch_size: int,
    neighborhood_size: int = 1,
    show_full_slide: bool = True,
):
    """Show a neighborhood around the patch containing a coordinate.

    Args:
        slide: Slide object.
        patch_x: X coordinate in slide pixel space.
        patch_y: Y coordinate in slide pixel space.
        patch_size: Patch side length in pixels (for labeling/overlay only).
        neighborhood_size: Radius k; renders a (2k+1)x(2k+1) block of patches.
        show_full_slide: If True, include a thumbnail with the neighborhood highlighted.

    Returns:
        Matplotlib Figure with either one or two panes depending on ``show_full_slide``.

    """
    ti = slide.tile_index
    r = patch_y // ti.tile_size
    c = patch_x // ti.tile_size
    r0 = max(r - neighborhood_size, 0)
    r1 = min(r + neighborhood_size + 1, ti.n_tile_rows)
    c0 = max(c - neighborhood_size, 0)
    c1 = min(c + neighborhood_size + 1, ti.n_tile_cols)
    x0, y0 = c0 * ti.tile_size, r0 * ti.tile_size
    w = (c1 - c0) * ti.tile_size
    h = (r1 - r0) * ti.tile_size

    im_block = slide.slide_reader.read_region(x0, y0, w, h).copy()
    # highlight center patch in the block
    cx = (c - c0) * ti.tile_size
    cy = (r - r0) * ti.tile_size
    cv2.rectangle(im_block, (cx, cy), (cx + ti.tile_size, cy + ti.tile_size), (0, 0, 255), 2)

    if show_full_slide:
        fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(18, 9), gridspec_kw={"width_ratios": [1, 2]})
        ax0.imshow(im_block)
        ax0.set_title(f"Neighborhood k={neighborhood_size}")
        thumb = _thumbnail(slide, target_h=5000)
        ax1.imshow(thumb)
        sy = thumb.shape[0] / slide.slide_reader.height
        sx = thumb.shape[1] / slide.slide_reader.width
        rect = Rectangle(
            (x0 * sx, y0 * sy),
            w * sx,
            h * sy,
            facecolor="green",
            edgecolor="green",
            linewidth=4,
            alpha=0.2,
        )
        ax1.add_patch(rect)
        ax1.set_title("Context")
        return fig
    else:
        fig, ax = plt.subplots(1, 1, figsize=(9, 9))
        ax.imshow(im_block)
        ax.set_title(f"Neighborhood k={neighborhood_size}")
        return fig


def plot_slide(
    slide: Slide,
    target_height: int | None = None,
    plot_segmentation_mask: bool = False,
    remove_background: bool = True,
    ax: Axes | list[Axes] | None = None,
):
    """Render a slide thumbnail with optional segmentation overlay/cropping.

    Args:
        slide: Slide object.
        target_height: Target thumbnail height; if None, uses native size.
        plot_segmentation_mask: If True, add segmentation mask overlays.
        remove_background: If True, crop to the TileIndex foreground bounding box.
        ax: Optional axis or axes to draw on; otherwise a new Figure is created.

    Returns:
        The created Figure, or None when ``ax`` is provided.

    """
    sr = slide.slide_reader
    w, h = sr.width, sr.height
    if target_height is not None:
        thumb = _thumbnail(slide, target_h=target_height)
    else:
        # Warn as in previous behavior
        logger.warning("No target height specified. Plotting at native resolution may be slow.")
        thumb = sr.get_thumbnail((w, h))

    ti = slide.tile_index
    if remove_background:
        r0, r1, c0, c1 = ti.foreground_bbox_rowcol()
        sy = thumb.shape[0] / h
        sx = thumb.shape[1] / w
        y0, y1 = round(r0 * ti.tile_size * sy), round(r1 * ti.tile_size * sy)
        x0, x1 = round(c0 * ti.tile_size * sx), round(c1 * ti.tile_size * sx)
        thumb_crop = thumb[y0:y1, x0:x1, :]
    else:
        thumb_crop = thumb

    if ax is None:
        fig, axes = plt.subplots(
            1,
            1 if not plot_segmentation_mask else 3,
            figsize=(10 if not plot_segmentation_mask else 30, 10),
        )
    else:
        fig, axes = None, ax

    # pane 0: thumbnail
    ax0 = axes if not plot_segmentation_mask else axes[0]
    ax0.imshow(thumb_crop)
    ax0.set_title(f"{slide.id_} ({w}x{h})")

    if plot_segmentation_mask:
        # Expand seg mask to thumbnail crop resolution using nearest-neighbor
        seg = ti.segmentation_mask.astype(np.uint8) * 255
        tile_w = max(1, round(ti.tile_size * (thumb.shape[1] / w)))
        tile_h = max(1, round(ti.tile_size * (thumb.shape[0] / h)))
        seg_overlay = np.repeat(np.repeat(seg, tile_h, axis=0), tile_w, axis=1)
        seg_overlay = seg_overlay[y0:y1, x0:x1]
        seg_rgb = cv2.applyColorMap(seg_overlay, cv2.COLORMAP_JET)

        overlay = cv2.addWeighted(thumb_crop, 1 - 0.5, seg_rgb, 0.5, 0)
        axes[1].imshow(overlay)
        axes[1].set_title("Segmentation mask overlaid")
        axes[2].imshow(ti.segmentation_mask)
        axes[2].set_title("Segmentation mask")

    return fig


def plot_logit_heatmaps(
    slide: Slide,
    logit_dict: dict[str, np.ndarray],
    cmap: int = cv2.COLORMAP_JET,
    cmap_alpha: float = 0.35,
    heatmaps_per_row: int = 2,
    target_height: int = 2000,
):
    """Render per-class heatmaps projected onto a cropped thumbnail.

    Args:
        slide: Slide whose TileIndex geometry is used for projection.
        logit_dict: Mapping from class name to per-tile logits (1D arrays).
        cmap: OpenCV colormap used for heatmap coloring.
        cmap_alpha: Opacity of the heatmap overlay.
        heatmaps_per_row: Grid width for arranging class heatmaps.
        target_height: Thumbnail target height used for projection/cropping.

    Returns:
        Matplotlib Figure containing the base thumbnail and class heatmaps.

    """
    ti = slide.tile_index
    base = _thumbnail(slide, target_h=target_height)
    w_full, h_full = slide.slide_reader.width, slide.slide_reader.height

    # Crop thumbnail to foreground bbox
    r0, r1, c0, c1 = ti.foreground_bbox_rowcol()
    tile_w = max(1, round(ti.tile_size * (base.shape[1] / w_full)))
    tile_h = max(1, round(ti.tile_size * (base.shape[0] / h_full)))
    y0, y1 = r0 * tile_h, r1 * tile_h
    x0, x1 = c0 * tile_w, c1 * tile_w
    im_resized = base[y0:y1, x0:x1]

    n_cols = heatmaps_per_row
    n_rows = math.ceil((len(logit_dict) + 1) / n_cols)
    fig_height = 10 * n_rows
    fig_width = int((im_resized.shape[1] / im_resized.shape[0]) * fig_height)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))
    axes = np.atleast_1d(axes)

    axes.flat[0].imshow(im_resized)
    axes.flat[0].set_title("Cropped Thumbnail")

    for i, (name, vals) in enumerate(logit_dict.items(), start=1):
        # project to cropped thumbnail space
        overlay = ti.project_to_thumbnail(
            vals.astype(np.float32), thumb_w=base.shape[1], thumb_h=base.shape[0], crop=True
        )
        if overlay.ndim != 2:  # keep 2D
            overlay = overlay.squeeze()
        # normalize to 0-255 and colorize
        norm = cv2.normalize(overlay, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        color = cv2.applyColorMap(norm, cmap)
        superimposed = cv2.addWeighted(im_resized, 1 - cmap_alpha, color, cmap_alpha, 0)
        axes.flat[i].imshow(superimposed)
        axes.flat[i].set_title(name)

    return fig


def plot_patch_value_heatmap(
    slide: Slide,
    value_arr: np.ndarray,
    target_height: int = 3000,
    ax: Axes | list[Axes] | None = None,
    discrete: bool = False,
    color_scale: Literal["linear", "log"] = "linear",
):
    """Overlay scalar per-tile values on the cropped thumbnail.

    Args:
        slide: Slide object.
        value_arr: Per-tile scalar values (1D array aligned to foreground tiles).
        target_height: Target thumbnail height for projection.
        ax: Optional axis or pair of axes; when a list is provided, the first axis
            shows the cropped thumbnail and the second shows the overlay.
        discrete: If True, use a categorical colormap for unique values.
        color_scale: Scale for continuous values ("linear" or "log").

    Returns:
        Tuple of (Figure | None, full_thumbnail, cropped_thumbnail, overlay_array).

    """
    ti = slide.tile_index
    base = _thumbnail(slide, target_h=target_height)

    # project to cropped thumbnail
    overlay = ti.project_to_thumbnail(
        value_arr.astype(np.float32),
        thumb_w=base.shape[1],
        thumb_h=base.shape[0],
        crop=True,
    )
    # compute cropped thumbnail (same bbox)
    r0, r1, c0, c1 = ti.foreground_bbox_rowcol()
    tile_w = max(1, round(ti.tile_size * (base.shape[1] / slide.slide_reader.width)))
    tile_h = max(1, round(ti.tile_size * (base.shape[0] / slide.slide_reader.height)))
    y0, y1 = r0 * tile_h, r1 * tile_h
    x0, x1 = c0 * tile_w, c1 * tile_w
    thumb_crop = base[y0:y1, x0:x1]

    if discrete:
        n_unique = len(np.unique(value_arr))
        colors = distinctipy.get_colors(n_unique, rng=42)
        cmap = mcolors.ListedColormap(colors, "indexed")
        cm = cmap(overlay)
        cm = (cm * 255).astype(np.uint8)[:, :, :3]
    else:
        if color_scale == "linear":
            norm = Normalize(vmin=np.nanmin(overlay), vmax=np.nanmax(overlay))
        elif color_scale == "log":
            norm = LogNorm(vmin=np.nanmin(overlay[overlay > 0]), vmax=np.nanmax(overlay))
        else:
            msg = f"Invalid color scale: {color_scale}"
            raise ValueError(msg)
        cm = get_cmap("jet")(norm(overlay))
        cm = (cm * 255).astype(np.uint8)[:, :, :3]

    if ax is None:
        fig, ax_hm = plt.subplots(figsize=(10, 10))
    elif isinstance(ax, list):
        (ax_im, ax_hm) = ax
        ax_im.imshow(thumb_crop)
    else:
        fig, ax_hm = None, ax

    ax_hm.imshow(cm)
    return fig, base, thumb_crop, overlay


def create_dynamic_thresholded_logit_plot(
    slide: Slide,
    logits: np.ndarray,
    target_height: int = 2000,
):
    """Interactive thresholding on a single logit map with clickable patch retrieval.

    Args:
        slide: Slide to visualize.
        logits: Per-tile logit values aligned to the foreground tiles.
        target_height: Target thumbnail height for projection.

    Returns:
        Matplotlib Figure with interactive thresholding. Also displays a slider widget
        in notebook environments.

    """
    ti = slide.tile_index
    base = _thumbnail(slide, target_h=target_height)

    # project & crop
    overlay = ti.project_to_thumbnail(
        logits.astype(np.float32),
        thumb_w=base.shape[1],
        thumb_h=base.shape[0],
        crop=True,
    )
    r0, r1, c0, c1 = ti.foreground_bbox_rowcol()
    tile_w = max(1, round(ti.tile_size * (base.shape[1] / slide.slide_reader.width)))
    tile_h = max(1, round(ti.tile_size * (base.shape[0] / slide.slide_reader.height)))

    fig, axes = plt.subplots(1, 2, figsize=(20, 17))
    init_thr = float(np.nanmin(overlay))
    mask = overlay >= init_thr
    heat = np.zeros((*mask.shape, 3), dtype=np.uint8)
    heat[mask] = [255, 0, 0]
    crop_thumb = base[r0 * tile_h : r1 * tile_h, c0 * tile_w : c1 * tile_w]
    merged = cv2.addWeighted(crop_thumb, 0.65, heat, 0.35, 0)
    im0 = axes[0].imshow(merged)
    title = axes[0].set_title(f"Threshold: {init_thr:.2f}")
    patch_ax = axes[1].imshow(np.zeros((ti.tile_size, ti.tile_size, 3), dtype=np.uint8))

    def update_plot(th):
        m = overlay >= th
        heat[..., :] = 0
        heat[m] = [255, 0, 0]
        im0.set_data(cv2.addWeighted(crop_thumb, 0.65, heat, 0.35, 0))
        title.set_text(f"Threshold: {th:.2f}")
        fig.canvas.draw_idle()

    def onclick(event):
        if event.inaxes is not axes[0]:
            return
        x, y = int(event.xdata), int(event.ydata)
        col = c0 + (x // tile_w)
        row = r0 + (y // tile_h)
        try:
            idx = int(ti.rowcol_to_idx(np.array([row], np.int32), np.array([col], np.int32))[0])
        except Exception:
            return
        patch = _read_patch_by_idx(slide, idx)
        patch_ax.set_data(patch)
        fig.canvas.draw_idle()

    slider = widgets.FloatSlider(
        value=init_thr,
        min=float(np.nanmin(overlay)),
        max=float(np.nanmax(overlay)),
        step=(np.nanmax(overlay) - np.nanmin(overlay)) / 100,
        description="Threshold:",
    )
    slider.observe(lambda ch: update_plot(ch["new"]), names="value")
    from IPython.display import display

    display(slider)
    fig.canvas.mpl_connect("button_press_event", onclick)
    fig.tight_layout()
    return fig


def overlay_as_heatmap(
    overlay: np.ndarray,
    *,
    slide_thumbnail: np.ndarray,
    cmap: int = cv2.COLORMAP_JET,
    alpha: float = 0.35,
    crop_all_zeros: bool = False,
    return_cropped_thumbnail: bool = False,
) -> tuple[np.ndarray, np.ndarray | None]:
    """Project a low-res overlay onto a slide thumbnail as a color heatmap.

    Args:
        overlay: 2D per-tile or coarse grid values (H_ov, W_ov).
        slide_thumbnail: Full thumbnail image (H, W, 3) uint8.
        cmap: OpenCV colormap (e.g., ``cv2.COLORMAP_JET``).
        alpha: Heatmap opacity when blending with the thumbnail.
        crop_all_zeros: If True, crop to bounding box of non-zero overlay.
        return_cropped_thumbnail: If True, also return the cropped thumbnail.

    Returns:
        (blended_image, cropped_thumbnail | None)

    Notes:
        - Uses nearest-neighbor upsampling to respect tile boundaries.
        - When ``crop_all_zeros`` is True, both overlay and thumbnail are cropped
          consistently to the non-zero mask bounding box.
    """
    if overlay.ndim != 2:
        msg = "overlay must be 2D"
        raise ValueError(msg)
    if slide_thumbnail.ndim != 3 or slide_thumbnail.shape[2] != 3:
        msg = "slide_thumbnail must be an HxWx3 image"
        raise ValueError(msg)

    H, W = slide_thumbnail.shape[:2]
    h_ov, w_ov = overlay.shape
    # Compute integer scaling to cover full thumbnail extent
    # Rationale: tests use exact multiples (e.g., 10x), and nearest is fastest.
    ov_resized = cv2.resize(
        overlay.astype(np.float32),
        (W, H),
        interpolation=cv2.INTER_NEAREST,
    )

    # Determine crop region if requested
    crop_thumb: np.ndarray | None = None
    if crop_all_zeros:
        nz = np.transpose(np.nonzero(ov_resized))
        if nz.size == 0:
            # No signal; return original thumbnail and empty heatmap overlay
            norm = np.zeros_like(ov_resized, dtype=np.uint8)
            color = cv2.applyColorMap(norm, cmap)
            blended = cv2.addWeighted(slide_thumbnail, 1 - alpha, color, alpha, 0)
            return blended, slide_thumbnail if return_cropped_thumbnail else None
        (y0, x0) = np.min(nz, axis=0)
        (y1, x1) = np.max(nz, axis=0) + 1
        ov_resized = ov_resized[y0:y1, x0:x1]
        crop_thumb = slide_thumbnail[y0:y1, x0:x1]
    else:
        crop_thumb = slide_thumbnail

    # Normalize overlay to 0..255 for colormap
    ov = ov_resized
    vmin = float(np.nanmin(ov))
    vmax = float(np.nanmax(ov))
    if vmax <= vmin:
        norm = np.zeros_like(ov, dtype=np.uint8)
    else:
        norm = cv2.normalize(ov, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
    color = cv2.applyColorMap(norm, cmap)

    blended = cv2.addWeighted(crop_thumb, 1 - alpha, color, alpha, 0)
    return blended, crop_thumb if return_cropped_thumbnail else None


def plot_patch_cluster_grid(
    *,
    patch_cluster_assignment_grid: np.ndarray,
    ax: Axes | None = None,
    crop_nans: bool = True,
    color_rng: int = 42,
):
    """Render a grid of per-patch cluster assignments.

    Args:
        patch_cluster_assignment_grid: 2D array of labels with NaN as background.
        ax: Optional target axes; if None a new figure is created.
        crop_nans: If True, crop to bounding box around non-NaN values.
        color_rng: Seed for distinct color generation.

    Returns:
        The Matplotlib Figure (or None if an axis was provided).
    """
    grid = patch_cluster_assignment_grid
    if grid.ndim != 2:
        msg = "patch_cluster_assignment_grid must be 2D"
        raise ValueError(msg)

    mask_valid = ~np.isnan(grid)
    if crop_nans and np.any(mask_valid):
        ys, xs = np.where(mask_valid)
        y0, y1 = ys.min(), ys.max() + 1
        x0, x1 = xs.min(), xs.max() + 1
        grid_c = grid[y0:y1, x0:x1]
        mask_c = mask_valid[y0:y1, x0:x1]
    else:
        grid_c = grid
        mask_c = mask_valid

    # Prepare an RGB image: white background, colored labels
    Hc, Wc = grid_c.shape
    rgb = np.full((Hc, Wc, 3), 255, dtype=np.uint8)
    if np.any(mask_c):
        labels = np.unique(grid_c[mask_c]).astype(int).tolist()
        colors = distinctipy.get_colors(len(labels), rng=color_rng)
        lut = {lab: (np.array(col) * 255).astype(np.uint8) for lab, col in zip(labels, colors)}
        for lab in labels:
            rgb[grid_c == lab] = lut[lab]

    if ax is None:
        fig, ax = plt.subplots(figsize=(max(4, Wc / 10), max(4, Hc / 10)))
    else:
        fig = None
    ax.imshow(rgb)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title("Patch cluster assignments")
    return fig


__all__ = [
    "overlay_as_heatmap",
    "plot_patch_cluster_grid",
]
