"""
Ablation Study: Sensitivity to Validation Fraction rho
"""
import numpy as np
import pandas as pd
from sklearn.kernel_ridge import KernelRidge
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import sys, os

# Add project root to path
sys.path.append("..")
from toy.toy_1d import generate_data

def run_rho_ablation():
    # Set parameters
    n = 200          # Fixed sample size
    delta = 0.2      # Fixed black-box error (Sample-Dominated Regime)
    sigma = 0.2      # Noise
    n_rep = 50       # Number of repetitions
    rho_list = [0.05, 0.1, 0.2, 0.3, 0.5, 0.7, 0.9] # Scan range
    
    results = {rho: [] for rho in rho_list}
    
    print(f"Running Rho Ablation (n={n}, delta={delta})...")
    
    # KRR parameter grid
    param_grid = {"alpha": [1e-4, 1e-3, 1e-2, 1e-1], "gamma": [1, 5, 10, 20]}
    
    for i in range(n_rep):
        rng = np.random.default_rng(i)
        # Generate data
        X_full, Y_full, f0 = generate_data(n, sigma, delta, rng)
        
        # Test set
        X_test = np.linspace(0, 1, 1000).reshape(-1, 1)
        # True f*(x)
        Y_test_true = (np.sin(4 * np.pi * X_test) * np.cos(2 * np.pi * X_test) + 0.5 * X_test).ravel()
        f0_test = f0(X_test).ravel()
        
        for rho in rho_list:
            # Split Train/Val based on rho
            # test_size = rho
            X_train, X_val, Y_train, Y_val = train_test_split(
                X_full, Y_full, test_size=rho, random_state=i
            )
            f0_train, f0_val = f0(X_train).ravel(), f0(X_val).ravel()
            
            # --- Residual Learning ---
            Z_train = Y_train - f0_train
            model_r = GridSearchCV(KernelRidge(kernel='rbf'), param_grid, cv=3, n_jobs=-1)
            model_r.fit(X_train, Z_train)
            
            # Raw Prediction
            p_res_raw = f0_test + model_r.predict(X_test)
            
            # Safe Selection Logic
            pred_res_val = model_r.predict(X_val)
            mse_res_val = np.mean((Y_val - (f0_val + pred_res_val))**2)
            mse_bb_val = np.mean((Y_val - f0_val)**2)
            
            # Selection
            if mse_res_val < mse_bb_val:
                final_pred = p_res_raw
            else:
                final_pred = f0_test
                
            # Calculate Risk (MSE against True Function)
            risk = mean_squared_error(Y_test_true, final_pred)
            results[rho].append(risk)

    # Plotting
    output_dir = "outputs/ablation"
    os.makedirs(output_dir, exist_ok=True)
    
    avg_risks = [np.mean(results[rho]) for rho in rho_list]
    std_risks = [np.std(results[rho]) / np.sqrt(n_rep) for rho in rho_list] # Standard Error

    # Set font size
    plt.rcParams['font.size'] = 12
    plt.rcParams['axes.labelsize'] = 14
    plt.rcParams['axes.titlesize'] = 16
    plt.rcParams['legend.fontsize'] = 12
    plt.rcParams['xtick.labelsize'] = 11
    plt.rcParams['ytick.labelsize'] = 11

    plt.figure(figsize=(8, 5))
    plt.errorbar(rho_list, avg_risks, yerr=std_risks, fmt='-o', capsize=5, label='Residual (Safe)')

    # Mark theoretically optimal region (around 0.2)
    plt.axvspan(0.15, 0.3, color='green', alpha=0.1, label='Stable Region')

    plt.xlabel(r'Validation Fraction $\rho$')
    plt.ylabel('MSE Risk')
    plt.title(f'Sensitivity to Validation Fraction (n={n}, $\delta$={delta})')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    
    save_path = os.path.join(output_dir, 'ablation_rho.png')
    plt.savefig(save_path, dpi=300)
    print(f"Saved plot to {save_path}")
    
    # Save data
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(output_dir, 'ablation_rho.csv'), index=False)

if __name__ == "__main__":
    run_rho_ablation()
