"""
1D Toy experiment: Verify deep model assisted statistical inference
Includes complete baselines: Scratch, BB_Only, Weighted, Concat, Residual
Modification: Introduced Validation Selection mechanism to eliminate negative transfer when delta=0
"""

import numpy as np
import pandas as pd
from sklearn.kernel_ridge import KernelRidge
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import mean_squared_error
import sys, os

sys.path.append("..")
from utils.seed_utils import set_global_seed
from utils.logging_utils import CSVLogger
from utils.plot_utils import plot_learning_curves, plot_phase_diagram

def true_function(x):
    """True function"""
    return np.sin(4 * np.pi * x) * np.cos(2 * np.pi * x) + 0.5 * x

def systematic_bias(x):
    """Systematic bias function"""
    return 0.8 * (x**2) - 0.5 * np.cos(2 * np.pi * x)

def get_f0(delta):
    """Construct black-box f0"""
    grid = np.linspace(0, 1, 10000)
    bias_vals = systematic_bias(grid)
    norm = np.sqrt(np.mean(bias_vals**2))
    if norm < 1e-9: norm = 1.0
    scale = delta / norm
    
    def f0(x):
        return true_function(x) + scale * systematic_bias(x)
    return f0

def generate_data(n, sigma, delta, rng):
    X = rng.uniform(0, 1, n).reshape(-1, 1)
    Y = true_function(X).ravel() + rng.normal(0, sigma, n)
    f0 = get_f0(delta)
    return X, Y, f0

def run_experiment():
    set_global_seed(42)
    n_list = [20, 50, 100, 200, 500, 1000, 2000]
    delta_list = [0.0, 0.1, 0.3, 0.8, 1.5]
    n_rep = 20
    sigma = 0.2
    
    # Initialize logger
    logger = CSVLogger("outputs/toy_1d", "toy_1d_results.csv")
    results = {}
    
    print("Running 1D Toy Experiment (With Explicit Safe Selection Reporting)...")
    
    param_grid = {"alpha": [1e-4, 1e-3, 1e-2, 1e-1], "gamma": [1, 5, 10, 20]}

    for n in n_list:
        for delta in delta_list:
            # Include residual_raw, residual_safe, fallback_rate
            metrics = {k: [] for k in ["scratch", "bb_only", "weighted", "concat", "residual_raw", "residual_safe", "fallback"]}
            
            for i in range(n_rep):
                rng = np.random.default_rng(i*100 + n)
                X_full, Y_full, f0 = generate_data(n, sigma, delta, rng)
                
                # Split Train/Val
                test_size = 0.5 if n <= 20 else 0.2
                X_train, X_val, Y_train, Y_val = train_test_split(X_full, Y_full, test_size=test_size, random_state=i)
                
                f0_train, f0_val = f0(X_train).ravel(), f0(X_val).ravel()
                X_test = np.linspace(0, 1, 1000).reshape(-1, 1)
                Y_test = true_function(X_test).ravel()
                f0_test = f0(X_test).ravel()
                
                # BB Only
                metrics["bb_only"].append(mean_squared_error(Y_test, f0_test))
                
                # Scratch (Learn from data)
                model_s = GridSearchCV(KernelRidge(kernel='rbf'), param_grid, cv=3, n_jobs=-1)
                model_s.fit(X_train, Y_train) # Use sub-train only
                p_s = model_s.predict(X_test)
                metrics["scratch"].append(mean_squared_error(Y_test, p_s))
                
                # Weighted (Select best w on Val)
                p_s_val = model_s.predict(X_val)
                best_w = 0
                best_val_loss = float('inf')
                # Search for optimal weight on validation set
                for w in np.linspace(0, 1, 21):
                    loss = np.mean((Y_val - (w * f0_val + (1-w) * p_s_val))**2)
                    if loss < best_val_loss:
                        best_val_loss = loss
                        best_w = w
                
                p_s_test = model_s.predict(X_test)
                p_weighted = best_w * f0_test + (1-best_w) * p_s_test
                metrics["weighted"].append(mean_squared_error(Y_test, p_weighted))

                # Concat
                X_train_aug = np.hstack([X_train, f0_train.reshape(-1, 1)])
                X_test_aug = np.hstack([X_test, f0_test.reshape(-1, 1)])
                model_c = GridSearchCV(KernelRidge(kernel='rbf'), param_grid, cv=3, n_jobs=-1)
                model_c.fit(X_train_aug, Y_train)
                p_concat = model_c.predict(X_test_aug)
                metrics["concat"].append(mean_squared_error(Y_test, p_concat))
                
                # --- 5. Residual (Raw & Safe) ---
                # A. Train residual on sub-train
                Z_train = Y_train - f0_train
                model_r = GridSearchCV(KernelRidge(kernel='rbf'), param_grid, cv=3, n_jobs=-1)
                model_r.fit(X_train, Z_train)
                
                # B. Predict Raw (Always use residual)
                p_res_raw = f0_test + model_r.predict(X_test)
                metrics["residual_raw"].append(mean_squared_error(Y_test, p_res_raw))
                
                # C. Safe Selection (Validation)
                pred_res_val = model_r.predict(X_val)
                mse_res_val = np.mean((Y_val - (f0_val + pred_res_val))**2)
                mse_bb_val = np.mean((Y_val - f0_val)**2)
                
                # D. Apply Decision
                if mse_res_val < mse_bb_val:
                    # Accept Residual
                    metrics["residual_safe"].append(mean_squared_error(Y_test, p_res_raw))
                    metrics["fallback"].append(0) # No fallback
                else:
                    # Fallback to BB
                    metrics["residual_safe"].append(mean_squared_error(Y_test, f0_test))
                    metrics["fallback"].append(1) # Fallback triggered
                
            # Compute averages
            avg = {k: np.mean(v) for k, v in metrics.items()}
            print(f"n={n}, d={delta} | Raw: {avg['residual_raw']:.4f} | Safe: {avg['residual_safe']:.4f} | Fallback: {avg['fallback']*100:.1f}%")
            
            row = {"n": n, "delta": delta}
            row.update(avg)
            logger.log(row)
            results[(n, delta)] = avg
            
    logger.save()
    plot_learning_curves(results, n_list, delta_list, ["scratch", "bb_only", "weighted", "concat", "residual_raw", "residual_safe"], 
                        title_prefix="1D Toy", save_path="outputs/toy_1d/curves")

if __name__ == "__main__":
    run_experiment()