"""
Utility functions for DRO-Lite importance weighting with EMA smoothing
and ESS-adaptive clipping.

These functions are intended to provide a more robust density-ratio
estimation pipeline for streaming scenarios, where raw logistic
discrimination estimates can vary widely across windows.  The ESS
(effective sample size) based clipping stabilizes the importance
weights without discarding useful information.
"""

from typing import Dict, Optional
import numpy as np
from sklearn.linear_model import LogisticRegression


def compute_density_ratio(X_target: np.ndarray, X_source: np.ndarray) -> np.ndarray:
    """
    Estimate the density ratio r(x) = p(x|target) / p(x|source) using a
    logistic discriminator.  If no source points are provided, all ratios
    default to 1.0.

    Args:
        X_target: Array of target samples (shape (n_target, d))
        X_source: Array of source samples (shape (n_source, d))

    Returns:
        Array of raw density ratios (shape (n_target,))
    """
    # No history → uniform weights
    if X_source is None or len(X_source) == 0:
        return np.ones(len(X_target), dtype=float)

    X_combined = np.vstack([X_target, X_source])
    # Labels: 1 for target, 0 for source
    y_combined = np.hstack([np.ones(len(X_target)), np.zeros(len(X_source))])
    # Train logistic classifier
    clf = LogisticRegression(max_iter=1000, random_state=42)
    clf.fit(X_combined, y_combined)
    # p(y=1|x) for target
    p_target = clf.predict_proba(X_target)[:, 1]
    p_source = 1.0 - p_target
    # Avoid division by 0
    p_source = np.clip(p_source, 1e-8, 1.0 - 1e-8)
    ratios = p_target / p_source
    return ratios


def ess(weights: np.ndarray, eps: float = 1e-12) -> float:
    """
    Compute the effective sample size (ESS) of a weight vector.  ESS is
    defined as (sum w)^2 / sum(w^2).  Higher ESS corresponds to more
    uniform weights; lower ESS indicates a few points dominate.

    Args:
        weights: Non-negative weight vector
        eps: Small epsilon to avoid divide-by-zero

    Returns:
        ESS of the weights
    """
    w = np.asarray(weights, dtype=float).clip(min=0.0)
    s1 = w.sum() + eps
    s2 = np.sum(w * w) + eps
    return (s1 * s1) / s2


def adaptive_clip(weights: np.ndarray, target_ess_ratio: float = 0.7, n: Optional[int] = None) -> np.ndarray:
    """
    Clip weights adaptively based on their effective sample size (ESS).
    When ESS is too low compared to n*target_ess_ratio, the clipping bounds
    are tightened to reduce variance; otherwise, the original loose bounds
    are used.

    Args:
        weights: Raw weight vector (non-negative)
        target_ess_ratio: Desired ESS ratio relative to the number of points
        n: Optional explicit number of points; defaults to len(weights)

    Returns:
        Clipped and re-scaled weight vector with average weight ~1
    """
    w = np.asarray(weights, dtype=float).clip(min=0.0)
    if n is None:
        n = len(w)
    cur_ess = ess(w)
    # Adjust clipping bounds based on ESS
    if cur_ess < target_ess_ratio * n:
        lo, hi = (0.25, 4.0)
    else:
        lo, hi = (0.1, 10.0)
    w = np.clip(w, lo, hi)
    # Normalize so that the average weight is approximately 1
    s = w.sum() + 1e-12
    return (n * w) / s


def dro_lite_weights(
    X_target: np.ndarray,
    X_source: np.ndarray,
    ema_state: Dict[str, Optional[np.ndarray]],
    ema: float = 0.6,
    target_ess_ratio: float = 0.7,
) -> np.ndarray:
    """
    Compute DRO-Lite weights for a target set of points relative to a source
    (history) set of points.  This function performs logistic density-ratio
    estimation, followed by optional exponential moving average (EMA)
    smoothing across windows and ESS-adaptive clipping.

    Args:
        X_target: Current window features
        X_source: Historical window features
        ema_state: Dictionary holding the previous smoothed weights; should have key 'prev'
        ema: Smoothing factor for EMA (0 ≤ ema ≤ 1); higher values favour past weights
        target_ess_ratio: Target ESS ratio controlling clipping aggressiveness

    Returns:
        Importance weights for the target samples (normalized so that sum=number of samples)
    """
    # Compute raw ratios via logistic discriminator
    raw = compute_density_ratio(X_target, X_source)
    # EMA smoothing
    prev = ema_state.get('prev', None)
    if prev is None:
        smoothed = raw
    else:
        smoothed = ema * raw + (1.0 - ema) * prev
    ema_state['prev'] = smoothed.copy()
    # ESS-adaptive clipping
    clipped = adaptive_clip(smoothed, target_ess_ratio=target_ess_ratio, n=len(smoothed))
    return clipped