
import numpy as np
from sklearn.mixture import GaussianMixture

def find_optimal_k(K_values, proxy_values):
    """
    Selects the optimal K using a curvature-based elbow method.

    Args:
        K_values (list or np.ndarray): List of candidate K values.
        proxy_values (list or np.ndarray): Corresponding proxy scores.

    Returns:
        int: Optimal value of K.
    """
    K = np.array(K_values)
    P = np.array(proxy_values)

    # Sort by K to ensure proper order
    idx_sorted = np.argsort(K)
    K = K[idx_sorted]
    P = P[idx_sorted]

    # Compute discrete slopes and curvature
    delta_K = np.diff(K)
    delta_P = np.diff(P)
    slope = delta_P / delta_K
    curvature = np.diff(slope) / delta_K[1:]

    # Get index of maximum curvature (offset by 1 due to diff)
    optimal_idx = np.argmax(curvature) + 1
    return int(K[optimal_idx])


def estimate_posterior_dropout(score_matrix, n_components=2):
    """
    Estimates the posterior-based dropout ratio using a GMM on flattened score matrix.

    Args:
        score_matrix (np.ndarray): 2D array of shape (N, K) with importance scores.
        n_components (int): Number of GMM components (default=2).

    Returns:
        float: Posterior-based dropout ratio.
    """
    scores = score_matrix.flatten().reshape(-1, 1)

    # Fit GMM to the score distribution
    gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
    gmm.fit(scores)

    # Identify non-salient component (component with smaller mean)
    means = gmm.means_.flatten()
    non_salient = np.argmin(means)

    # Compute posterior probabilities and average probability of non-salient
    probs = gmm.predict_proba(scores)
    dropout_ratio = np.mean(probs[:, non_salient])

    return dropout_ratio

def calculate_tau(K, r):
    """
    Computes the temperature parameter τ based on the number of clusters K
    and dropout ratio r using the formula: τ = 1 / log(1 + rK)

    Args:
        K (int or float): Number of clusters.
        r (float): Dropout ratio (e.g., from 0 to 1).

    Returns:
        float: Temperature value τ.
    """
    return 1.0 / np.log(1 + r * K)

