"""
Implementation of the ESD with various alignment in linear models from Appendix A.

This script tests the relationship between Effective Span Dimension (ESD) and
prediction risk in linear models. It implements two different design matrix
scenarios and calculates empirical risk and ESD for each alignment parameter value.

Key Components:
- Tests two design scenarios through CaseID parameter (1: geometric decay, 2: logarithmic decay)
- Uses Monte Carlo replication with common random numbers across all alpha values
- Compares empirical prediction risk with ESD values along an alpha path
- Produces visualization showing strong correlation between ESD and rescaled risk
- Outputs results as 'risk_vs_alpha_case_{CaseID}.pdf'

Usage:
python Appendix_A_LinearModel.py
"""

CaseID = 1 # geometric decay example
CaseID = 2 # logarithmic decay example

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams

# Import our modules
from esd_modular_functions import compute_esd
from linear_model_modular_functions import (
    make_design, make_beta, cache_svd, translate_theta_lambda,
    build_A_diag_exp, apply_alignment, pcr_factory
)

# Set plot parameters
rcParams['figure.figsize'] = (5, 5)
rcParams['font.size'] = 12

# Parameters
n = 300        # samples
p = 400        # features
sig2 = 1       # noise variance σ₀²
B = 20         # MC replicates



if CaseID == 1:
    X0 = make_design(n, p, par=0.95)   # frozen design
    beta_star = make_beta(p, power=0.2)     # dense, front‑loaded signal
    alpha_seq = np.linspace(0, 30, 10)
else:
    X0 = make_design(n, p, par="log")
    beta_star = make_beta(p, type="log")
    alpha_seq = np.linspace(0, 10, 10)

# Pre‑generate ε matrix: n × B, same across all α
np.random.seed(42)  # For reproducibility
eps_mat = np.random.normal(0, np.sqrt(sig2), size=(n, B))

# Initialize results dataframe
risk_path = pd.DataFrame(columns=["alpha", "esd", "k_opt", "risk"])

for a in alpha_seq:
    A = build_A_diag_exp(p, a)
    align = apply_alignment(X0, beta_star, A)
    
    # SVD cache & ESD (deterministic for this α)
    svd_A = cache_svd(align["X_A"])
    tl_A = translate_theta_lambda(svd_A, align["beta_A"])
    esd_k = compute_esd(tl_A["theta"], tl_A["lambda"], sig2 / n)
    
    # Oracle PCR factory
    pcr_A = pcr_factory(align["X_A"])
    
    # ---------- MC loop with CRN ----------
    loss_mat = np.full((pcr_A["r"], B), np.nan)  # k × B
    
    for b in range(B):
        y_b = align["X_A"] @ align["beta_A"] + eps_mat[:, b]  # same ε column for *all* α
        
        # losses for every k (vectorised)
        beta_mat = pcr_A["fit_all"](y_b)  # p × r
        pred_err = align["X_A"] @ (beta_mat - align["beta_A"][:, np.newaxis])  # n × r
        loss_mat[:, b] = np.mean(pred_err**2, axis=0)  # length = r
    
    # After calculating risk_k, add standard error calculation
    risk_k = np.mean(loss_mat, axis=1)  # Monte‑Carlo risk for each k
    risk_k_se = np.std(loss_mat, axis=1) / np.sqrt(B)  # Standard error
    k_star = np.argmin(risk_k)
    
    # Append to dataframe (add standard error)
    risk_path = pd.concat([risk_path, pd.DataFrame({
        "alpha": [a],
        "esd": [esd_k],
        "k_opt": [k_star + 1],  # +1 to match R's 1-indexing
        "risk": [risk_k[k_star]],
        "risk_se": [risk_k_se[k_star]]
    })], ignore_index=True)

print(risk_path.head())

# Prepare data in long format (include standard error)
plot_df = pd.DataFrame({
    "alpha": np.concatenate([risk_path["alpha"], risk_path["alpha"]]),
    "value": np.concatenate([risk_path["esd"], risk_path["risk"] * n / sig2]),
    "curve": np.repeat(["ESD", "Rescaled Risk"], len(risk_path)),
    "error": np.concatenate([np.zeros(len(risk_path)), risk_path["risk_se"] * n / sig2])
})

# Create plot with error bars
plt.figure(figsize=(5, 5))
for curve, style in zip(["ESD", "Rescaled Risk"], ["solid", "dashed"]):
    subset = plot_df[plot_df["curve"] == curve]
    if curve == "Rescaled Risk":
        plt.errorbar(subset["alpha"], subset["value"], yerr=subset["error"],
                    linestyle=style, color="firebrick", linewidth=1, 
                    capsize=3, label=curve)
    else:
        plt.plot(subset["alpha"], subset["value"], 
                linestyle=style, color="steelblue", linewidth=1, label=curve)

plt.xlabel(r"$\alpha$")
plt.ylabel(r"ESD and (risk $\cdot$ n / $\sigma^2$)")
plt.legend(loc="lower right")
plt.ylim(bottom=0)  # Set y-axis to start from 0
plt.tight_layout()

# Save plot
plt.savefig(f"risk_vs_alpha_case_{CaseID}.pdf")
plt.close()

print(f"Plot saved as risk_vs_alpha_case_{CaseID}.pdf")