"""
resampler.py - Resampling algorithms for imbalanced regression

Independent implementation of resampling techniques for regression problems
with imbalanced target distributions.

References:
    [1] Torgo, L. et al (2015). Resampling strategies for regression.
    [2] Branco, P., Torgo, L., and Ribeiro, R.P. (2019). Pre-processing approaches 
        for imbalanced distributions in regression.
"""

import numpy as np
from scipy.spatial.distance import pdist, squareform
from sklearn.neighbors import KernelDensity
from typing import Tuple, Optional, Union, Literal


# =============================================================================
# Relevance Functions
# =============================================================================

def _stable_sigmoid(y: np.ndarray, slope: float, center: float) -> np.ndarray:
    """
    Compute sigmoid with numerical stability.
    
    Formula: 1 / (1 + exp(-slope * (y - center)))
    
    Uses different formulations based on sign of exponent to avoid overflow.
    """
    y = np.squeeze(np.asarray(y))
    exponent = -slope * (y - center)
    exponent = np.clip(exponent, -700, 700)
    
    # Stable computation: different formula for positive/negative exponents
    positive_mask = exponent >= 0
    result = np.empty_like(exponent, dtype=np.float64)
    result[positive_mask] = 1.0 / (1.0 + np.exp(exponent[positive_mask]))
    result[~positive_mask] = np.exp(exponent[~positive_mask]) / (1.0 + np.exp(exponent[~positive_mask]))
    
    return result


def sigmoid_relevance(y: np.ndarray, cl: Optional[float], ch: Optional[float]) -> np.ndarray:
    """
    Map target values to relevance using sigmoid functions.
    
    Extreme values (below cl or above ch) receive higher relevance.
    
    Parameters
    ----------
    y : array-like
        Target values.
    cl : float or None
        Lower extreme center. Values below cl get relevance > 0.5.
        None if only upper extremes matter.
    ch : float or None
        Upper extreme center. Values above ch get relevance > 0.5.
        None if only lower extremes matter.
    
    Returns
    -------
    ndarray
        Relevance values in [0, 1].
    """
    y = np.squeeze(np.asarray(y))
    margin = 0.001 * np.std(y)
    
    if cl is None:
        # Upper extreme only
        ch_adj = ch + margin
        shape = np.log(1e4 - 1) / ch_adj
        return _stable_sigmoid(y, abs(shape), ch_adj)
    
    elif ch is None:
        # Lower extreme only
        cl_adj = cl - margin
        shape = np.log(1e4 - 1) / cl_adj
        return _stable_sigmoid(y, -abs(shape), cl_adj)
    
    else:
        # Both extremes
        cl_adj, ch_adj = cl - margin, ch + margin
        shape_high = np.log(1e4 - 1) / ch_adj
        shape_low = np.log(1e4 - 1) / cl_adj
        return _stable_sigmoid(y, -abs(shape_low), cl_adj) + _stable_sigmoid(y, abs(shape_high), ch_adj)


def pdf_relevance(y: np.ndarray, bandwidth: float = 1.0) -> np.ndarray:
    """
    Compute relevance from inverse probability density.
    
    Low-density (rare) values get high relevance.
    
    Parameters
    ----------
    y : array-like
        Target values.
    bandwidth : float
        KDE bandwidth. Higher = smoother density estimate.
    
    Returns
    -------
    ndarray
        Relevance values in [0, 1].
    """
    y = np.squeeze(np.asarray(y))
    y_2d = y.reshape(len(y), 1)
    
    kde = KernelDensity(bandwidth=bandwidth, kernel='gaussian')
    kde.fit(y_2d)
    density = np.exp(kde.score_samples(y_2d))
    
    # Normalize and invert
    density_range = density.max() - density.min()
    
    if density_range <= 1e-12:
        print(f"Warning: PDF values are nearly identical (range: {density_range:.2e}). Returning uniform relevance.")
        return np.full_like(density, 0.5)
    
    relevance = 1 - (density - density.min()) / density_range
    
    if np.any(~np.isfinite(relevance)):
        print("Warning: Invalid values found in relevance calculation. Replacing with default.")
        return np.full_like(density, 0.5)
    
    return relevance


# =============================================================================
# Core Utilities
# =============================================================================

def _get_knn_indices(features: np.ndarray, k: int) -> np.ndarray:
    """
    Find k nearest neighbors for each sample.
    
    Returns indices of k closest samples (excluding self).
    """
    distances = squareform(pdist(features))
    sorted_indices = np.argsort(distances, axis=1)
    return sorted_indices[:, 1:k+1]


def _split_by_relevance(
    X: np.ndarray, 
    y: np.ndarray, 
    relevance: np.ndarray, 
    threshold: float
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Partition dataset into normal and rare domains.
    
    Rare: relevance >= threshold
    Normal: relevance < threshold
    """
    X = np.asarray(X)
    y = np.squeeze(np.asarray(y))
    relevance = np.squeeze(np.asarray(relevance))
    
    assert len(X) == len(y) == len(relevance), 'X, y, relevance must have same length'
    
    rare_idx = np.where(relevance >= threshold)[0]
    normal_idx = np.where(relevance < threshold)[0]
    
    if len(rare_idx) >= len(normal_idx):
        raise AssertionError(
            f'Rare domain must be smaller than normal domain. '
            f'Current: rare={len(rare_idx)}, normal={len(normal_idx)}, '
            f'threshold={threshold:.4f}, relevance_range=[{relevance.min():.4f}, {relevance.max():.4f}]. '
            f'Try increasing relevance_threshold (use a higher percentile).'
        )
    
    return X[normal_idx], y[normal_idx], X[rare_idx], y[rare_idx]


def _calculate_new_sizes(
    n_normal: int, 
    n_rare: int, 
    over: Union[float, str], 
    under: Optional[float]
) -> Tuple[int, int]:
    """
    Determine target sizes after resampling.
    """
    if isinstance(over, (int, float)) and not isinstance(over, str):
        assert isinstance(under, (int, float)), 'under must also be a float if over is a float'
        assert 0 <= under <= 1, 'under must be between 0 and 1'
        assert over >= 0, 'over must be non-negative'
        return int((1 - under) * n_normal), int((1 + over) * n_rare)
    
    elif over == 'balance':
        size = int((n_normal + n_rare) / 2)
        return size, size
    
    elif over == 'extreme':
        return n_rare, n_normal
    
    elif over == 'average':
        mid = int((n_normal + n_rare) / 2)
        return int((mid + n_rare) / 2), int((mid + n_normal) / 2)
    
    else:
        raise ValueError(f"Invalid 'over': {over}. Use float or 'balance'/'extreme'/'average'")


def _subsample(X: np.ndarray, y: np.ndarray, size: int, seed: int = None) -> Tuple[np.ndarray, np.ndarray]:
    """
    Randomly select subset of samples.
    """
    assert len(X) == len(y), 'X and y must have same length'
    
    if size >= len(y):
        raise ValueError(f'size ({size}) must be smaller than data length ({len(y)})')
    
    np.random.seed(seed)
    indices = np.random.choice(range(len(y)), size, replace=False)
    return X[indices], y[indices]


def _generate_synthetic_via_interpolation(
    X: np.ndarray, 
    y: np.ndarray, 
    k: int, 
    count: int, 
    nominal_features: np.ndarray = None, 
    seed: int = None
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Create synthetic samples by interpolating between samples and their neighbors.
    
    For each new sample:
    1. Pick a random existing sample
    2. Pick one of its k neighbors
    3. Interpolate features and target
    """
    X, y = np.asarray(X), np.squeeze(np.asarray(y))
    assert len(X) == len(y), 'X and y must have same length'
    
    neighbors = _get_knn_indices(X, k)
    np.random.seed(seed=seed)
    base_samples = np.random.choice(range(len(y)), count, replace=True)
    
    X_synthetic, y_synthetic = [], []
    
    for idx in base_samples:
        # Select random neighbor
        neighbor_idx = np.random.choice(neighbors[idx])
        
        # Interpolation weight
        weight = np.random.rand() * np.ones_like(X[idx])
        
        # Handle categorical features
        if nominal_features is not None:
            weight = np.array([
                np.random.choice([0, 1]) if col in nominal_features else weight[col]
                for col in range(len(weight))
            ])
        
        # Interpolate features
        diff = (X[idx] - X[neighbor_idx]) * weight
        x_new = X[neighbor_idx] + diff
        
        # Interpolate target by distance weighting
        d1 = np.linalg.norm(x_new - X[idx])
        d2 = np.linalg.norm(x_new - X[neighbor_idx])
        y_new = (d2 * y[idx] + d1 * y[neighbor_idx]) / (d1 + d2 + 1e-10)
        
        X_synthetic.append(x_new)
        y_synthetic.append(y_new)
    
    return np.array(X_synthetic), np.array(y_synthetic)


def _generate_synthetic_via_noise(
    X: np.ndarray, 
    y: np.ndarray, 
    delta: float, 
    count: int, 
    nominal_features: np.ndarray = None, 
    seed: int = None
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Create synthetic samples by adding Gaussian noise to existing samples.
    
    Noise magnitude = delta * feature_std
    """
    X, y = np.asarray(X), np.squeeze(np.asarray(y))
    assert len(X) == len(y), 'X and y must have same length'
    
    np.random.seed(seed=seed)
    indices = np.random.choice(range(len(y)), count, replace=True)
    
    X_selected = X[indices]
    y_selected = y[indices]
    
    # Compute standard deviations
    feature_stds = np.std(X, axis=0)
    target_std = np.std(y)
    
    # Generate noise
    X_noise = np.array([
        [np.random.normal(0.0, std * delta) for std in feature_stds]
        for _ in range(count)
    ])
    y_noise = np.random.normal(0.0, target_std * delta, count)
    
    X_synthetic = X_selected + X_noise
    y_synthetic = y_selected + y_noise
    
    # Handle categorical features
    if nominal_features is not None:
        for col in range(X.shape[1]):
            if col in nominal_features:
                vals, freqs = np.unique(X[:, col], return_counts=True)
                probs = freqs / freqs.sum()
                X_synthetic[:, col] = np.random.choice(vals, size=count, p=probs, replace=True)
    
    return X_synthetic, y_synthetic


def _expand_rare_domain(
    X_rare: np.ndarray, 
    y_rare: np.ndarray, 
    target_size: int, 
    method: str, 
    k: int = None, 
    delta: float = None, 
    nominal: np.ndarray = None, 
    seed: int = None
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Expand rare domain to target size using specified method.
    """
    current_size = len(y_rare)
    extra_needed = target_size - current_size
    
    if extra_needed <= 0:
        return X_rare, y_rare
    
    np.random.seed(seed=seed)
    
    if method == 'smoter':
        if k is None:
            raise ValueError("k required for SMOTER")
        X_extra, y_extra = _generate_synthetic_via_interpolation(
            X_rare, y_rare, k, extra_needed, nominal, seed
        )
    
    elif method == 'gaussian':
        if delta is None:
            raise ValueError("delta required for Gaussian noise")
        X_extra, y_extra = _generate_synthetic_via_noise(
            X_rare, y_rare, delta, extra_needed, nominal, seed
        )
    
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return np.append(X_rare, X_extra, axis=0), np.append(y_rare, y_extra)


# =============================================================================
# Main API Functions
# =============================================================================

def smoter(
    X: np.ndarray,
    y: np.ndarray,
    relevance: np.ndarray,
    relevance_threshold: float = 0.5,
    k: int = 5,
    over: Union[float, Literal['balance', 'extreme', 'average']] = 'balance',
    under: Optional[float] = None,
    nominal: Optional[np.ndarray] = None,
    random_state: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray]:
    """
    SMOTER: Synthetic Minority Over-sampling Technique for Regression.
    
    Handles imbalanced regression by:
    - Oversampling rare domain via k-NN interpolation
    - Undersampling normal domain
    
    Parameters
    ----------
    X : ndarray (n_samples, n_features)
        Feature matrix.
    y : ndarray (n_samples,)
        Target values.
    relevance : ndarray (n_samples,)
        Relevance scores in [0, 1]. Higher = more important.
    relevance_threshold : float
        Boundary between normal (< threshold) and rare (>= threshold) domains.
    k : int
        Neighbors for interpolation.
    over : float or str
        Resampling degree. Float for explicit ratio, or 'balance'/'extreme'/'average'.
    under : float, optional
        Undersampling ratio (required if over is float).
    nominal : ndarray, optional
        Indices of categorical features.
    random_state : int, optional
        Seed for reproducibility.
    
    Returns
    -------
    X_resampled, y_resampled : tuple of ndarrays
    """
    X = np.asarray(X)
    y = np.squeeze(np.asarray(y))
    relevance = np.squeeze(np.asarray(relevance))
    
    # Partition data
    X_normal, y_normal, X_rare, y_rare = _split_by_relevance(X, y, relevance, relevance_threshold)
    n_normal, n_rare = len(y_normal), len(y_rare)
    
    # Target sizes
    new_normal, new_rare = _calculate_new_sizes(n_normal, n_rare, over, under)
    
    # Split rare by median for balanced oversampling
    median_y = np.median(y)
    low_mask = y_rare < median_y
    high_mask = y_rare >= median_y
    
    # Oversample low rare values
    X_rare_resampled, y_rare_resampled = [], []
    
    if low_mask.sum() > 0:
        size_low = int(low_mask.sum() / n_rare * new_rare)
        X_low, y_low = _expand_rare_domain(
            X_rare[low_mask], y_rare[low_mask], size_low,
            method='smoter', k=k, nominal=nominal, seed=random_state
        )
        X_rare_resampled.append(X_low)
        y_rare_resampled.append(y_low)
    
    # Oversample high rare values
    if high_mask.sum() > 0:
        size_high = int(high_mask.sum() / n_rare * new_rare)
        X_high, y_high = _expand_rare_domain(
            X_rare[high_mask], y_rare[high_mask], size_high,
            method='smoter', k=k, nominal=nominal, seed=random_state
        )
        X_rare_resampled.append(X_high)
        y_rare_resampled.append(y_high)
    
    # Combine rare parts
    if len(X_rare_resampled) == 2:
        X_rare_new = np.vstack(X_rare_resampled)
        y_rare_new = np.concatenate(y_rare_resampled)
    elif len(X_rare_resampled) == 1:
        X_rare_new, y_rare_new = X_rare_resampled[0], y_rare_resampled[0]
    else:
        X_rare_new, y_rare_new = X_rare, y_rare
    
    # Undersample normal domain
    X_normal_new, y_normal_new = _subsample(X_normal, y_normal, new_normal, random_state)
    
    # Combine final result
    X_final = np.append(X_rare_new, X_normal_new, axis=0)
    y_final = np.append(y_rare_new, y_normal_new)
    
    return (X_final, y_final)


def gaussian_noise(
    X: np.ndarray,
    y: np.ndarray,
    relevance: np.ndarray,
    relevance_threshold: float = 0.5,
    delta: float = 0.05,
    over: Union[float, Literal['balance', 'extreme', 'average']] = None,
    under: Optional[float] = None,
    nominal: Optional[np.ndarray] = None,
    random_state: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Resample imbalanced regression data by adding Gaussian noise.
    
    Handles imbalanced regression by:
    - Oversampling rare domain via Gaussian perturbation
    - Undersampling normal domain
    
    Parameters
    ----------
    X : ndarray (n_samples, n_features)
        Feature matrix.
    y : ndarray (n_samples,)
        Target values.
    relevance : ndarray (n_samples,)
        Relevance scores in [0, 1].
    relevance_threshold : float
        Boundary between domains.
    delta : float
        Noise magnitude as fraction of std.
    over : float or str
        Resampling degree.
    under : float, optional
        Undersampling ratio.
    nominal : ndarray, optional
        Indices of categorical features.
    random_state : int, optional
        Seed for reproducibility.
    
    Returns
    -------
    X_resampled, y_resampled : tuple of ndarrays
    """
    if over is None:
        over = 'balance'
    
    X = np.asarray(X)
    y = np.squeeze(np.asarray(y))
    relevance = np.squeeze(np.asarray(relevance))
    
    # Partition data
    X_normal, y_normal, X_rare, y_rare = _split_by_relevance(X, y, relevance, relevance_threshold)
    n_normal, n_rare = len(y_normal), len(y_rare)
    
    # Target sizes
    new_normal, new_rare = _calculate_new_sizes(n_normal, n_rare, over, under)
    
    # Split rare by median
    median_y = np.median(y)
    low_mask = y_rare < median_y
    high_mask = y_rare >= median_y
    
    # Oversample both halves
    X_rare_resampled, y_rare_resampled = [], []
    
    if low_mask.sum() > 0:
        size_low = int(low_mask.sum() / n_rare * new_rare)
        X_low, y_low = _expand_rare_domain(
            X_rare[low_mask], y_rare[low_mask], size_low,
            method='gaussian', delta=delta, nominal=nominal, seed=random_state
        )
        X_rare_resampled.append(X_low)
        y_rare_resampled.append(y_low)
    
    if high_mask.sum() > 0:
        size_high = int(high_mask.sum() / n_rare * new_rare)
        X_high, y_high = _expand_rare_domain(
            X_rare[high_mask], y_rare[high_mask], size_high,
            method='gaussian', delta=delta, nominal=nominal, seed=random_state
        )
        X_rare_resampled.append(X_high)
        y_rare_resampled.append(y_high)
    
    # Combine rare parts
    if len(X_rare_resampled) == 2:
        X_rare_new = np.vstack(X_rare_resampled)
        y_rare_new = np.concatenate(y_rare_resampled)
    elif len(X_rare_resampled) == 1:
        X_rare_new, y_rare_new = X_rare_resampled[0], y_rare_resampled[0]
    else:
        X_rare_new, y_rare_new = X_rare, y_rare
    
    # Undersample normal domain
    X_normal_new, y_normal_new = _subsample(X_normal, y_normal, new_normal, random_state)
    
    # Combine final result
    X_final = np.append(X_rare_new, X_normal_new, axis=0)
    y_final = np.append(y_rare_new, y_normal_new)
    
    return (X_final, y_final)
