"""
Anchor management utilities for streaming Nyström approximation.

These functions implement diverse anchor selection and refresh strategies
for streaming scenarios under concept drift.  The refresh_anchors function
keeps a fraction of existing anchors and replenishes the remainder using a
mixture of k-means centroids, farthest-point sampling, and class-balanced
medoids.  These heuristics aim to maintain coverage of the input space
while adapting to new data.
"""

from typing import Optional
import numpy as np
from sklearn.cluster import KMeans


def farthest_points(X: np.ndarray, k: int, seed: int = 42, existing: Optional[np.ndarray] = None) -> np.ndarray:
    """
    Select k points from X via farthest-point sampling.  Optionally seed the
    selection with existing indices.

    Args:
        X: Data matrix (n_samples, d)
        k: Number of points to select
        seed: Random seed
        existing: Array of existing indices to include in the seed

    Returns:
        Array of selected indices (length k)
    """
    rng = np.random.default_rng(seed)
    n = X.shape[0]
    if n == 0 or k <= 0:
        return np.empty((0,), dtype=int)
    if existing is None or len(existing) == 0:
        idx = rng.integers(0, n)
        sel = [idx]
    else:
        sel = list(existing)
    while len(sel) < k:
        # Compute squared distances to current selection
        d2 = np.min(((X[:, None, :] - X[sel]) ** 2).sum(axis=2), axis=1)
        cand = int(np.argmax(d2))
        if cand not in sel:
            sel.append(cand)
        else:
            # Randomly pick if duplicate
            rand_idx = rng.integers(0, n)
            if rand_idx not in sel:
                sel.append(int(rand_idx))
    return np.array(sel[:k], dtype=int)


def class_balanced_medoids(X: np.ndarray, y: np.ndarray, k: int, seed: int = 42) -> np.ndarray:
    """
    Select up to k medoids balancing across classes.  For each class, we
    randomly sample up to floor(k / n_classes) points; if insufficient, we
    fill with random samples.

    Args:
        X: Data matrix (n_samples, d) – unused except for indexing
        y: Labels (n_samples,)
        k: Number of indices to select
        seed: Random seed

    Returns:
        Array of selected indices
    """
    rng = np.random.default_rng(seed)
    if k <= 0 or len(y) == 0:
        return np.empty((0,), dtype=int)
    classes = np.unique(y)
    n_classes = len(classes)
    per_class = max(1, k // max(1, n_classes))
    idxs = []
    for c in classes:
        cand = np.where(y == c)[0]
        if cand.size == 0:
            continue
        # Sample without replacement
        pick = rng.choice(cand, size=min(per_class, cand.size), replace=False)
        idxs.extend(pick.tolist())
    # If not enough, fill with random
    if len(idxs) < k:
        remaining = k - len(idxs)
        extras = rng.choice(np.arange(X.shape[0]), size=remaining, replace=False)
        idxs.extend(extras.tolist())
    return np.array(idxs[:k], dtype=int)


def refresh_anchors(
    X_window: np.ndarray,
    y_window: np.ndarray,
    anchors: np.ndarray,
    budget_m: int,
    refresh_frac: float = 0.25,
    seed: int = 42,
) -> np.ndarray:
    """
    Refresh a set of anchors by retaining a fraction of existing anchors and
    replenishing the remainder using a hybrid selection strategy.  The
    refresh strategy picks k-means centroids, farthest points, and class-
    balanced medoids from the current window.

    Args:
        X_window: Current window features
        y_window: Current window labels
        anchors: Existing anchor points (array of shape (m_old, d))
        budget_m: Desired total number of anchors after refresh
        refresh_frac: Fraction of anchors to refresh (0 <= refresh_frac <= 1)
        seed: Random seed

    Returns:
        Updated array of anchors (shape (budget_m, d))
    """
    rng = np.random.default_rng(seed)
    # Determine number to keep
    keep = int(round((1.0 - refresh_frac) * budget_m))
    keep = min(keep, anchors.shape[0])
    if keep > 0:
        keep_indices = rng.choice(np.arange(anchors.shape[0]), size=keep, replace=False)
        kept = anchors[keep_indices]
    else:
        kept = np.empty((0, X_window.shape[1]))
    # Determine number to add
    needed = budget_m - kept.shape[0]
    if needed <= 0:
        return kept
    # Partition the additions
    n_kmeans = max(1, int(0.5 * needed))
    n_farthest = int(0.3 * needed)
    n_medoid = max(0, needed - n_kmeans - n_farthest)
    # k-means centroids
    if n_kmeans > 0:
        km = KMeans(n_clusters=n_kmeans, n_init='auto', random_state=seed).fit(X_window)
        centroids = km.cluster_centers_
    else:
        centroids = np.empty((0, X_window.shape[1]))
    # Farthest-point sampling
    if n_farthest > 0:
        far_idx = farthest_points(X_window, n_farthest, seed=seed)
        far_points = X_window[far_idx]
    else:
        far_points = np.empty((0, X_window.shape[1]))
    # Class-balanced medoids
    if n_medoid > 0:
        med_idx = class_balanced_medoids(X_window, y_window, n_medoid, seed=seed)
        med_points = X_window[med_idx]
    else:
        med_points = np.empty((0, X_window.shape[1]))
    # Concatenate
    new_block = np.concatenate([centroids, far_points, med_points], axis=0)
    # If still short, fill randomly
    if new_block.shape[0] < needed:
        extra_needed = needed - new_block.shape[0]
        extra_idx = rng.choice(np.arange(X_window.shape[0]), size=extra_needed, replace=False)
        extra = X_window[extra_idx]
        new_block = np.concatenate([new_block, extra], axis=0)
    # Trim to exact length
    new_block = new_block[:needed]
    return np.concatenate([kept, new_block], axis=0)