import numpy as np
import matplotlib.pyplot as plt
import os, sys
from Optimization.RiemannianBCD import RiemannianBlockCoordinateDescent


# -------------------------
# 1) Tau & Model Design
# -------------------------

def build_tau_signal_exp_decay(r: int, E_sig: float, eta_sig: float) -> np.ndarray:
    powers = eta_sig ** np.arange(r, dtype=float)
    C_sig = E_sig / float(np.sum(powers))
    return C_sig * powers

def build_tau_residual_powerlaw(k_res: int, E_res: float, alpha: float) -> np.ndarray:
    j = np.arange(1, k_res + 1, dtype=float)
    w = j ** (-alpha)
    w /= float(np.sum(w))
    return E_res * w

def tau_to_cov_eigs(tau: np.ndarray) -> np.ndarray:
    return (1.0 + np.sqrt(tau)) ** 2

def fixed_orthogonal_basis(d: int, seed: int) -> np.ndarray:
    rng = np.random.default_rng(seed)
    A = rng.standard_normal((d, d))
    Q, _ = np.linalg.qr(A)
    if np.linalg.det(Q) < 0:
        Q[:, 0] *= -1
    return Q

def make_sigmas_from_taus(d: int, taus: dict, Q_basis: np.ndarray) -> dict:
    out = {}
    for key, tau in taus.items():
        lam = tau_to_cov_eigs(tau)
        Sigma = Q_basis @ np.diag(lam) @ Q_basis.T
        Sigma = (Sigma + Sigma.T) / 2.0
        out[key] = Sigma
    return out

def sample_gaussian_pair(d: int, Sigma: np.ndarray, n: int, rng: np.random.Generator):
    X = rng.standard_normal((n, d))
    L = np.linalg.cholesky(Sigma + 1e-12 * np.eye(d))
    Y = rng.standard_normal((n, d)) @ L.T
    return X, Y

# -------------------------
# 2) Population Calculations
# -------------------------

def population_wpp_from_tau(tau: np.ndarray, k_values: np.ndarray) -> np.ndarray:
    tau_sorted = np.sort(tau)[::-1]
    csum = np.cumsum(tau_sorted)
    # Handle k values that might exceed tau length (though unlikely in this setup)
    res = []
    for k in k_values:
        if k <= len(csum):
            res.append(csum[k-1])
        else:
            res.append(csum[-1])
    return np.array(res)

def sw2_pop_diag(s_diag: np.ndarray, n_mc: int, rng: np.random.Generator) -> float:
    m = len(s_diag)
    if m == 0: return 0.0
    g = rng.gamma(shape=0.5, scale=1.0, size=(n_mc, m))
    w = g / np.sum(g, axis=1, keepdims=True)
    val = np.sum(w * (s_diag ** 2), axis=1)
    return float(np.mean((np.sqrt(val) - 1.0) ** 2))

def population_lb_from_tau(tau: np.ndarray, d: int, k_values: np.ndarray, n_mc_sw: int, seed: int) -> np.ndarray:
    tau_sorted = np.sort(tau)[::-1]
    wpp = population_wpp_from_tau(tau, k_values)
    lb = np.zeros_like(wpp)

    base_rng = np.random.default_rng(seed)
    for idx, k in enumerate(k_values):
        if k >= d:
            lb[idx] = wpp[idx]
            continue
        tau_res = tau_sorted[k:]
        s_res = 1.0 + np.sqrt(tau_res)
        rng_k = np.random.default_rng(base_rng.integers(0, 2**32 - 1) + int(k))
        sw2 = sw2_pop_diag(s_res, n_mc=n_mc_sw, rng=rng_k)
        lb[idx] = wpp[idx] + (d - k) * sw2
    return lb

# -------------------------
# 3) Finite Sample Helpers
# -------------------------

def extend_U_warmstart(U_prev: np.ndarray, seed: int) -> np.ndarray:
    d, k_prev = U_prev.shape
    rng = np.random.default_rng(seed)
    v = rng.standard_normal((d, 1))
    U = np.hstack([U_prev, v])
    Q, _ = np.linalg.qr(U)
    return Q[:, :k_prev + 1]

def orthogonal_complement(U: np.ndarray, seed: int) -> np.ndarray:
    d, k = U.shape
    rng = np.random.default_rng(seed)
    R = rng.standard_normal((d, d - k))
    Q, _ = np.linalg.qr(np.hstack([U, R]))
    return Q[:, k:]

def w2_1d_sorted(x: np.ndarray, y: np.ndarray) -> float:
    xs = np.sort(x)
    ys = np.sort(y)
    return float(np.mean((xs - ys) ** 2))

def sliced_w2_residual(X_perp: np.ndarray, Y_perp: np.ndarray, L: int, seed: int) -> float:
    rng = np.random.default_rng(seed)
    n, m = X_perp.shape
    if m == 0: return 0.0
    acc = 0.0
    for _ in range(L):
        theta = rng.standard_normal(m)
        theta /= (np.linalg.norm(theta) + 1e-12)
        acc += w2_1d_sorted(X_perp @ theta, Y_perp @ theta)
    return acc / float(L)

def run_single_trial_curve(X: np.ndarray, Y: np.ndarray, k_values: np.ndarray,
                           rbcd: RiemannianBlockCoordinateDescent,
                           sw_L: int, seed: int):
    n, d = X.shape
    a = np.ones(n) / n
    b = np.ones(n) / n

    wpp_hat = np.zeros(len(k_values))
    lb_hat = np.zeros(len(k_values))

    U = None
    for idx, k in enumerate(k_values):
        if U is None:
            U = rbcd.InitialStiefel(d, k)
        else:
            U = extend_U_warmstart(U, seed=seed + 17 * k)

        _, U, _, f_val, _ = rbcd.run_RBCD(a, b, X, Y, k, U)
        wpp_hat[idx] = float(f_val)

        if k < d:
            U_perp = orthogonal_complement(U, seed=seed + 999 + 31 * k)
            Xp = X @ U_perp
            Yp = Y @ U_perp
            sw2 = sliced_w2_residual(Xp, Yp, L=sw_L, seed=seed + 2024 + 101 * k)
            lb_hat[idx] = wpp_hat[idx] + (d - k) * sw2
        else:
            lb_hat[idx] = wpp_hat[idx]

    return wpp_hat, lb_hat

# -------------------------
# 4) Plotting Function
# -------------------------

def plot_trade_off_graph(results, k_values, save_path):
    """
    Plots the 2x3 grid of trade-off curves.
    Row 1: WPP (Pop, Finite, Total)
    Row 2: LB (Pop, Finite, Total)
    Cols: Alphas
    """
    alphas = sorted(list(results.keys()))
    n_cols = len(alphas)
    
    # Setup figure: 2 Rows, n_cols Columns
    fig, axes = plt.subplots(2, n_cols, figsize=(6 * n_cols, 10), squeeze=False)
    
    # Define colors for different alphas to keep them distinct across columns
    # Or use a consistent color scheme for line types. 
    # Here we use standard colors for line types to make the legend consistent.
    color_pop = 'tab:blue'
    color_fs = 'tab:orange'
    color_total = 'tab:green'

    for i, alpha in enumerate(alphas):
        data = results[alpha]
        
        # --- Row 0: WPP Trade-off ---
        ax_wpp = axes[0, i]
        
        # Population Error (Bias): |True - Pop|
        ax_wpp.plot(k_values, data['wpp']['pop_err'], linestyle=':', linewidth=2.5, 
                    color=color_pop, label='Population Error (Bias)')
        
        # Finite Sample Error (Variance): |Pop - Emp|
        ax_wpp.plot(k_values, data['wpp']['fs_err'], linestyle='--', linewidth=2.5, 
                    color=color_fs, label='Finite Sample Error')
        
        # Total Error: |True - Emp|
        ax_wpp.plot(k_values, data['wpp']['total_err'], linestyle='-', linewidth=2.5, 
                    color=color_total, label='Total Error')
        
        ax_wpp.set_title(f'WPP Trade-off ($\\alpha={alpha}$)')
        ax_wpp.set_xlabel('Subspace Dimension $k^\star$')
        ax_wpp.set_ylabel('Absolute Error')
        ax_wpp.grid(True, alpha=0.3)
        if i == 0: # Only show legend in the first column to avoid clutter
            ax_wpp.legend()

        # --- Row 1: LB Trade-off ---
        ax_lb = axes[1, i]
        
        # Population Error (Bias)
        ax_lb.plot(k_values, data['lb']['pop_err'], linestyle=':', linewidth=2.5, 
                   color=color_pop, label='Population Error (Bias)')
        
        # Finite Sample Error (Variance)
        ax_lb.plot(k_values, data['lb']['fs_err'], linestyle='--', linewidth=2.5, 
                   color=color_fs, label='Finite Sample Error')
        
        # Total Error
        ax_lb.plot(k_values, data['lb']['total_err'], linestyle='-', linewidth=2.5, 
                   color=color_total, label='Total Error')
        
        ax_lb.set_title(f'LB Trade-off ($\\alpha={alpha}$)')
        ax_lb.set_xlabel('Subspace Dimension $k^\star$')
        ax_lb.set_ylabel('Absolute Error')
        ax_lb.grid(True, alpha=0.3)
        if i == 0:
            ax_lb.legend()

    plt.tight_layout()
    plt.tight_layout()
    plt.savefig("finite_sample_tradeoff_clean.pdf", bbox_inches="tight", pad_inches=0.02)
    print(f"Trade-off graph saved to {save_path}")

# -------------------------
# 5) Main Execution
# -------------------------

def main():
    # Experiment Parameters
    d = 60
    r = 10
    E_total = 120.0
    rho = 0.15
    eta_sig = 1.0
    alphas = [0.0, 1.0, 1.5]  
    n = 400
    n_trials = 5
    k_max = 30
    k_values = np.arange(1, k_max + 1)
    sw_L = 1000
    pop_sw_mc = 20000

    # 1. Setup Model & Population Curves
    E_sig = (1.0 - rho) * E_total
    E_res = rho * E_total
    tau_sig = build_tau_signal_exp_decay(r, E_sig, eta_sig)
    
    taus = {}
    pop_curves = {} # Store raw population values (not errors yet)
    
    print("Computing Population Curves...")
    for alpha in alphas:
        tau_res = build_tau_residual_powerlaw(d - r, E_res, alpha)
        tau = np.concatenate([tau_sig, tau_res])
        taus[alpha] = tau
        
        # Calculate Population Values
        wpp_pop = population_wpp_from_tau(tau, k_values)
        lb_pop = population_lb_from_tau(tau, d, k_values, n_mc_sw=pop_sw_mc, seed=777 + int(100 * alpha))
        
        pop_curves[alpha] = {
            'wpp': wpp_pop,
            'lb': lb_pop,
            'true_w2': float(np.sum(tau))
        }

    # 2. Finite Sample Simulation
    Q_basis = fixed_orthogonal_basis(d, seed=12345)
    Sigmas = make_sigmas_from_taus(d, taus, Q_basis)
    
    # Storage for raw empirical results: raw_emp[alpha]['wpp'] -> list of arrays
    raw_emp = {alpha: {'wpp': [], 'lb': []} for alpha in alphas}

    rbcd = RiemannianBlockCoordinateDescent(
        eta=2.0, tau=0.1, max_iter=200, threshold=1e-6,
        verbose=False, use_gpu=False
    )

    print(f"Running {n_trials} trials...")
    for t in range(n_trials):
        rngX = np.random.default_rng(10000 + t)
        X = rngX.standard_normal((n, d))

        for alpha in alphas:
            Sigma = Sigmas[alpha]
            rngY = np.random.default_rng(20000 + 97 * t + int(100 * alpha))
            _, Y = sample_gaussian_pair(d, Sigma, n, rngY)

            wpp_hat, lb_hat = run_single_trial_curve(
                X=X, Y=Y, k_values=k_values,
                rbcd=rbcd, sw_L=sw_L,
                seed=30000 + 999 * t + int(100 * alpha)
            )

            raw_emp[alpha]['wpp'].append(wpp_hat)
            raw_emp[alpha]['lb'].append(lb_hat)

    # 3. Compute Error Components for Plotting
    # Structure: results[alpha]['wpp']['pop_err'] -> array
    results = {}

    for alpha in alphas:
        W2_true = pop_curves[alpha]['true_w2']
        W_pop_wpp = pop_curves[alpha]['wpp']
        W_pop_lb = pop_curves[alpha]['lb']
        
        # Convert empirical lists to arrays (n_trials, n_k)
        W_emp_wpp = np.array(raw_emp[alpha]['wpp'])
        W_emp_lb = np.array(raw_emp[alpha]['lb'])
        
        # --- Calculate Errors ---
        
        # 1. Population Error (Bias): |True - Pop|
        # This is constant across trials
        pop_err_wpp = np.abs(W2_true - W_pop_wpp)
        pop_err_lb = np.abs(W2_true - W_pop_lb)
        
        # 2. Finite Sample Error (Variance/Estimation): Mean(|Pop - Emp|)
        # Note: Broadcasting W_pop (1D) against W_emp (2D)
        fs_err_wpp = np.mean(np.abs(W_pop_wpp - W_emp_wpp), axis=0)
        fs_err_lb = np.mean(np.abs(W_pop_lb - W_emp_lb), axis=0)
        
        # 3. Total Error: Mean(|True - Emp|)
        total_err_wpp = np.mean(np.abs(W2_true - W_emp_wpp), axis=0)
        total_err_lb = np.mean(np.abs(W2_true - W_emp_lb), axis=0)
        
        results[alpha] = {
            'wpp': {
                'pop_err': pop_err_wpp,
                'fs_err': fs_err_wpp,
                'total_err': total_err_wpp
            },
            'lb': {
                'pop_err': pop_err_lb,
                'fs_err': fs_err_lb,
                'total_err': total_err_lb
            }
        }

    # 4. Plot
    plot_trade_off_graph(results, k_values, save_path="finite_sample_tradeoff_clean.png")

if __name__ == "__main__":
    main()