"""Pareto front computation utilities.

This module provides functions for identifying Pareto-optimal points
in multi-objective optimization problems.
"""

from __future__ import annotations

from typing import Sequence

import numpy as np
from paretoset import paretoset


def pareto_front(
    points: np.ndarray,
    sense: Sequence[str] | None = None,
    maximize: list[bool] | None = None,
) -> np.ndarray:
    """Return boolean mask for Pareto-optimal points.

    A point is Pareto-optimal if no other point is strictly better in all objectives.

    Args:
        points: Array of shape (N, D) where N is number of points and D is number of objectives
        sense: List of "max" or "min" for each objective dimension.
               Use "max" if higher is better, "min" if lower is better.
        maximize: DEPRECATED. List of bools for backward compatibility.
                 Use sense parameter instead.

    Returns:
        Boolean mask of shape (N,) where True indicates Pareto-optimal points

    Example:
        >>> points = np.array([[1, 2], [2, 1], [1.5, 1.5], [0.5, 0.5]])
        >>> mask = pareto_front(points, sense=["max", "max"])
        >>> points[mask]  # Returns [[1, 2], [2, 1], [1.5, 1.5]]
    """
    points = np.asarray(points, dtype=np.float64)

    if points.ndim != 2:
        raise ValueError(f"points must be 2D array, got shape {points.shape}")

    n_points, n_dims = points.shape

    if n_points == 0:
        return np.array([], dtype=bool)

    # Handle backward compatibility with maximize parameter
    if sense is None and maximize is not None:
        sense = ["max" if m else "min" for m in maximize]
    elif sense is None:
        # Default: assume all objectives are to be maximized
        sense = ["max"] * n_dims

    if len(sense) != n_dims:
        raise ValueError(f"sense length ({len(sense)}) must match number of dimensions ({n_dims})")

    # Validate sense values
    for i, s in enumerate(sense):
        if s not in ("max", "min"):
            raise ValueError(f"sense[{i}] must be 'max' or 'min', got '{s}'")

    # Use paretoset library for efficient computation
    return paretoset(points, sense=sense)


def get_pareto_points(
    points: np.ndarray,
    sense: Sequence[str] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
    """Extract Pareto-optimal points from an array.

    Args:
        points: Array of shape (N, D)
        sense: List of "max" or "min" for each objective

    Returns:
        Tuple of (pareto_points, pareto_indices) where:
        - pareto_points: Array of shape (M, D) containing Pareto-optimal points
        - pareto_indices: Array of original indices of Pareto points
    """
    mask = pareto_front(points, sense=sense)
    indices = np.where(mask)[0]
    return points[mask], indices


def pareto_rank(
    points: np.ndarray,
    sense: Sequence[str] | None = None,
    max_ranks: int | None = None,
) -> np.ndarray:
    """Compute Pareto rank (front number) for each point.

    Rank 0 = first Pareto front (non-dominated)
    Rank 1 = second front (dominated only by rank 0)
    etc.

    Args:
        points: Array of shape (N, D)
        sense: List of "max" or "min" for each objective
        max_ranks: If set, stop after computing this many ranks

    Returns:
        Array of shape (N,) with rank for each point
    """
    points = np.asarray(points, dtype=np.float64)
    n_points = points.shape[0]

    if n_points == 0:
        return np.array([], dtype=np.int32)

    ranks = np.full(n_points, -1, dtype=np.int32)
    remaining = np.ones(n_points, dtype=bool)
    current_rank = 0

    while remaining.any():
        if max_ranks is not None and current_rank >= max_ranks:
            break

        # Get indices of remaining points
        remaining_idx = np.where(remaining)[0]
        remaining_points = points[remaining_idx]

        # Find Pareto front among remaining points
        front_mask = pareto_front(remaining_points, sense=sense)

        # Assign rank to front points
        front_original_idx = remaining_idx[front_mask]
        ranks[front_original_idx] = current_rank

        # Remove front points from consideration
        remaining[front_original_idx] = False
        current_rank += 1

    return ranks


def get_pareto_neighbors(
    points: np.ndarray,
    sense: Sequence[str] | None = None,
    k_neighbors: int = 5,
    normalize: bool = True,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Get Pareto front and K nearest neighbors to expand candidate selection.

    This function identifies the true Pareto front and then finds the K nearest
    neighbors to each Pareto point, creating an expanded "near-Pareto" selection.
    This is useful when the true Pareto front has too few candidates for
    meaningful optimization experiments.

    Args:
        points: Array of shape (N, D) with objective values
        sense: List of "max" or "min" for each objective
        k_neighbors: Number of nearest neighbors per Pareto point to include
        normalize: Whether to normalize property space before computing distances.
                  Recommended when properties have different scales.

    Returns:
        Tuple of (pareto_mask, selection_mask, neighbor_mask) where:
        - pareto_mask: Boolean mask for true Pareto-optimal points
        - selection_mask: Boolean mask for all selected points (Pareto + neighbors)
        - neighbor_mask: Boolean mask for near-Pareto points only (not true Pareto)

    Example:
        >>> points = np.random.randn(100, 2)
        >>> pareto_mask, selection_mask, neighbor_mask = get_pareto_neighbors(
        ...     points, sense=["max", "min"], k_neighbors=3
        ... )
        >>> print(f"Pareto: {pareto_mask.sum()}, Total selected: {selection_mask.sum()}")
    """
    from sklearn.neighbors import NearestNeighbors

    points = np.asarray(points, dtype=np.float64)
    n_points = points.shape[0]

    if n_points == 0:
        empty = np.array([], dtype=bool)
        return empty, empty, empty

    # Get true Pareto front
    pareto_mask = pareto_front(points, sense=sense)
    pareto_indices = np.where(pareto_mask)[0]

    if k_neighbors <= 0 or len(pareto_indices) == 0:
        # No neighbors requested or no Pareto points
        neighbor_mask = np.zeros(n_points, dtype=bool)
        return pareto_mask, pareto_mask.copy(), neighbor_mask

    # Prepare points for distance computation
    if normalize:
        # Normalize each dimension to [0, 1] range
        mins = points.min(axis=0)
        maxs = points.max(axis=0)
        ranges = maxs - mins
        # Avoid division by zero for constant dimensions
        ranges = np.where(ranges > 0, ranges, 1.0)
        points_norm = (points - mins) / ranges
    else:
        points_norm = points

    # Fit nearest neighbors on all points
    # We need k_neighbors + 1 because each Pareto point will find itself
    n_neighbors_query = min(k_neighbors + 1, n_points)
    nn = NearestNeighbors(n_neighbors=n_neighbors_query, metric="euclidean")
    nn.fit(points_norm)

    # Find neighbors of Pareto points
    pareto_points_norm = points_norm[pareto_indices]
    _, neighbor_indices = nn.kneighbors(pareto_points_norm)

    # Flatten and get unique neighbor indices
    all_neighbor_indices = np.unique(neighbor_indices.flatten())

    # Create selection mask (Pareto + neighbors)
    selection_mask = np.zeros(n_points, dtype=bool)
    selection_mask[all_neighbor_indices] = True

    # Neighbor mask is selection minus Pareto
    neighbor_mask = selection_mask & ~pareto_mask

    return pareto_mask, selection_mask, neighbor_mask
