"""UMAP dimensionality reduction for latent space visualization.

This module provides utilities for fitting UMAP models on latent embeddings
and projecting them to 2D for visualization.
"""

from __future__ import annotations
from typing import Tuple
import numpy as np
from umap import UMAP


def _maybe_pool_latents(z: np.ndarray) -> np.ndarray:
    """Pool K-token latents to single vectors if needed.

    Args:
        z: Latent array of shape (N, d) or (N, K, d)

    Returns:
        Array of shape (N, d) - pooled via mean if 3D input
    """
    if z.ndim == 3:
        # (N, K, d) -> (N, d) via mean pooling
        return z.mean(axis=1)
    elif z.ndim == 2:
        return z
    else:
        raise ValueError(f"Expected 2D or 3D array, got shape {z.shape}")


def fit_umap(
    z: np.ndarray,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    n_components: int = 2,
    metric: str = "euclidean",
    seed: int = 42,
) -> UMAP:
    """Fit a UMAP model on latent embeddings.

    Args:
        z: Latent array of shape (N, d) or (N, K, d). If 3D, pools K tokens via mean.
        n_neighbors: Number of neighbors for UMAP (controls local vs global structure)
        min_dist: Minimum distance between points in embedding (controls clustering)
        n_components: Output dimensionality (default 2 for visualization)
        metric: Distance metric for UMAP
        seed: Random seed for reproducibility

    Returns:
        Fitted UMAP model
    """
    z_pooled = _maybe_pool_latents(z)

    umap_model = UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=n_components,
        metric=metric,
        random_state=seed,
    )
    umap_model.fit(z_pooled)
    return umap_model


def transform_umap(umap_model: UMAP, z: np.ndarray) -> np.ndarray:
    """Transform latent embeddings using a fitted UMAP model.

    Args:
        umap_model: Fitted UMAP model
        z: Latent array of shape (N, d) or (N, K, d). If 3D, pools K tokens via mean.

    Returns:
        Array of shape (N, n_components) with projected embeddings
    """
    z_pooled = _maybe_pool_latents(z)
    return umap_model.transform(z_pooled)


def fit_transform_umap(
    z: np.ndarray,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    n_components: int = 2,
    metric: str = "euclidean",
    seed: int = 42,
) -> Tuple[np.ndarray, UMAP]:
    """Fit UMAP and transform in one step.

    Args:
        z: Latent array of shape (N, d) or (N, K, d). If 3D, pools K tokens via mean.
        n_neighbors: Number of neighbors for UMAP
        min_dist: Minimum distance between points in embedding
        n_components: Output dimensionality
        metric: Distance metric for UMAP
        seed: Random seed for reproducibility

    Returns:
        Tuple of (projected embeddings of shape (N, n_components), fitted UMAP model)
    """
    z_pooled = _maybe_pool_latents(z)

    umap_model = UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=n_components,
        metric=metric,
        random_state=seed,
    )
    z_2d = umap_model.fit_transform(z_pooled)
    return z_2d, umap_model
