import numpy as np
import matplotlib.pyplot as plt
import os, sys
from Optimization.RiemannianBCD import RiemannianBlockCoordinateDescent



# -------------------------
# 1) Model construction (same as your trade-off script)
# -------------------------

def build_tau_signal_exp_decay(r: int, E_sig: float, eta_sig: float) -> np.ndarray:
    powers = eta_sig ** np.arange(r, dtype=float)
    C_sig = E_sig / float(np.sum(powers))
    return C_sig * powers

def build_tau_residual_powerlaw(k_res: int, E_res: float, alpha: float) -> np.ndarray:
    j = np.arange(1, k_res + 1, dtype=float)
    w = j ** (-alpha)
    w /= float(np.sum(w))
    return E_res * w

def tau_to_cov_eigs(tau: np.ndarray) -> np.ndarray:
    # If tau_i are eigenvalues of M=(Sigma^{1/2}-I)^2 then Sigma eigenvalues are (1+sqrt(tau_i))^2.
    return (1.0 + np.sqrt(tau)) ** 2

def fixed_orthogonal_basis(d: int, seed: int) -> np.ndarray:
    rng = np.random.default_rng(seed)
    A = rng.standard_normal((d, d))
    Q, _ = np.linalg.qr(A)
    if np.linalg.det(Q) < 0:
        Q[:, 0] *= -1
    return Q

def make_sigma_from_tau(tau: np.ndarray, Q_basis: np.ndarray) -> np.ndarray:
    lam = tau_to_cov_eigs(tau)
    Sigma = Q_basis @ np.diag(lam) @ Q_basis.T
    Sigma = (Sigma + Sigma.T) / 2.0
    return Sigma

def sample_gaussian_pair(d: int, Sigma: np.ndarray, n: int, rng: np.random.Generator):
    X = rng.standard_normal((n, d))  # ~ N(0, I)
    L = np.linalg.cholesky(Sigma + 1e-12 * np.eye(d))
    Y = rng.standard_normal((n, d)) @ L.T  # ~ N(0, Sigma)
    return X, Y


# -------------------------
# 2) Residual sliced estimator (same logic as your trade-off script)
# -------------------------

def orthogonal_complement(U: np.ndarray, seed: int) -> np.ndarray:
    d, k = U.shape
    rng = np.random.default_rng(seed)
    R = rng.standard_normal((d, d - k))
    Q, _ = np.linalg.qr(np.hstack([U, R]))
    return Q[:, k:]

def w2_1d_sorted(x: np.ndarray, y: np.ndarray) -> float:
    xs = np.sort(x)
    ys = np.sort(y)
    return float(np.mean((xs - ys) ** 2))

def sliced_w2_residual(X_perp: np.ndarray, Y_perp: np.ndarray, L: int, seed: int) -> float:
    rng = np.random.default_rng(seed)
    n, m = X_perp.shape
    if m == 0:
        return 0.0
    acc = 0.0
    for _ in range(L):
        theta = rng.standard_normal(m)
        theta /= (np.linalg.norm(theta) + 1e-12)
        acc += w2_1d_sorted(X_perp @ theta, Y_perp @ theta)
    return acc / float(L)


# -------------------------
# 3) Gaussian plug-in Wasserstein (Bures) estimator for baseline W
# -------------------------

def sqrtm_psd(A: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    """Matrix square root for symmetric PSD matrix using eigen-decomposition."""
    A = (A + A.T) / 2.0
    w, V = np.linalg.eigh(A)
    w = np.clip(w, eps, None)
    return (V * np.sqrt(w)) @ V.T

def w2_gaussian_plugin_from_samples(Y: np.ndarray) -> float:
    """Plug-in estimator of W2^2 between N(0, I) and N(0, Sigma) using sample cov of Y."""
    n, d = Y.shape
    Yc = Y - np.mean(Y, axis=0, keepdims=True)
    Sigma_hat = (Yc.T @ Yc) / float(n)
    S_half = sqrtm_psd(Sigma_hat)
    I = np.eye(d)
    return float(np.trace((S_half - I) @ (S_half - I)))


# -------------------------
# 4) Single-run (fixed k_star) evaluation
# -------------------------
def run_single_trial_fixed_k(
    X: np.ndarray,
    Y: np.ndarray,
    rbcd: RiemannianBlockCoordinateDescent,
    k_star: int,
    sw_L: int,
    seed: int,
):
    """Return (wpp_hat, lb_hat, U_hat) at fixed subspace dim k_star."""
    n, d = X.shape
    a = np.ones(n) / n
    b = np.ones(n) / n

    U0 = rbcd.InitialStiefel(d, k_star)
    _, U_hat, _, f_val, _ = rbcd.run_RBCD(a, b, X, Y, k_star, U0)
    wpp_hat = float(f_val)

    if k_star < d:
        U_perp = orthogonal_complement(U_hat, seed=seed + 999)
        Xp = X @ U_perp
        Yp = Y @ U_perp
        sw2 = sliced_w2_residual(Xp, Yp, L=sw_L, seed=seed + 2024)
        lb_hat = wpp_hat + (d - k_star) * sw2
    else:
        lb_hat = wpp_hat

    return wpp_hat, lb_hat, U_hat


def run_single_trial_full_dim_w2(
    X: np.ndarray,
    Y: np.ndarray,
    rbcd: RiemannianBlockCoordinateDescent,
):
    """
    Full-dimensional baseline: set k=d and U=I_d, so PRW/WPP reduces to the standard
    (regularized) empirical W2 in R^d.
    """
    n, d = X.shape
    a = np.ones(n) / n
    b = np.ones(n) / n

    k_full = d
    U_full = np.eye(d)  # Stiefel(d,d)
    _, _, _, f_val, _ = rbcd.run_RBCD(a, b, X, Y, k_full, U_full)
    return float(f_val)


def projection_fro_error(U_hat: np.ndarray, U_true: np.ndarray) -> float:
    """||U_true U_true^T - U_hat U_hat^T||_F."""
    P_hat = U_hat @ U_hat.T
    P_true = U_true @ U_true.T
    return float(np.linalg.norm(P_true - P_hat, ord="fro"))


# -------------------------
# 5) Aggregation + plotting (quantile bands)
# -------------------------

def summarize_bands(values: np.ndarray):
    """
    values: shape (n_trials, n_points). Return median & quantile bands per point.
    """
    return {
        "median": np.median(values, axis=0),
        "q10": np.quantile(values, 0.10, axis=0),
        "q90": np.quantile(values, 0.90, axis=0),
        "q25": np.quantile(values, 0.25, axis=0),
        "q75": np.quantile(values, 0.75, axis=0),
    }

def plot_with_bands(ax, x, bands, label: str, linestyle: str = "-", marker: str = "o"):
    ax.plot(x, bands["median"], linestyle=linestyle, marker=marker, label=label)
    ax.fill_between(x, bands["q10"], bands["q90"], alpha=0.15)
    ax.fill_between(x, bands["q25"], bands["q75"], alpha=0.25)

def make_twofig_pdf(
    n_list,
    mee_bands_dict,
    subspace_bands_dict,
    out_pdf: str,
):
    fig, axes = plt.subplots(1, 2, figsize=(11, 4.2))

    # Left: MEE
    ax = axes[0]
    for name, bands in mee_bands_dict.items():
        plot_with_bands(ax, n_list, bands, label=name)
    ax.set_title("Mean estimation error")
    ax.set_xlabel("Number of points n")
    ax.set_ylabel("Absolute error")
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.grid(True, alpha=0.3)
    ax.legend()

    # Right: subspace error
    ax = axes[1]
    for name, bands in subspace_bands_dict.items():
        plot_with_bands(ax, n_list, bands, label=name)
    ax.set_title("Mean subspace estimation error")
    ax.set_xlabel("Number of points n")
    ax.set_ylabel(r"$\| \Omega^* - \widehat{\Omega} \|_F$")
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.grid(True, alpha=0.3)
    ax.legend()

    plt.tight_layout()
    # PDF output
    plt.savefig(out_pdf, bbox_inches="tight", pad_inches=0.02)
    print(f"Saved: {out_pdf}")


# -------------------------
# 6) Main experiment
# -------------------------

def main():
    # ====== (A) Core settings you may tune ======
    d = 60
    r = 10                 # signal rank (ground-truth signal subspace dimension)
    k_star = 10            # fixed k* used by WPP/LB in this two-fig experiment

    # spectrum design (same style as your text)
    E_total = 120.0
    rho = 0.15             # residual energy fraction
    eta_sig = 1.0          # signal decay (1.0 = flat signal; <1 geometric decay)
    alpha = 1.0            # residual anisotropy (power-law)

    # finite-sample grid
    n_list = np.array([25, 50, 100, 250, 500, 1000], dtype=int)
    n_trials = 20         # set to 10 if you want to match the reference caption
    sw_L = 1000            # number of residual slices for LB

    seed_basis = 12345
    assert 1 <= k_star <= d

    # ====== (B) Build one Sigma with spiked + eigengap ======
    E_sig = (1.0 - rho) * E_total
    E_res = rho * E_total

    tau_sig = build_tau_signal_exp_decay(r, E_sig, eta_sig)
    tau_res = build_tau_residual_powerlaw(d - r, E_res, alpha)
    tau = np.concatenate([tau_sig, tau_res])

    # True W2^2 is tr(M) = sum tau by construction
    W2_true = float(np.sum(tau))

    Q_basis = fixed_orthogonal_basis(d, seed=seed_basis)
    Sigma = make_sigma_from_tau(tau, Q_basis)

    # Ground-truth signal subspace (first r directions in your construction basis)
    U_true = Q_basis[:, :k_star]

    # ====== (C) RBCD optimizer ======
    rbcd = RiemannianBlockCoordinateDescent(
        eta=2.0, tau=0.1, max_iter=200, threshold=1e-6,
        verbose=False, use_gpu=False
    )

    # ====== (D) Run trials across n ======
    err_W = np.zeros((len(n_list), n_trials))
    err_WPP = np.zeros((len(n_list), n_trials))
    err_LB = np.zeros((len(n_list), n_trials))
    err_sub = np.zeros((len(n_list), n_trials))

    for i, n in enumerate(n_list):
        for t in range(n_trials):
            rng = np.random.default_rng(100000 + 1000 * i + t)
            X, Y = sample_gaussian_pair(d, Sigma, n, rng)

            # baseline: full-dimensional empirical W2 via the same RBCD solver (k=d, U=I)
            W_hat = run_single_trial_full_dim_w2(X=X, Y=Y, rbcd=rbcd)
            err_W[i, t] = abs(W_hat - W2_true)

            # WPP & LB at fixed k_star + estimated subspace
            wpp_hat, lb_hat, U_hat = run_single_trial_fixed_k(
                X=X, Y=Y, rbcd=rbcd, k_star=k_star, sw_L=sw_L,
                seed=200000 + 1000 * i + t,
            )
            err_WPP[i, t] = abs(wpp_hat - W2_true)
            err_LB[i, t] = abs(lb_hat - W2_true)
            err_sub[i, t] = projection_fro_error(U_hat, U_true)

        print(f"Done n={n}")

    # summarize bands (need shape: (n_trials, n_points))
    mee_bands = {
        "W (plug-in)": summarize_bands(err_W.T),
        "WPP (RBCD)": summarize_bands(err_WPP.T),
        "LB (RBCD)": summarize_bands(err_LB.T),
    }
    subspace_bands = {
        "Subspace (RBCD)": summarize_bands(err_sub.T),
    }

    make_twofig_pdf(
        n_list=n_list,
        mee_bands_dict=mee_bands,
        subspace_bands_dict=subspace_bands,
        out_pdf="twofigs_finite_sample.pdf",
    )

if __name__ == "__main__":
    main()
