from scipy.optimize import root_scalar
import numpy as np

def solve_alpha(lambdas, n=400, p=150, n1=200):

    def fixed_point_eq(alpha):
        denom = lambdas * alpha + 1 - (p / n) - alpha
        lhs = np.sum(1.0 / denom)
        rhs = (p + n * alpha - n1) / (1 - (p / n) - alpha)
        return lhs - rhs

    lower = 0
    upper = 1 - p / n - 1e-6

    sol = root_scalar(fixed_point_eq, bracket=[lower, upper], method='bisect')
    if sol.converged:
        return sol.root
    else:
        raise RuntimeError("Failed to solve for alpha")

def empirical_covariance(x):
    x = np.asarray(x)
    x_centered = x - np.mean(x, axis=0)
    cov = x_centered.T @ x_centered / (x.shape[0] - 1)
    return cov

def calculate_alpha_estimate(real_feats,gen_feats,n=0,do_pca=False):
    if n == 0:
        n = real_feats.shape[0]+ gen_feats.shape[0]
    if do_pca:
        from sklearn.decomposition import PCA
        pca = PCA(n_components=32)
        real_feats = pca.fit_transform(real_feats)
        gen_feats = pca.transform(gen_feats)
    cov_orig = empirical_covariance(real_feats)
    eigvals1, eigvecs1 = np.linalg.eigh(cov_orig)
    eps = 1e-6
    eigvals1 = np.maximum(eigvals1, eps)
    S1_inv_sqrt = eigvecs1 @ np.diag(1.0 / np.sqrt(eigvals1)) @ eigvecs1.T
    cov_aug = empirical_covariance(gen_feats)
    M = S1_inv_sqrt @ cov_aug @ S1_inv_sqrt
    eigvals_transfer = np.linalg.eigvalsh(M)
    lambdas = np.sort(eigvals_transfer)[::-1]
    n1 = gen_feats.shape[0]
    p = gen_feats.shape[1]
    alpha_solution = solve_alpha(lambdas,n=n,p=p,n1=n1)
    estimate = (p + n * alpha_solution - n1) / (1 - (p / n) - alpha_solution)
    return alpha_solution, estimate

def calculate_frob_distance_covs(real_feats, gen_feats):
    cov_real = empirical_covariance(real_feats)
    cov_gen = empirical_covariance(gen_feats)
    frob_distance = np.linalg.norm(cov_real - cov_gen, 'fro')
    return frob_distance
