"""
Feature extraction via binary expansion for CoBET methods.
"""
import numpy as np
import itertools
from scipy.stats import rankdata


def bits_from_uniform(u, K):
    """
    Convert uniform samples to K-bit binary representation.
    
    Parameters
    ----------
    u : np.ndarray, shape (n,)
        Uniform(0,1) samples
    K : int
        Number of bits (dyadic depth)
        
    Returns
    -------
    bits : np.ndarray, shape (K, n)
        Binary representation, most significant bit first
    """
    M = 1 << K  # 2^K
    z = np.minimum((u * M).astype(int), M - 1)  # Discretize to 0..2^K-1
    
    # Extract bits
    bits = np.array([((z >> (K - 1 - k)) & 1).astype(int) for k in range(K)])
    return bits  # (K, n)


def all_nonempty_subsets_indices(K):
    """
    Generate all non-empty subsets of {1, 2, ..., K}.
    
    Parameters
    ----------
    K : int
        Set size
        
    Returns
    -------
    subsets : list of tuples
        All 2^K - 1 non-empty subsets
        
    Examples
    --------
    >>> all_nonempty_subsets_indices(2)
    [(1,), (2,), (1, 2)]
    """
    idx = list(range(1, K + 1))
    subsets = []
    for r in range(1, K + 1):
        subsets.extend(itertools.combinations(idx, r))
    return subsets  # length = 2^K - 1


def features_by_u(u, K, subsets):
    """
    Construct centered indicator features from uniform samples.
    
    For each subset S, compute the product of bits in S, then center.
    
    Parameters
    ----------
    u : np.ndarray, shape (n,)
        Uniform(0,1) samples
    K : int
        Dyadic depth
    subsets : list of tuples
        Subsets to use (from all_nonempty_subsets_indices)
        
    Returns
    -------
    F : np.ndarray, shape (len(subsets), n)
        Centered feature matrix
    """
    bits = bits_from_uniform(u, K)  # (K, n)
    n = u.shape[0]
    F = np.empty((len(subsets), n), dtype=float)
    
    for i, S in enumerate(subsets):
        rows = [r - 1 for r in S]  # Convert 1-indexed to 0-indexed
        ind = np.prod(bits[rows, :], axis=0)
        # Center: E[indicator] = 2^(-|S|)
        F[i, :] = ind.astype(float) - (2.0 ** (-len(S)))
    
    return F


def ranks_to_uniforms(X):
    """
    Convert data matrix to pseudo-uniforms via ranks.
    
    Parameters
    ----------
    X : np.ndarray, shape (n, d)
        Data matrix
        
    Returns
    -------
    U : np.ndarray, shape (n, d)
        Pseudo-uniform samples via rank transformation
    """
    n, d = X.shape
    U = np.empty_like(X, dtype=float)
    
    for j in range(d):
        U[:, j] = rankdata(X[:, j]) / (n + 1.0)
    
    return U


def build_AB_features(X, Y, K, subsets=None):
    """
    Build feature matrices A and B from data matrices X and Y.
    
    Parameters
    ----------
    X, Y : np.ndarray, shape (n, d)
        Data matrices
    K : int
        Dyadic depth
    subsets : list of tuples, optional
        Subsets to use. If None, uses all non-empty subsets.
        
    Returns
    -------
    A, B : np.ndarray, shape (d * (2^K - 1), n)
        Feature matrices for X and Y
    """
    if subsets is None:
        subsets = all_nonempty_subsets_indices(K)
    
    n, d = X.shape
    
    # Convert to uniforms via ranks
    X_u = ranks_to_uniforms(X)
    Y_u = ranks_to_uniforms(Y)
    
    # Build features for each coordinate
    feats_A = []
    feats_B = []
    
    for r in range(d):
        F_A = features_by_u(X_u[:, r], K, subsets)
        F_B = features_by_u(Y_u[:, r], K, subsets)
        feats_A.append(F_A)
        feats_B.append(F_B)
    
    # Stack all coordinates
    A = np.vstack(feats_A)  # (d * (2^K - 1), n)
    B = np.vstack(feats_B)  # (d * (2^K - 1), n)
    
    return A, B


def block_view(M, r, base_dim):
    """
    Extract coordinate block r from stacked feature matrix.
    
    Parameters
    ----------
    M : np.ndarray, shape (d * base_dim, n)
        Stacked feature matrix
    r : int
        Coordinate index (0-indexed)
    base_dim : int
        Features per coordinate (typically 2^K - 1)
        
    Returns
    -------
    M_r : np.ndarray, shape (base_dim, n)
        Features for coordinate r
    """
    return M[r * base_dim:(r + 1) * base_dim, :]
