from scipy.optimize import root_scalar
import numpy as np

def solve_alpha(lambdas, n=400, p=150, n1=200):

    def fixed_point_eq(alpha):
        denom = lambdas * alpha + 1 - (p / n) - alpha
        lhs = np.sum(1.0 / denom)
        rhs = (p + n * alpha - n1) / (1 - (p / n) - alpha)
        return lhs - rhs

    lower = max(0.0, (n1 - p) / n + 1e-6)
    upper = 1 - p / n - 1e-6

    sol = root_scalar(fixed_point_eq, bracket=[lower, upper], method='bisect')
    print("baseline: ",n1*(n-p)/(n*n))
    if sol.converged:
        return sol.root
    else:
        raise RuntimeError("Failed to solve for alpha")
        
def empirical_covariance(x):
    x = np.asarray(x)
    x_centered = x - np.mean(x, axis=0)
    cov = x_centered.T @ x_centered / (x.shape[0] - 1)
    return cov
    
def calculate_alpha_estimate(real_feats,gen_feats):
    cov_orig = empirical_covariance(real_feats)
    eigvals1, eigvecs1 = np.linalg.eigh(cov_orig)
    eps = 1e-6
    eigvals1 = np.maximum(eigvals1, eps)
    S1_inv_sqrt = eigvecs1 @ np.diag(1.0 / np.sqrt(eigvals1)) @ eigvecs1.T
    cov_aug = empirical_covariance(gen_feats)
    M = S1_inv_sqrt @ cov_aug @ S1_inv_sqrt
    eigvals_transfer = np.linalg.eigvalsh(M)
    lambdas = np.sort(eigvals_transfer)[::-1]
    n1 = gen_feats.shape[0]
    n = n1 + real_feats.shape[0]
    p = gen_feats.shape[1]
    alpha_solution = solve_alpha(lambdas,n=n,p=p,n1=n1)
    return alpha_solution