import numpy as np
from scipy.special import erfc

def filter_gaussian_mean(data, eps, tau, cher=2.5):
    N, d = data.shape
    empirical_mean = np.mean(data, axis=0)
    threshold = eps * np.log(1 / eps)
    data_centered = (data - empirical_mean) / np.sqrt(N)

    # Compute singular value decomposition to obtain the largest eigenvalue and eigenvector (direction with highest variance) --> potentially most affected by outliers
    [U, S, _] = np.linalg.svd(data_centered.T, full_matrices=False)
    lambda_ = S[0]**2
    v = U[:, 0]

    # If the largest eigenvalue is not significantly larger than 1, return the empirical mean (barely affected by outliers)
    if lambda_ < 1 + 3 * threshold:
       return empirical_mean
    
    delta = 2 * eps
    proj = np.dot(data_centered, v)
    med = np.median(proj)

    data_projected = np.column_stack([np.abs(proj - med), data])
    data_projected_sorted = data_projected[np.argsort(data_projected[:, 0])]

    for ii in range(N):
        T = data_projected_sorted[ii, 0] - delta
        if (N - ii) > cher * N * (erfc(T / np.sqrt(2)) / 2 + eps / (d * np.log(d * eps / tau))):
            break
    # If no suitable threshold is found, return the empirical mean
    if ii == 0 or ii == N-1:
        return empirical_mean
    
    return filter_gaussian_mean(data_projected_sorted[:ii, 1:], eps, tau, cher)

def find_smallest_alpha(lower_bound, beta_hat):
    alpha_hat = None
    # Find smallest alpha (> 1/2 + 2*wlow) satisfying certain condition
    for alpha in np.linspace(lower_bound, 1, 1000):
        if (np.sqrt(np.log(1/alpha)) <= beta_hat/2):
            alpha_hat = alpha
            break

    return alpha_hat

def robust_mean_estimation(data, mean, threshold, alpha_min, eps, tau):
    EPS_RME = 1/2 - 2 * alpha_min
    beta_hat = threshold
    
    alpha_hat = find_smallest_alpha(1 - EPS_RME, beta_hat)
    
    mean_hat = mean
    if alpha_hat is None:
        return mean_hat
    
    mean_rme = filter_gaussian_mean(data, eps, tau)
    while np.linalg.norm(mean_hat - mean_rme) <= 3/2 * beta_hat:
        print(np.linalg.norm(mean_hat - mean_rme))
        mean_hat = mean_rme
        beta_hat = np.sqrt(np.log(1/alpha_hat))
        alpha_hat_prime = find_smallest_alpha(alpha_hat + alpha_min**2, beta_hat)
        if alpha_hat_prime is None:
            break
        alpha_hat = alpha_hat_prime
        mean_rme = filter_gaussian_mean(data, eps, tau)

    return mean_hat