"""UMAP-based latent space analysis and visualization.

This module provides visualization tools for analyzing latent embeddings
using UMAP projections, including comparisons between real and generated samples.
"""

from __future__ import annotations
import os
from pathlib import Path
from typing import Sequence
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from moltenflow.utils.umap import fit_transform_umap, transform_umap, fit_umap


def plot_umap_real_vs_generated(
    z_real: np.ndarray,
    z_generated: np.ndarray,
    umap_model=None,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
    alpha: float = 0.6,
    figsize: tuple[int, int] = (10, 8),
    save_path: str | None = None,
) -> plt.Figure:
    """Create UMAP visualization comparing real and generated samples.

    Args:
        z_real: Real latent embeddings (N_real, d) or (N_real, K, d)
        z_generated: Generated latent embeddings (N_gen, d) or (N_gen, K, d)
        umap_model: Optional pre-fitted UMAP model. If None, fits on combined data.
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP
        alpha: Point transparency
        figsize: Figure size
        save_path: Optional path to save figure

    Returns:
        matplotlib Figure
    """
    # Pool if 3D
    if z_real.ndim == 3:
        z_real = z_real.mean(axis=1)
    if z_generated.ndim == 3:
        z_generated = z_generated.mean(axis=1)

    # Fit UMAP on combined data or use provided model
    if umap_model is None:
        z_combined = np.concatenate([z_real, z_generated], axis=0)
        z_2d, umap_model = fit_transform_umap(
            z_combined, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed
        )
        z_real_2d = z_2d[: len(z_real)]
        z_gen_2d = z_2d[len(z_real) :]
    else:
        z_real_2d = transform_umap(umap_model, z_real)
        z_gen_2d = transform_umap(umap_model, z_generated)

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    ax.scatter(
        z_real_2d[:, 0],
        z_real_2d[:, 1],
        c="#2E86AB",
        label=f"Real (n={len(z_real)})",
        alpha=alpha,
        s=20,
        edgecolors="none",
    )
    ax.scatter(
        z_gen_2d[:, 0],
        z_gen_2d[:, 1],
        c="#E94F37",
        label=f"Generated (n={len(z_generated)})",
        alpha=alpha,
        s=20,
        edgecolors="none",
    )

    ax.set_xlabel("UMAP 1", fontsize=12)
    ax.set_ylabel("UMAP 2", fontsize=12)
    ax.set_title("Latent Space: Real vs Generated", fontsize=14)
    ax.legend(loc="best", fontsize=10)
    ax.set_aspect("equal", adjustable="datalim")

    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


def plot_umap_property_hue(
    z: np.ndarray,
    properties: np.ndarray,
    property_names: Sequence[str] | None = None,
    umap_model=None,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
    cmap: str = "viridis",
    figsize: tuple[int, int] = (12, 5),
    save_path: str | None = None,
) -> plt.Figure:
    """Create UMAP visualization with samples colored by property values.

    Args:
        z: Latent embeddings (N, d) or (N, K, d)
        properties: Property values (N,) or (N, P)
        property_names: Names for each property column
        umap_model: Optional pre-fitted UMAP model
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP
        cmap: Colormap for property values
        figsize: Figure size (width is per subplot)
        save_path: Optional path to save figure

    Returns:
        matplotlib Figure
    """
    # Pool if 3D
    if z.ndim == 3:
        z = z.mean(axis=1)

    # Ensure properties is 2D
    if properties.ndim == 1:
        properties = properties[:, np.newaxis]

    n_props = properties.shape[1]

    if property_names is None:
        property_names = [f"Property {i}" for i in range(n_props)]

    # Fit UMAP
    if umap_model is None:
        z_2d, umap_model = fit_transform_umap(
            z, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed
        )
    else:
        z_2d = transform_umap(umap_model, z)

    # Create subplots
    fig, axes = plt.subplots(1, n_props, figsize=(figsize[0] * n_props / 2, figsize[1]))
    if n_props == 1:
        axes = [axes]

    for i, (ax, name) in enumerate(zip(axes, property_names)):
        scatter = ax.scatter(
            z_2d[:, 0],
            z_2d[:, 1],
            c=properties[:, i],
            cmap=cmap,
            alpha=0.7,
            s=20,
            edgecolors="none",
        )
        ax.set_xlabel("UMAP 1", fontsize=10)
        ax.set_ylabel("UMAP 2", fontsize=10)
        ax.set_title(name, fontsize=12)
        ax.set_aspect("equal", adjustable="datalim")
        plt.colorbar(scatter, ax=ax, shrink=0.8)

    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


def plot_umap_with_labels(
    z: np.ndarray,
    labels: np.ndarray,
    label_names: dict[int, str] | None = None,
    umap_model=None,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
    alpha: float = 0.6,
    figsize: tuple[int, int] = (10, 8),
    save_path: str | None = None,
) -> plt.Figure:
    """Create UMAP visualization with categorical labels.

    Args:
        z: Latent embeddings (N, d) or (N, K, d)
        labels: Integer labels (N,)
        label_names: Optional mapping from label int to display name
        umap_model: Optional pre-fitted UMAP model
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP
        alpha: Point transparency
        figsize: Figure size
        save_path: Optional path to save figure

    Returns:
        matplotlib Figure
    """
    # Pool if 3D
    if z.ndim == 3:
        z = z.mean(axis=1)

    # Fit UMAP
    if umap_model is None:
        z_2d, umap_model = fit_transform_umap(
            z, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed
        )
    else:
        z_2d = transform_umap(umap_model, z)

    unique_labels = np.unique(labels)
    colors = sns.color_palette("husl", n_colors=len(unique_labels))

    fig, ax = plt.subplots(figsize=figsize)

    for i, label in enumerate(unique_labels):
        mask = labels == label
        name = label_names.get(label, str(label)) if label_names else str(label)
        ax.scatter(
            z_2d[mask, 0],
            z_2d[mask, 1],
            c=[colors[i]],
            label=name,
            alpha=alpha,
            s=20,
            edgecolors="none",
        )

    ax.set_xlabel("UMAP 1", fontsize=12)
    ax.set_ylabel("UMAP 2", fontsize=12)
    ax.set_title("Latent Space by Label", fontsize=14)
    ax.legend(loc="best", fontsize=10)
    ax.set_aspect("equal", adjustable="datalim")

    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


def plot_umap_pretrain_vs_finetuned(
    z_pretrain: np.ndarray,
    z_finetuned: np.ndarray,
    y_real: np.ndarray,
    property_names: Sequence[str] | None = None,
    umap_model=None,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
    alpha: float = 0.6,
    figsize: tuple[int, int] = (18, 6),
    save_path: str | None = None,
) -> plt.Figure:
    """Compare pretrained vs finetuned VAE latent spaces hued by properties.

    Creates side-by-side UMAP plots showing how property-guided finetuning
    affects the latent space organization.

    Args:
        z_pretrain: Pretrained VAE latent embeddings (N, d) or (N, K, d)
        z_finetuned: Finetuned VAE latent embeddings (N, d) or (N, K, d)
        y_real: Real property values (N, P)
        property_names: Names for each property column
        umap_model: Optional pre-fitted UMAP model (should be fitted on combined data)
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP
        alpha: Point transparency
        figsize: Figure size (width, height)
        save_path: Optional path to save figure

    Returns:
        matplotlib Figure
    """
    # Pool if 3D
    if z_pretrain.ndim == 3:
        z_pretrain = z_pretrain.mean(axis=1)
    if z_finetuned.ndim == 3:
        z_finetuned = z_finetuned.mean(axis=1)

    # Ensure properties are 2D
    if y_real.ndim == 1:
        y_real = y_real[:, np.newaxis]

    n_props = y_real.shape[1]

    if property_names is None:
        property_names = [f"Property {i}" for i in range(n_props)]

    # Fit or use provided UMAP
    if umap_model is None:
        # Fit on combined data for direct comparison
        z_combined = np.concatenate([z_pretrain, z_finetuned], axis=0)
        umap_model = fit_umap(z_combined, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed)

    # Transform both sets
    z_pretrain_2d = transform_umap(umap_model, z_pretrain)
    z_finetuned_2d = transform_umap(umap_model, z_finetuned)

    # Create figure with 1 row and n_props columns for each model
    fig, axes = plt.subplots(2, n_props, figsize=figsize)
    if n_props == 1:
        axes = axes[:, np.newaxis]

    for prop_idx in range(n_props):
        prop_name = property_names[prop_idx]
        y_vals = y_real[:, prop_idx]

        # Pretrained (top row)
        ax_pre = axes[0, prop_idx]
        sc_pre = ax_pre.scatter(
            z_pretrain_2d[:, 0],
            z_pretrain_2d[:, 1],
            c=y_vals,
            cmap="viridis",
            alpha=alpha,
            s=20,
            edgecolors="none",
        )
        ax_pre.set_xlabel("UMAP 1", fontsize=10)
        ax_pre.set_ylabel("UMAP 2", fontsize=10)
        ax_pre.set_title(f"Pretrained: {prop_name}", fontsize=11, fontweight="bold")
        ax_pre.set_aspect("equal", adjustable="datalim")
        plt.colorbar(sc_pre, ax=ax_pre, fraction=0.046, pad=0.04)

        # Finetuned (bottom row)
        ax_fine = axes[1, prop_idx]
        sc_fine = ax_fine.scatter(
            z_finetuned_2d[:, 0],
            z_finetuned_2d[:, 1],
            c=y_vals,
            cmap="viridis",
            alpha=alpha,
            s=20,
            edgecolors="none",
        )
        ax_fine.set_xlabel("UMAP 1", fontsize=10)
        ax_fine.set_ylabel("UMAP 2", fontsize=10)
        ax_fine.set_title(f"Finetuned: {prop_name}", fontsize=11, fontweight="bold")
        ax_fine.set_aspect("equal", adjustable="datalim")
        plt.colorbar(sc_fine, ax=ax_fine, fraction=0.046, pad=0.04)

    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


def plot_umap_real_analysis(
    z_real: np.ndarray,
    y_real: np.ndarray,
    y_pred: np.ndarray,
    property_names: Sequence[str] | None = None,
    umap_model=None,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
    figsize: tuple[int, int] = (18, 15),
    save_path: str | None = None,
) -> plt.Figure:
    """Create comprehensive real sample analysis with actual, predicted, and error plots.

    Args:
        z_real: Real latent embeddings (N, d) or (N, K, d)
        y_real: Real property values (N, P)
        y_pred: Predicted property values (N, P)
        property_names: Names for each property column
        umap_model: Optional pre-fitted UMAP model
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP
        figsize: Figure size (width, height)
        save_path: Optional path to save figure

    Returns:
        matplotlib Figure
    """
    # Pool if 3D
    if z_real.ndim == 3:
        z_real = z_real.mean(axis=1)

    # Ensure properties are 2D
    if y_real.ndim == 1:
        y_real = y_real[:, np.newaxis]
        y_pred = y_pred[:, np.newaxis]

    n_props = y_real.shape[1]

    if property_names is None:
        property_names = [f"Property {i}" for i in range(n_props)]

    # Fit UMAP
    if umap_model is None:
        z_2d, umap_model = fit_transform_umap(
            z_real, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed
        )
    else:
        z_2d = transform_umap(umap_model, z_real)

    # Calculate errors
    _ = (y_real - y_pred) ** 2  # Compute squared error for future use
    absolute_error = np.abs(y_real - y_pred)

    # Create 3 rows: actual, predicted, error
    fig, axes = plt.subplots(3, n_props, figsize=(figsize[0], figsize[1]))
    if n_props == 1:
        axes = axes.reshape(3, 1)

    for i, name in enumerate(property_names):
        # Row 1: Actual values
        scatter1 = axes[0, i].scatter(
            z_2d[:, 0],
            z_2d[:, 1],
            c=y_real[:, i],
            cmap="viridis",
            alpha=0.7,
            s=20,
            edgecolors="none",
        )
        axes[0, i].set_xlabel("UMAP 1", fontsize=10)
        axes[0, i].set_ylabel("UMAP 2", fontsize=10)
        axes[0, i].set_title(f"{name} (Actual)", fontsize=11)
        axes[0, i].set_aspect("equal", adjustable="datalim")
        plt.colorbar(scatter1, ax=axes[0, i], shrink=0.8)

        # Row 2: Predicted values
        scatter2 = axes[1, i].scatter(
            z_2d[:, 0],
            z_2d[:, 1],
            c=y_pred[:, i],
            cmap="viridis",
            alpha=0.7,
            s=20,
            edgecolors="none",
        )
        axes[1, i].set_xlabel("UMAP 1", fontsize=10)
        axes[1, i].set_ylabel("UMAP 2", fontsize=10)
        axes[1, i].set_title(f"{name} (Predicted)", fontsize=11)
        axes[1, i].set_aspect("equal", adjustable="datalim")
        plt.colorbar(scatter2, ax=axes[1, i], shrink=0.8)

        # Row 3: MAE Error
        scatter3 = axes[2, i].scatter(
            z_2d[:, 0],
            z_2d[:, 1],
            c=absolute_error[:, i],
            cmap="Reds",
            alpha=0.7,
            s=20,
            edgecolors="none",
        )
        axes[2, i].set_xlabel("UMAP 1", fontsize=10)
        axes[2, i].set_ylabel("UMAP 2", fontsize=10)
        axes[2, i].set_title(f"{name} (Absolute Error)", fontsize=11)
        axes[2, i].set_aspect("equal", adjustable="datalim")
        plt.colorbar(scatter3, ax=axes[2, i], shrink=0.8)

    plt.suptitle("Real Sample Analysis: Actual vs Predicted Properties", fontsize=14, y=0.995)
    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


def plot_umap_multitype_overlay(
    z_real: np.ndarray,
    z_uncond: np.ndarray | None = None,
    z_cond: np.ndarray | None = None,
    umap_model=None,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
    alpha: float = 0.5,
    figsize: tuple[int, int] = (10, 8),
    save_path: str | None = None,
    title: str = "Latent Space: Real vs Generated (Uncond & Cond)",
    real_label: str = "Real",
) -> plt.Figure:
    """Create UMAP overlay of real, unconditioned, and conditioned samples.

    Args:
        z_real: Real latent embeddings (N_real, d) or (N_real, K, d)
        z_uncond: Unconditioned generated embeddings (N_uncond, d) or (N_uncond, K, d)
        z_cond: Conditioned generated embeddings (N_cond, d) or (N_cond, K, d)
        umap_model: Optional pre-fitted UMAP model
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP
        alpha: Point transparency
        figsize: Figure size
        save_path: Optional path to save figure
        title: Plot title
        real_label: Label for real samples (e.g., "Train", "Val", "Test")

    Returns:
        matplotlib Figure
    """
    # Pool if 3D
    if z_real.ndim == 3:
        z_real = z_real.mean(axis=1)
    if z_uncond is not None and z_uncond.ndim == 3:
        z_uncond = z_uncond.mean(axis=1)
    if z_cond is not None and z_cond.ndim == 3:
        z_cond = z_cond.mean(axis=1)

    # Combine all data for UMAP fitting
    data_list = [z_real]
    if z_uncond is not None:
        data_list.append(z_uncond)
    if z_cond is not None:
        data_list.append(z_cond)

    z_combined = np.concatenate(data_list, axis=0)

    # Fit UMAP
    if umap_model is None:
        z_2d, umap_model = fit_transform_umap(
            z_combined, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed
        )
    else:
        z_2d = transform_umap(umap_model, z_combined)

    # Split projections
    idx = 0
    z_real_2d = z_2d[idx : idx + len(z_real)]
    idx += len(z_real)

    if z_uncond is not None:
        z_uncond_2d = z_2d[idx : idx + len(z_uncond)]
        idx += len(z_uncond)
    else:
        z_uncond_2d = None

    if z_cond is not None:
        z_cond_2d = z_2d[idx : idx + len(z_cond)]
    else:
        z_cond_2d = None

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Plot in order: real, uncond, cond
    ax.scatter(
        z_real_2d[:, 0],
        z_real_2d[:, 1],
        c="#2E86AB",
        label=f"{real_label} (n={len(z_real)})",
        alpha=alpha,
        s=20,
        edgecolors="none",
    )

    if z_uncond_2d is not None:
        ax.scatter(
            z_uncond_2d[:, 0],
            z_uncond_2d[:, 1],
            c="#E94F37",
            label=f"Unconditioned (n={len(z_uncond)})",
            alpha=alpha,
            s=20,
            edgecolors="none",
        )

    if z_cond_2d is not None:
        ax.scatter(
            z_cond_2d[:, 0],
            z_cond_2d[:, 1],
            c="#F6AE2D",
            label=f"Conditioned (n={len(z_cond)})",
            alpha=alpha,
            s=20,
            edgecolors="none",
        )

    ax.set_xlabel("UMAP 1", fontsize=12)
    ax.set_ylabel("UMAP 2", fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc="best", fontsize=10)
    ax.set_aspect("equal", adjustable="datalim")

    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


def plot_umap_splits_overlay(
    z_train: np.ndarray,
    z_val: np.ndarray,
    z_test: np.ndarray,
    z_uncond: np.ndarray | None = None,
    z_cond: np.ndarray | None = None,
    umap_model=None,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
    alpha: float = 0.5,
    figsize: tuple[int, int] = (20, 6),
    save_path: str | None = None,
) -> plt.Figure:
    """Create side-by-side UMAP plots for each data split with generated samples overlaid.

    Args:
        z_train: Train split latent embeddings (N_train, d)
        z_val: Validation split latent embeddings (N_val, d)
        z_test: Test split latent embeddings (N_test, d)
        z_uncond: Unconditioned generated embeddings (N_uncond, d)
        z_cond: Conditioned generated embeddings (N_cond, d)
        umap_model: Optional pre-fitted UMAP model
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP
        alpha: Point transparency
        figsize: Figure size (width, height)
        save_path: Optional path to save figure

    Returns:
        matplotlib Figure
    """
    # Pool if 3D
    if z_train.ndim == 3:
        z_train = z_train.mean(axis=1)
    if z_val.ndim == 3:
        z_val = z_val.mean(axis=1)
    if z_test.ndim == 3:
        z_test = z_test.mean(axis=1)
    if z_uncond is not None and z_uncond.ndim == 3:
        z_uncond = z_uncond.mean(axis=1)
    if z_cond is not None and z_cond.ndim == 3:
        z_cond = z_cond.mean(axis=1)

    # Combine all data for UMAP fitting if needed
    data_list = [z_train, z_val, z_test]
    if z_uncond is not None:
        data_list.append(z_uncond)
    if z_cond is not None:
        data_list.append(z_cond)
    z_combined = np.concatenate(data_list, axis=0)

    # Fit UMAP
    if umap_model is None:
        z_2d, umap_model = fit_transform_umap(
            z_combined, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed
        )
    else:
        z_2d = transform_umap(umap_model, z_combined)

    # Split projections
    idx = 0
    z_train_2d = z_2d[idx : idx + len(z_train)]
    idx += len(z_train)
    z_val_2d = z_2d[idx : idx + len(z_val)]
    idx += len(z_val)
    z_test_2d = z_2d[idx : idx + len(z_test)]
    idx += len(z_test)

    z_uncond_2d = None
    if z_uncond is not None:
        z_uncond_2d = z_2d[idx : idx + len(z_uncond)]
        idx += len(z_uncond)

    z_cond_2d = None
    if z_cond is not None:
        z_cond_2d = z_2d[idx : idx + len(z_cond)]

    # Create 3-panel figure
    fig, axes = plt.subplots(1, 3, figsize=figsize)

    splits_data = [
        ("Train", z_train_2d, len(z_train)),
        ("Validation", z_val_2d, len(z_val)),
        ("Test", z_test_2d, len(z_test)),
    ]

    for ax, (split_name, z_split_2d, n_split) in zip(axes, splits_data):
        # Plot real samples for this split
        ax.scatter(
            z_split_2d[:, 0],
            z_split_2d[:, 1],
            c="#2E86AB",
            label=f"{split_name} (n={n_split})",
            alpha=alpha,
            s=20,
            edgecolors="none",
        )

        # Plot generated samples
        if z_uncond_2d is not None:
            ax.scatter(
                z_uncond_2d[:, 0],
                z_uncond_2d[:, 1],
                c="#E94F37",
                label=f"Unconditioned (n={len(z_uncond)})",
                alpha=alpha * 0.7,
                s=15,
                edgecolors="none",
            )

        if z_cond_2d is not None:
            ax.scatter(
                z_cond_2d[:, 0],
                z_cond_2d[:, 1],
                c="#F6AE2D",
                label=f"Conditioned (n={len(z_cond)})",
                alpha=alpha * 0.7,
                s=15,
                edgecolors="none",
            )

        ax.set_xlabel("UMAP 1", fontsize=10)
        ax.set_ylabel("UMAP 2", fontsize=10)
        ax.set_title(f"{split_name} Split vs Generated", fontsize=12)
        ax.legend(loc="best", fontsize=8)
        ax.set_aspect("equal", adjustable="datalim")

    plt.suptitle("Latent Space by Data Split", fontsize=14, y=1.02)
    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


def create_umap_analysis_report(
    z_real: np.ndarray,
    z_generated: np.ndarray | None = None,
    properties_real: np.ndarray | None = None,
    properties_generated: np.ndarray | None = None,
    property_names: Sequence[str] | None = None,
    output_dir: str = "outputs/umap_analysis",
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
) -> dict[str, str]:
    """Create a comprehensive UMAP analysis report.

    Generates multiple visualizations and saves them to output directory.

    Args:
        z_real: Real latent embeddings
        z_generated: Optional generated latent embeddings
        properties_real: Optional property values for real samples
        properties_generated: Optional property values for generated samples
        property_names: Names for property columns
        output_dir: Directory to save figures
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP

    Returns:
        Dict mapping figure names to file paths
    """
    os.makedirs(output_dir, exist_ok=True)
    saved_figures: dict[str, str] = {}

    # Pool if needed
    if z_real.ndim == 3:
        z_real = z_real.mean(axis=1)
    if z_generated is not None and z_generated.ndim == 3:
        z_generated = z_generated.mean(axis=1)

    # Fit UMAP on real data (or combined if generated available)
    if z_generated is not None:
        z_combined = np.concatenate([z_real, z_generated], axis=0)
        _, umap_model = fit_transform_umap(
            z_combined, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed
        )
    else:
        umap_model = fit_umap(z_real, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed)

    # 1. Real vs Generated comparison
    if z_generated is not None:
        path = str(Path(output_dir) / "umap_real_vs_generated.png")
        plot_umap_real_vs_generated(
            z_real, z_generated, umap_model=umap_model, seed=seed, save_path=path
        )
        saved_figures["real_vs_generated"] = path
        plt.close()

    # 2. Property hue for real samples
    if properties_real is not None:
        path = str(Path(output_dir) / "umap_real_properties.png")
        plot_umap_property_hue(
            z_real,
            properties_real,
            property_names=property_names,
            umap_model=umap_model,
            seed=seed,
            save_path=path,
        )
        saved_figures["real_properties"] = path
        plt.close()

    # 3. Property hue for generated samples
    if z_generated is not None and properties_generated is not None:
        path = str(Path(output_dir) / "umap_generated_properties.png")
        plot_umap_property_hue(
            z_generated,
            properties_generated,
            property_names=property_names,
            umap_model=umap_model,
            seed=seed,
            save_path=path,
        )
        saved_figures["generated_properties"] = path
        plt.close()

    return saved_figures


# Kernels that don't require epsilon parameter
_KERNELS_NO_EPSILON = {"cubic", "thin_plate_spline", "linear", "quintic"}


def _compute_default_epsilon(points: np.ndarray) -> float:
    """Compute a sensible default epsilon for RBF kernels that require it.

    Uses the average nearest-neighbor distance as a heuristic.
    """
    from scipy.spatial.distance import pdist

    # Use average pairwise distance scaled down
    distances = pdist(points)
    return float(np.median(distances) / 2.0)


def plot_umap_with_contours(
    z_2d: np.ndarray,
    values: np.ndarray,
    property_name: str = "Property",
    n_levels: int = 10,
    cmap: str = "viridis",
    alpha_contour: float = 0.4,
    alpha_scatter: float = 0.6,
    figsize: tuple[int, int] = (10, 8),
    save_path: str | None = None,
    show_scatter: bool = True,
    grid_resolution: int = 100,
    smoothing: float = 0.0,
    kernel: str = "thin_plate_spline",
) -> plt.Figure:
    """Create UMAP scatter plot with property contour overlay.

    Uses RBF interpolation to create a smooth property surface over the UMAP
    embedding, making it easier to visualize latent space organization.

    Args:
        z_2d: 2D UMAP coordinates (N, 2)
        values: Property values for each point (N,)
        property_name: Name for the property (used in title/colorbar)
        n_levels: Number of contour levels
        cmap: Colormap for contours and scatter
        alpha_contour: Transparency of contour fill
        alpha_scatter: Transparency of scatter points
        figsize: Figure size
        save_path: Optional path to save figure
        show_scatter: Whether to show scatter points on top of contours
        grid_resolution: Resolution of interpolation grid
        smoothing: RBF smoothing parameter (0 = exact interpolation, higher = smoother)
        kernel: RBF kernel type ('thin_plate_spline', 'gaussian', 'cubic', etc.)

    Returns:
        matplotlib Figure
    """
    from scipy.interpolate import RBFInterpolator

    fig, ax = plt.subplots(figsize=figsize)

    # Create grid for interpolation
    x_min, x_max = z_2d[:, 0].min(), z_2d[:, 0].max()
    y_min, y_max = z_2d[:, 1].min(), z_2d[:, 1].max()

    # Add small margin
    margin_x = (x_max - x_min) * 0.05
    margin_y = (y_max - y_min) * 0.05
    x_min -= margin_x
    x_max += margin_x
    y_min -= margin_y
    y_max += margin_y

    xi = np.linspace(x_min, x_max, grid_resolution)
    yi = np.linspace(y_min, y_max, grid_resolution)
    xi, yi = np.meshgrid(xi, yi)
    grid_points = np.column_stack([xi.ravel(), yi.ravel()])

    # Use RBF interpolation for smooth contours
    # Some kernels require epsilon parameter
    rbf_kwargs = {"kernel": kernel, "smoothing": smoothing}
    if kernel not in _KERNELS_NO_EPSILON:
        rbf_kwargs["epsilon"] = _compute_default_epsilon(z_2d)
    rbf = RBFInterpolator(z_2d, values, **rbf_kwargs)
    zi = rbf(grid_points).reshape(xi.shape)

    # Plot filled contours
    contour = ax.contourf(xi, yi, zi, levels=n_levels, cmap=cmap, alpha=alpha_contour)
    plt.colorbar(contour, ax=ax, label=property_name, shrink=0.8)

    # Add contour lines
    ax.contour(xi, yi, zi, levels=n_levels, colors="gray", alpha=0.3, linewidths=0.5)

    # Overlay scatter points
    if show_scatter:
        ax.scatter(
            z_2d[:, 0],
            z_2d[:, 1],
            c=values,
            cmap=cmap,
            alpha=alpha_scatter,
            s=15,
            edgecolors="white",
            linewidths=0.3,
        )

    ax.set_xlabel("UMAP 1", fontsize=12)
    ax.set_ylabel("UMAP 2", fontsize=12)
    ax.set_title(f"Latent Space: {property_name}", fontsize=14)
    ax.set_aspect("equal", adjustable="datalim")

    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


def plot_umap_pretrain_vs_finetuned_with_contours(
    z_pretrain: np.ndarray,
    z_finetuned: np.ndarray,
    y_real: np.ndarray,
    property_names: Sequence[str] | None = None,
    umap_model=None,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
    n_levels: int = 10,
    alpha_contour: float = 0.4,
    alpha_scatter: float = 0.6,
    figsize: tuple[int, int] = (18, 6),
    save_path: str | None = None,
    grid_resolution: int = 100,
    smoothing: float = 0.0,
    kernel: str = "thin_plate_spline",
) -> plt.Figure:
    """Compare pretrained vs finetuned VAE latent spaces with property contours.

    Similar to plot_umap_pretrain_vs_finetuned but with contour overlays
    to better visualize how property gradients are organized in latent space.
    Uses RBF interpolation for smooth contours.

    Args:
        z_pretrain: Pretrained VAE latent embeddings (N, d) or (N, K, d)
        z_finetuned: Finetuned VAE latent embeddings (N, d) or (N, K, d)
        y_real: Real property values (N, P)
        property_names: Names for each property column
        umap_model: Optional pre-fitted UMAP model
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP
        n_levels: Number of contour levels
        alpha_contour: Transparency of contour fill
        alpha_scatter: Transparency of scatter points
        figsize: Figure size (width, height)
        save_path: Optional path to save figure
        grid_resolution: Resolution of interpolation grid
        smoothing: RBF smoothing parameter (0 = exact interpolation, higher = smoother)
        kernel: RBF kernel type ('thin_plate_spline', 'gaussian', 'cubic', etc.)

    Returns:
        matplotlib Figure
    """
    from scipy.interpolate import RBFInterpolator

    # Pool if 3D
    if z_pretrain.ndim == 3:
        z_pretrain = z_pretrain.mean(axis=1)
    if z_finetuned.ndim == 3:
        z_finetuned = z_finetuned.mean(axis=1)

    # Ensure properties are 2D
    if y_real.ndim == 1:
        y_real = y_real[:, np.newaxis]

    n_props = y_real.shape[1]

    if property_names is None:
        property_names = [f"Property {i}" for i in range(n_props)]

    # Fit or use provided UMAP
    if umap_model is None:
        z_combined = np.concatenate([z_pretrain, z_finetuned], axis=0)
        umap_model = fit_umap(z_combined, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed)

    # Transform both sets
    z_pretrain_2d = transform_umap(umap_model, z_pretrain)
    z_finetuned_2d = transform_umap(umap_model, z_finetuned)

    # Create figure with 2 rows (pretrain, finetuned) and n_props columns
    fig, axes = plt.subplots(2, n_props, figsize=figsize)
    if n_props == 1:
        axes = axes[:, np.newaxis]

    def _add_contours(ax, z_2d, values, title, cmap="viridis"):
        """Helper to add contours to an axis using RBF interpolation."""
        x_min, x_max = z_2d[:, 0].min(), z_2d[:, 0].max()
        y_min, y_max = z_2d[:, 1].min(), z_2d[:, 1].max()

        margin_x = (x_max - x_min) * 0.05
        margin_y = (y_max - y_min) * 0.05
        x_min -= margin_x
        x_max += margin_x
        y_min -= margin_y
        y_max += margin_y

        xi = np.linspace(x_min, x_max, grid_resolution)
        yi = np.linspace(y_min, y_max, grid_resolution)
        xi, yi = np.meshgrid(xi, yi)
        grid_points = np.column_stack([xi.ravel(), yi.ravel()])

        # Use RBF interpolation for smooth contours
        # Some kernels require epsilon parameter
        rbf_kwargs = {"kernel": kernel, "smoothing": smoothing}
        if kernel not in _KERNELS_NO_EPSILON:
            rbf_kwargs["epsilon"] = _compute_default_epsilon(z_2d)
        rbf = RBFInterpolator(z_2d, values, **rbf_kwargs)
        zi = rbf(grid_points).reshape(xi.shape)

        contour = ax.contourf(xi, yi, zi, levels=n_levels, cmap=cmap, alpha=alpha_contour)
        ax.contour(xi, yi, zi, levels=n_levels, colors="gray", alpha=0.3, linewidths=0.5)

        ax.scatter(
            z_2d[:, 0],
            z_2d[:, 1],
            c=values,
            cmap=cmap,
            alpha=alpha_scatter,
            s=15,
            edgecolors="white",
            linewidths=0.3,
        )

        ax.set_xlabel("UMAP 1", fontsize=10)
        ax.set_ylabel("UMAP 2", fontsize=10)
        ax.set_title(title, fontsize=11, fontweight="bold")
        ax.set_aspect("equal", adjustable="datalim")
        plt.colorbar(contour, ax=ax, fraction=0.046, pad=0.04)

    for prop_idx in range(n_props):
        prop_name = property_names[prop_idx]
        y_vals = y_real[:, prop_idx]

        # Pretrained (top row)
        _add_contours(axes[0, prop_idx], z_pretrain_2d, y_vals, f"Pretrained: {prop_name}")

        # Finetuned (bottom row)
        _add_contours(axes[1, prop_idx], z_finetuned_2d, y_vals, f"Finetuned: {prop_name}")

    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


# Default UMAP plot configuration
DEFAULT_UMAP_PLOT_CONFIG = {
    "pretrain_vs_finetuned": True,
    "pretrain_vs_finetuned_contours": True,
    "real_analysis": True,
    "property_hue_uncond": True,
    "property_hue_cond": True,
    "multitype_overlay": True,
    "splits_overlay": True,
}


def create_standard_umap_suite(
    output_dir: str,
    z_train: np.ndarray,
    z_val: np.ndarray,
    z_test: np.ndarray,
    y_real: np.ndarray,
    property_names: Sequence[str],
    z_pretrain: np.ndarray | None = None,
    z_uncond: np.ndarray | None = None,
    z_cond: np.ndarray | None = None,
    y_pred_real: np.ndarray | None = None,
    y_pred_uncond: np.ndarray | None = None,
    y_pred_cond: np.ndarray | None = None,
    plot_config: dict[str, bool] | None = None,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    seed: int = 42,
    contour_smoothing: float = 0.0,
    contour_kernel: str = "thin_plate_spline",
    contour_grid_resolution: int = 100,
    contour_n_levels: int = 10,
) -> dict[str, str]:
    """Create a standardized suite of UMAP visualizations.

    This function generates all standard UMAP plots used in pipeline evaluation,
    with configurable options to enable/disable specific plots.

    Args:
        output_dir: Directory to save figures
        z_train: Train split latent embeddings (N_train, d)
        z_val: Validation split latent embeddings (N_val, d)
        z_test: Test split latent embeddings (N_test, d)
        y_real: Real property values for test split (N_test, P)
        property_names: Names for each property column
        z_pretrain: Optional pretrained VAE latents for test split (for comparison)
        z_uncond: Optional unconditioned generated latents
        z_cond: Optional conditioned generated latents
        y_pred_real: Optional predicted properties for test samples
        y_pred_uncond: Optional predicted properties for unconditioned samples
        y_pred_cond: Optional predicted properties for conditioned samples
        plot_config: Dict to enable/disable specific plots. See DEFAULT_UMAP_PLOT_CONFIG.
        n_neighbors: UMAP n_neighbors parameter
        min_dist: UMAP min_dist parameter
        seed: Random seed for UMAP
        contour_smoothing: RBF smoothing for contour plots (0 = exact, higher = smoother)
        contour_kernel: RBF kernel type for contours (thin_plate_spline, gaussian, cubic)
        contour_grid_resolution: Resolution of contour interpolation grid
        contour_n_levels: Number of contour levels

    Returns:
        Dict mapping plot names to file paths
    """
    os.makedirs(output_dir, exist_ok=True)
    saved_figures: dict[str, str] = {}

    # Merge with defaults
    config = DEFAULT_UMAP_PLOT_CONFIG.copy()
    if plot_config is not None:
        config.update(plot_config)

    # Pool 3D latents if needed
    def _pool(z):
        return z.mean(axis=1) if z is not None and z.ndim == 3 else z

    z_train = _pool(z_train)
    z_val = _pool(z_val)
    z_test = _pool(z_test)
    z_pretrain = _pool(z_pretrain)
    z_uncond = _pool(z_uncond)
    z_cond = _pool(z_cond)

    # Combine all latents for fitting a shared UMAP
    all_z = [z_train, z_val, z_test]
    if z_pretrain is not None:
        all_z.append(z_pretrain)
    if z_uncond is not None:
        all_z.append(z_uncond)
    if z_cond is not None:
        all_z.append(z_cond)
    z_combined = np.concatenate(all_z, axis=0)

    # Fit shared UMAP model
    umap_model = fit_umap(z_combined, n_neighbors=n_neighbors, min_dist=min_dist, seed=seed)

    # 1. Pretrain vs Finetuned comparison (scatter)
    if config.get("pretrain_vs_finetuned") and z_pretrain is not None:
        path = os.path.join(output_dir, "umap_pretrain_vs_finetuned.png")
        plot_umap_pretrain_vs_finetuned(
            z_pretrain,
            z_test,
            y_real,
            property_names=property_names,
            umap_model=umap_model,
            seed=seed,
            save_path=path,
        )
        saved_figures["pretrain_vs_finetuned"] = path
        plt.close()

    # 2. Pretrain vs Finetuned with contours
    if config.get("pretrain_vs_finetuned_contours") and z_pretrain is not None:
        path = os.path.join(output_dir, "umap_pretrain_vs_finetuned_contours.png")
        plot_umap_pretrain_vs_finetuned_with_contours(
            z_pretrain,
            z_test,
            y_real,
            property_names=property_names,
            umap_model=umap_model,
            seed=seed,
            save_path=path,
            smoothing=contour_smoothing,
            kernel=contour_kernel,
            grid_resolution=contour_grid_resolution,
            n_levels=contour_n_levels,
        )
        saved_figures["pretrain_vs_finetuned_contours"] = path
        plt.close()

    # 3. Real sample analysis (actual, predicted, error)
    if config.get("real_analysis") and y_pred_real is not None:
        path = os.path.join(output_dir, "umap_real_analysis.png")
        plot_umap_real_analysis(
            z_test,
            y_real,
            y_pred_real,
            property_names=property_names,
            umap_model=umap_model,
            seed=seed,
            save_path=path,
        )
        saved_figures["real_analysis"] = path
        plt.close()

    # 4. Property hue for unconditioned samples
    if config.get("property_hue_uncond") and z_uncond is not None and y_pred_uncond is not None:
        path = os.path.join(output_dir, "umap_uncond_predicted.png")
        plot_umap_property_hue(
            z_uncond,
            y_pred_uncond,
            property_names=property_names,
            umap_model=umap_model,
            seed=seed,
            save_path=path,
        )
        saved_figures["property_hue_uncond"] = path
        plt.close()

    # 5. Property hue for conditioned samples
    if config.get("property_hue_cond") and z_cond is not None and y_pred_cond is not None:
        path = os.path.join(output_dir, "umap_cond_predicted.png")
        plot_umap_property_hue(
            z_cond,
            y_pred_cond,
            property_names=property_names,
            umap_model=umap_model,
            seed=seed,
            save_path=path,
        )
        saved_figures["property_hue_cond"] = path
        plt.close()

    # 6. Multitype overlay (all real vs generated)
    if config.get("multitype_overlay"):
        z_real_all = np.concatenate([z_train, z_val, z_test], axis=0)
        path = os.path.join(output_dir, "umap_multitype_overlay.png")
        plot_umap_multitype_overlay(
            z_real_all,
            z_uncond,
            z_cond,
            umap_model=umap_model,
            seed=seed,
            save_path=path,
            title="Latent Space: Real vs Generated",
        )
        saved_figures["multitype_overlay"] = path
        plt.close()

    # 7. Splits overlay
    if config.get("splits_overlay"):
        path = os.path.join(output_dir, "umap_splits_overlay.png")
        plot_umap_splits_overlay(
            z_train,
            z_val,
            z_test,
            z_uncond,
            z_cond,
            umap_model=umap_model,
            seed=seed,
            save_path=path,
        )
        saved_figures["splits_overlay"] = path
        plt.close()

    return saved_figures
