"""Clustering visualization utilities.

Decoupled from analysis; pure viz helpers consuming computed labels and data.
"""

from __future__ import annotations

import json
import logging
import math
from typing import TYPE_CHECKING, Literal

import distinctipy
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import silhouette_samples

from pathfmtools.io.schema import StoreKeys as SK

from .slide import plot_patch_cluster_grid

if TYPE_CHECKING:
    from pathlib import Path

    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from pathfmtools.image.slide_group import SlideGroup


def plot_patch_cluster_membership(
    slide_group: SlideGroup,
    labels: dict[str, np.ndarray],
    *,
    slide_id: str | None = None,
    ax: Axes | None = None,
    n_slides_per_row: int = 10,
    color_rng: int = 42,
) -> Figure | None:
    """Plot per-patch cluster assignments for a slide or all slides.

    Args:
        slide_group: The group of slides containing data.
        labels: Mapping from slide_id -> flat per-patch labels (1D arrays).
        slide_id: If provided, plot only this slide; otherwise plot a grid of all slides.
        ax: Target axes when plotting a single slide.
        n_slides_per_row: Layout for multi-slide plotting.
        color_rng: Seed for distinct color map.

    """
    if slide_id is None:
        n_rows = math.ceil(len(slide_group) / n_slides_per_row)
        n_cols = n_slides_per_row
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
        axes = np.atleast_1d(axes)
        for s_ix, slide in enumerate(slide_group):
            plot_patch_cluster_membership(
                slide_group,
                labels,
                slide_id=slide.id_,
                ax=axes.flat[s_ix],
                color_rng=color_rng,
            )
        return fig

    slide = slide_group[slide_id]
    label_arr = labels[slide_id]
    # Convert labels to float to get NaN in background when mapping to grid
    # (to support cropping and correct background handling in visualization).
    patch_cluster_grid = slide.tile_index.to_grid(label_arr.astype(float))
    return plot_patch_cluster_grid(
        patch_cluster_assignment_grid=patch_cluster_grid,
        ax=ax,
        color_rng=color_rng,
    )


def silhouette_plot(
    data: np.ndarray,
    labels: np.ndarray,
    *,
    ax: Axes | None = None,
    sample_points: bool = True,
    sample_size: int = 10_000,
    seed: int = 42,
) -> Figure | None:
    """Render a silhouette plot for clustering quality.

    Adapted from scikit-learn's example implementation.
    """
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = None

    n_clusters = len(np.unique(labels))
    if sample_points:
        rng = np.random.RandomState(seed)
        cluster_sample_idxs = {
            cluster_idx: rng.choice(
                np.where(labels == cluster_idx)[0],
                size=min(sample_size, np.sum(labels == cluster_idx)),
                replace=False,
            )
            for cluster_idx in range(n_clusters)
        }
        data_sample = data[np.concatenate(list(cluster_sample_idxs.values()))]
        labels_sample = labels[np.concatenate(list(cluster_sample_idxs.values()))]
    else:
        data_sample = data
        labels_sample = labels

    silhouette_vals: np.ndarray = silhouette_samples(data_sample, labels_sample)
    silhouette_avg: float = float(np.mean(silhouette_vals))

    y_lower = 10
    colors = distinctipy.get_colors(n_clusters, rng=seed)
    for cluster_idx in range(n_clusters):
        cluster_silhouette_scores = silhouette_vals[labels_sample == cluster_idx]
        cluster_silhouette_scores.sort()

        size_cluster_i = cluster_silhouette_scores.shape[0]
        y_upper = y_lower + size_cluster_i

        color = colors[cluster_idx]
        ax.fill_betweenx(
            np.arange(y_lower, y_upper),
            0,
            cluster_silhouette_scores,
            facecolor=color,
            edgecolor=color,
            alpha=0.7,
        )
        ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(cluster_idx))
        y_lower = y_upper + 10

    ax.set_title("The silhouette plot for the various clusters.")
    ax.set_xlabel("The silhouette coefficient values")
    ax.set_ylabel("Cluster label")
    ax.axvline(x=silhouette_avg, color="red", linestyle="--")
    ax.set_yticks([])
    ax.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])

    return fig


__all__ = [
    "plot_patch_cluster_membership",
    "silhouette_plot",
]


def representative_patches_grid(
    slide_group: SlideGroup,
    embeddings: np.ndarray,
    labels: np.ndarray,
    centroids: np.ndarray,
    *,
    n_patches_per_cluster: int,
    embedding_type: Literal[
        "patch_feature_embeddings",
        "patch_zeroshot_embeddings",
    ] = "patch_feature_embeddings",
    cluster_ids: list[int] | None = None,
    n_clusters_per_row: int = 3,
    annotate_slide_ids: bool = True,
    cluster_text: dict[int, str] | None = None,
    seed: int = 42,
    save_fig_path: Path | None = None,
    save_map_path: Path | None = None,
) -> tuple[Figure, dict[int, str]]:
    """Render nearest-patch grids per cluster.

    Selects `n_patches_per_cluster` patches with smallest Euclidean distance to each
    cluster centroid, then renders a square grid of the corresponding patches.

    Returns the Matplotlib Figure and a mapping of short slide indices to slide IDs
    for easy reverse lookup when `annotate_slide_ids=True`.
    """
    if embeddings.ndim != 2:
        msg = "embeddings must be 2D (N, D)"
        raise ValueError(msg)
    if labels.ndim != 1:
        msg = "labels must be 1D (N,)"
        raise ValueError(msg)
    if embeddings.shape[0] != labels.shape[0]:
        msg = "embeddings and labels must have the same number of samples"
        raise ValueError(msg)
    if centroids.ndim != 2 or centroids.shape[1] != embeddings.shape[1]:
        msg = "centroids must be 2D with same feature dimension as embeddings"
        raise ValueError(msg)
    if n_patches_per_cluster <= 0:
        msg = "n_patches_per_cluster must be positive"
        raise ValueError(msg)
    if n_clusters_per_row <= 0:
        msg = "n_clusters_per_row must be positive"
        raise ValueError(msg)

    # Determine which clusters to display
    unique_clusters = np.unique(labels)
    cluster_list: list[int] = (
        sorted(cluster_ids) if cluster_ids is not None else sorted(unique_clusters.tolist())
    )

    # Prepare slide short id mapping (index -> slide_id) and reverse for annotation.
    slide_ids = slide_group.get_slide_ids()
    slide_id_to_short: dict[str, int] = {sid: ix for ix, sid in enumerate(slide_ids)}
    short_to_slide_id: dict[int, str] = {ix: sid for sid, ix in slide_id_to_short.items()}

    # Figure layout: one block per cluster.
    n_blocks = len(cluster_list)
    n_rows = math.ceil(n_blocks / n_clusters_per_row)
    # Cap figure size growth to avoid over-large canvases
    fig_w = min(n_clusters_per_row * 10, 40)
    fig_h = min(n_rows * 10, 40)
    fig = plt.figure(figsize=(fig_w, fig_h))
    outer = fig.add_gridspec(
        n_rows,
        n_clusters_per_row,
        wspace=0.05,
        hspace=0.4 if cluster_text else 0.2,
    )

    rng = np.random.RandomState(seed)

    block_ix = 0
    for r in range(n_rows):
        for c in range(n_clusters_per_row):
            if block_ix >= n_blocks:
                break
            cluster_id = cluster_list[block_ix]
            block = outer[r, c]

            # Gather indices for this cluster
            idxs = np.where(labels == cluster_id)[0]
            if idxs.size < n_patches_per_cluster:
                logging.getLogger(__name__).warning(
                    "Cluster %s has only %s patches (< %s). Skipping.",
                    cluster_id,
                    idxs.size,
                    n_patches_per_cluster,
                )
                block_ix += 1
                continue

            # Distances to centroid with tiny seeded jitter for deterministic tie-breaks
            try:
                centroid = centroids[cluster_id]
            except Exception as e:
                msg = (
                    f"Centroid index {cluster_id} not present; ensure centroids align with label "
                    f"ids."
                )
                raise ValueError(msg) from e

            d = np.linalg.norm(embeddings[idxs] - centroid, axis=1)
            d = d + 1e-12 * rng.random(size=d.shape[0])
            order = np.argsort(d, kind="mergesort")  # stable
            chosen = idxs[order[:n_patches_per_cluster]]

            # Inner grid for patches (square-ish)
            side = math.ceil(math.sqrt(n_patches_per_cluster))
            inner = block.subgridspec(
                side,
                side,
                wspace=0,
                hspace=0 if not annotate_slide_ids else 0.25,
            )

            # Optional cluster text/title overlay across the block
            title_ax = fig.add_subplot(inner[:, :])
            title_lines = [f"Cluster {cluster_id}"]
            if cluster_text is not None and cluster_id in cluster_text:
                title_lines.append(cluster_text[cluster_id])
            title_ax.set_title("\n".join(title_lines), fontsize=12, fontweight="bold", pad=20)
            title_ax.axis("off")

            axs = inner.subplots()
            axs = np.atleast_1d(axs)  # type: ignore[reportCallIssue]
            # Draw selected patches
            for ax_ix, ax in enumerate(axs.flat):  # type: ignore[reportAttributeAccessIssue]
                if ax_ix >= chosen.size:
                    ax.axis("off")
                    continue
                flat_idx = int(chosen[ax_ix])
                slide, patch_idx = slide_group.map_embedding_idx_to_source_patches(
                    embedding_idx=flat_idx,
                    embedding_type=embedding_type,
                )
                # Read patch pixels directly from the per-slide store
                with slide.store.open_slide_store_file_readonly(slide.id_) as f:
                    ax.imshow(f[SK.DS_TILES][patch_idx])
                ax.set(xticks=[], yticks=[])
                ax.grid(False)
                if annotate_slide_ids:
                    ax.set_title(str(slide_id_to_short[slide.id_]))

            block_ix += 1

    # Optional side effects: save figure and mapping
    if save_fig_path is not None:
        fig.savefig(save_fig_path, dpi=200, bbox_inches="tight")
    if save_map_path is not None:
        with save_map_path.open("w") as f:
            json.dump(short_to_slide_id, f)

    return fig, short_to_slide_id
