"""
Implementation of the RKHS (Reproducing Kernel Hilbert Space) experiments from Appendix B.

This script mirrors the ESD with various alignment from Appendix A but in the RKHS setting. 

Key Components:
- Uses periodic cosine basis functions to create analytical eigenfunctions
- Handles design-induced variance calculations per Proposition B.3 in the paper
- Implements common random numbers for noise to reduce Monte Carlo variance
- Calculates ESD values and empirical risk along an alpha-misalignment path
- Produces visualization comparing optimal risk to ESD-based bounds
- Outputs results as a PDF figure

Usage:
python Appendix_B_RKHS.py
"""


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams

# Import our modules
from esd_modular_functions import compute_esd

# Set plot parameters
rcParams['figure.figsize'] = (5, 5)
rcParams['font.size'] = 12

# --- 1  Global simulation parameters ----------------------------------------

n = 200  # sample size  (matches linear experiment)
J = 800  # number of eigenfunctions kept
D = 80   # number of coordinates whose eigenvalues are warped by α
B = 10   # Monte‑Carlo replicates (same as linear code)

alpha_grid = np.linspace(0, 30, 11)   # alignment‑severity path

sigma0 = 1
sigma0_sq = sigma0**2   # plain observation‑noise variance

# Base (un‑aligned) spectrum λ_{j,0}=j^{−s} with mild decay so that
# the α‑warping dominates the behaviour.
s = 1.1
lambda0 = (np.arange(1, J+1))**(-s)

# Alignment weights  t_j  in [−½,½]
t_vec = np.concatenate([np.linspace(0, 1, D), np.zeros(J-D)])

# Signal coefficients θ*_j (smooth)
beta = 4
theta0 = sigma0 / (np.arange(1, J+1))**(beta)

# --- 2  Analytic eigenfunctions & feature matrix Φ --------------------------

# Periodic cosine basis ψ_j(x)=√2 cos(2π j x) on [0,1].
# We scale by √n later so that ΦᵀΦ≈I.

np.random.seed(1)
x = np.random.uniform(0, 1, n)
# x = np.linspace(0, 1, n)   # deterministic grid also works

def psi(j, x):
    return np.sqrt(2) * np.cos(2 * np.pi * j * x)

Phi = np.zeros((n, J))
for j in range(1, J+1):
    Phi[:, j-1] = psi(j, x)  # √n‑scaling ⇒ ΦᵀΦ≈I

# True function values f*(x_i)  (fixed across α)
f_true = Phi @ theta0  # length‑n vector

# ---------------------------------------------------------------------------
# 2b  Design‑induced variance   σ_{f,4}²  and effective noise σ²
# ---------------------------------------------------------------------------
# τ_j² = Var( f*(X) ψ_j(X) ).  We estimate each τ_j² empirically and take
# the supremum across j, as used in Proposition B.4.

fpsi_mat = Phi * f_true[:, np.newaxis]  # n × J matrix  f⋅ψ_j
centre_mat = fpsi_mat - np.mean(fpsi_mat, axis=0)
tau_j_sq = np.mean(centre_mat**2, axis=0)  # empirical Var per j
sigma_f4_sq = np.max(tau_j_sq)  # ⟨sup_j τ_j²⟩

sigma_eff_sq = (sigma0_sq + sigma_f4_sq)  # effective per‑coordinate noise

print(f"Design‑induced variance σ_{{f,4}}² ≈ {sigma_f4_sq:.4f}  →  σ² ≈ {sigma_eff_sq:.4g}")

# --- 3  Common‑random‑numbers noise matrix ----------------------------------

np.random.seed(42)  # For reproducibility
eps_mat = np.random.normal(0, sigma0, size=(n, B))  # n × B

# --- 4  Helper: cumulative‑sum transformer ----------------------------------

def cum_col(M):
    """Cumulative sum transformer for columns.
    
    Given n×J matrix M with columns m_j, return n×J matrix where
    column k is Σ_{j≤k} m_j.
    """
    J = M.shape[1]
    L = np.zeros((J, J))
    L[np.triu_indices(J)] = 1
    return M @ L

# --- 5  Main loop over α -----------------------------------------------------

results = pd.DataFrame(columns=["alpha", "esd", "k_opt", "risk"])

for alpha in alpha_grid:
    # 5.1  Current spectrum & ordering ----------------------------------------
    lambda_alpha = lambda0 * np.exp(alpha * t_vec)
    ord_idx = np.argsort(lambda_alpha)[::-1]  # Sort in decreasing order
    lambda_sorted = lambda_alpha[ord_idx]
    theta_sorted = theta0[ord_idx]
    phi_sorted = Phi[:, ord_idx]  # n × J, permuted columns
    
    # 5.2  Effective Span Dimension (deterministic)
    esd_k = compute_esd(theta0, lambda_alpha, sigma_eff_sq/n)
    
    # 5.3  Monte‑Carlo risk estimation ----------------------------------------
    loss_mat = np.full((J, B), np.nan)  # k × B
    
    for b in range(B):
        y_b = f_true + eps_mat[:, b]  # response vector
        
        # coefficient estimates  θ̂_j = Φᵀ y  (n‑scaled inner products)
        theta_hat_full = (Phi.T @ y_b) / n  # length J, *unsorted*
        theta_hat_sorted = theta_hat_full[ord_idx]
        
        # Create matrix for truncation at different k values
        theta_hat_sorted_mat = np.zeros((J, J))
        for i in range(J):
            theta_hat_sorted_mat[:i+1, i] = theta_hat_sorted[:i+1]
        
        # Compute error and loss
        err_mat = theta_hat_sorted_mat - theta_sorted[:, np.newaxis]
        loss_mat[:, b] = np.sum(err_mat**2, axis=0)
    
    # After calculating risk_k, add standard error calculation
    risk_k = np.mean(loss_mat, axis=1)  # average over B replicates
    risk_k_se = np.std(loss_mat, axis=1) / np.sqrt(B)  # Standard error
    k_star = np.argmin(risk_k)
    min_risk = risk_k[k_star]  # Get the minimum risk value
    min_risk_se = risk_k_se[k_star]  # Get SE at the minimum risk
    
    # Append to results (add standard error)
    results = pd.concat([results, pd.DataFrame({
        "alpha": [alpha],
        "esd": [esd_k],
        "k_opt": [k_star + 1],  # +1 to match R's 1-indexing
        "risk": [min_risk],  # Use the minimum risk value
        "risk_se": [min_risk_se]  # Use the SE at the minimum risk
    })], ignore_index=True)

# --- 6  Visualisation --------------------------------------------------------

# Define colors, line types, and labels first
colors = {
    "risk_curve": "#D55E00",        # vermillion
    "esd_sigma0_curve": "#0072B2",  # blue
    "esd_sigmaeff_curve": "#009E73"  # green
}

linetypes = {
    "risk_curve": "solid",
    "esd_sigma0_curve": "dashed",
    "esd_sigmaeff_curve": "dotted"
}

labels = {
    "risk_curve": "optimal risk",
    "esd_sigma0_curve": "(ESD-1) · σ₀²/n",
    "esd_sigmaeff_curve": "ESD · σ_eff²/n"
}

# Prepare data for plotting (include standard error)
plot_df = pd.DataFrame({
    "alpha": np.tile(results["alpha"], 3),
    "value": np.concatenate([
        results["risk"],
        (results["esd"]-1) * sigma0_sq / n,
        2 * results["esd"] * sigma_eff_sq / n
    ]),
    "curve": np.repeat(["risk_curve", "esd_sigma0_curve", "esd_sigmaeff_curve"], len(results)),
    "error": np.concatenate([
        results["risk_se"],
        np.zeros(len(results)),
        np.zeros(len(results))
    ])
})

# Create plot with error bars
plt.figure(figsize=(5, 5))

for curve in ["risk_curve", "esd_sigma0_curve", "esd_sigmaeff_curve"]:
    subset = plot_df[plot_df["curve"] == curve]
    if curve == "risk_curve":
        plt.errorbar(subset["alpha"], subset["value"], yerr=subset["error"],
                    color=colors[curve], linestyle=linetypes[curve],
                    linewidth=1.1, capsize=1, label=labels[curve])
    else:
        plt.plot(subset["alpha"], subset["value"], 
                color=colors[curve], linestyle=linetypes[curve],
                linewidth=1.1, label=labels[curve])

plt.xlabel(r"$\alpha$")
plt.ylabel("Risk  &  ESD‑based bounds")
plt.legend(loc="upper left")
plt.grid(True, linestyle='--', alpha=0.7)
plt.ylim(bottom=0)  # Set y-axis to start from 0
plt.tight_layout()

# Save plot
plt.savefig("risk_vs_alpha_rkhs.pdf")
plt.close()

print("Plot saved as risk_vs_alpha_rkhs.pdf")

# For debugging
print(f"Alpha: {alpha}, ESD: {esd_k}, k_opt: {k_star+1}, min_risk: {min_risk}")