

import numpy as np
from scipy.optimize import nnls, lsq_linear


def get_square_aligned_bins_fixed():
    """
    Generate the exact square-aligned bins as specified
    Returns:
        List of (start, end) tuples
    """
    bins = [
        (1, 1), (2, 4), (5, 9), (10, 16), (17, 25), (26, 36), (37, 49), (50, 64),
        (65, 81), (82, 100), (101, 121), (122, 144), (145, 169), (170, 196), (197, 225)
    ]
    return bins


def get_square_aligned_bins(k_max):
    """
    Generate square-aligned bins up to k_max from the fixed list.

    Args:
        k_max: Maximum degree to include

    Returns:
        List of (start, end) tuples up to k_max
    """
    all_bins = get_square_aligned_bins_fixed()

    # Add k=0 bin at the beginning
    bins = [(0, 0)]

    # Add bins up to k_max
    for start, end in all_bins:
        if start > k_max:
            break
        if end <= k_max:
            bins.append((start, end))
        else:
            # Truncate last bin at k_max
            bins.append((start, k_max))
            break

    return bins


def get_binning_for_T(T):
    """
    Get square-aligned binning for a given T.

    Args:
        T: Temporal depth

    Returns:
        dict with keys: 'bins', 'bin_labels', 'bin_centers', 'k_max', 'n_bins'
    """
    k_max = (2 * T + 1) ** 2
    bins = get_square_aligned_bins(k_max)

    # Create labels
    bin_labels = []
    bin_centers = []

    for start, end in bins:
        if start == end:
            label = f"{start}"
        else:
            label = f"{start}-{end}"
        bin_labels.append(label)
        bin_centers.append((start + end) / 2.0)

    return {
        'bins': bins,
        'bin_labels': bin_labels,
        'bin_centers': np.array(bin_centers),
        'k_max': k_max,
        'n_bins': len(bins)
    }





def create_noise_sensitivity_matrix(bins, delta_values, n=256):
    """
    Create design matrix A where A[i,j] relates NS at delta[i] to W_k in bin j.

    NS_delta = sum_k W_k * [1 - (1-2*delta)^k]

    Args:
        bins: List of (start, end) tuples
        delta_values: Array of noise rates
        n: Number of input variables (256 for 16x16 patches)

    Returns:
        A: Design matrix of shape (len(delta_values), len(bins))
    """
    n_delta = len(delta_values)
    n_bins = len(bins)
    A = np.zeros((n_delta, n_bins))

    for i, delta in enumerate(delta_values):
        for j, (start, end) in enumerate(bins):
            # Sum contribution from all degrees in this bin
            bin_contribution = 0.0
            for k in range(start, end + 1):
                if k == 0:
                    # k=0 term doesn't contribute to noise sensitivity
                    continue
                # Contribution: [1 - (1-2*delta)^k]
                bin_contribution += 1.0 - (1.0 - 2.0 * delta) ** k

            # Average over bin (each degree gets equal weight)
            A[i, j] = bin_contribution / (end - start + 1)

    return A


def fit_constrained_parseval(ns_values, A, weight_penalty=5.0, use_ridge=True, ridge_alpha=0.001):
    """
    Fit W_k coefficients with Parseval constraint: sum(W_k) = 1.

    Uses bounded least squares with soft Parseval constraint.

    Args:
        ns_values: Measured noise sensitivities
        A: Design matrix from create_noise_sensitivity_matrix
        weight_penalty: Weight for soft sum=1 constraint
        use_ridge: Whether to use ridge regularization
        ridge_alpha: Ridge regularization strength

    Returns:
        weights: Fitted W_k values (normalized to sum=1)
        residual: Fit residual
    """
    n_bins = A.shape[1]

    # Add ridge regularization if requested
    if use_ridge:
        A_aug = np.vstack([
            A,
            np.sqrt(ridge_alpha) * np.eye(n_bins)
        ])
        b_aug = np.append(ns_values, np.zeros(n_bins))
    else:
        A_aug = A.copy()
        b_aug = ns_values.copy()

    # Add soft constraint: sum(W_k) = 1
    A_aug = np.vstack([
        A_aug,
        weight_penalty * np.ones((1, n_bins))
    ])
    b_aug = np.append(b_aug, weight_penalty)

    # Bounded least squares: 0 <= W_k <= 1
    result = lsq_linear(A_aug, b_aug, bounds=(0, 1), method='bvls', max_iter=10000)
    weights = result.x

    # Normalize to ensure exact Parseval
    weights = weights / (weights.sum() + 1e-10)

    # Calculate residual on original problem (without augmentation)
    residual = np.linalg.norm(A @ weights - ns_values)

    return weights, residual


def compute_derived_features(df, bins):
    """
    Compute derived features from W_k coefficients.

    Args:
        df: DataFrame with W_k columns
        bins: List of (start, end) tuples

    Returns:
        DataFrame with added columns: avg_degree, W_high, W_low, sum_Wk
    """
    df = df.copy()

    # Get W_k column names
    wk_cols = [f"W_{s}_{e}" for s, e in bins]
    wk_cols = [col for col in wk_cols if col in df.columns]

    if len(wk_cols) == 0:
        # No W_k columns, return as-is
        return df

    # Average spectral degree: sum_k k * W_k
    avg_degree = np.zeros(len(df))
    for col, (start, end) in zip(wk_cols, bins):
        k_center = (start + end) / 2.0
        avg_degree += df[col].values * k_center
    df['avg_degree'] = avg_degree

    # High-degree weight (k > 6)
    W_high = np.zeros(len(df))
    for col, (start, end) in zip(wk_cols, bins):
        if start > 6:
            W_high += df[col].values
        elif end > 6:
            # Bin straddles threshold - apportion proportionally
            frac_high = (end - 6) / (end - start + 1)
            W_high += df[col].values * frac_high
    df['W_high'] = W_high

    # Low-degree weight (k <= 6)
    df['W_low'] = 1.0 - df['W_high']

    # Sum of W_k (should be 1.0 if Parseval satisfied)
    df['sum_Wk'] = df[wk_cols].sum(axis=1)

    return df



