"""
High-dimensional Toy experiment: Verify phase transition and curse of dimensionality in high dimensions
Includes complete baselines: Scratch, Weighted, Concat, Residual
Modification: Introduced Validation Selection mechanism to prevent overfitting/explosion in HD
"""
import numpy as np
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split
import sys, os

sys.path.append("..")
from utils.seed_utils import set_global_seed
from utils.logging_utils import CSVLogger
from utils.plot_utils import plot_learning_curves

def true_hd(X):
    return np.sin(2*np.pi*X[:,0]) + 0.5*np.cos(2*np.pi*X[:,1])

def bias_hd(X):
    return np.cos(3*np.pi*X[:,0])

def get_f0(delta, dim):
    rng = np.random.default_rng(0)
    X_sample = rng.uniform(0, 1, (10000, dim))
    norm = np.sqrt(np.mean(bias_hd(X_sample)**2))
    scale = delta / norm if norm > 0 else 0
    def f0(X):
        return true_hd(X) + scale * bias_hd(X)
    return f0

class RFF:
    """Random Fourier Features implementation"""
    def __init__(self, input_dim, n_features, rng):
        self.W = rng.normal(0, 1, (input_dim, n_features))
        self.b = rng.uniform(0, 2*np.pi, n_features)
        self.scale = np.sqrt(2/n_features)
    def transform(self, X):
        return self.scale * np.cos(X @ self.W + self.b)

def run_hd_exp():
    set_global_seed(42)
    dim = 20
    n_list = [100, 200, 500, 1000, 2000]
    delta_list = [0.0, 0.2, 0.5, 1.0]
    n_rep = 10
    sigma = 0.5
    
    logger = CSVLogger("outputs/toy_hd", "toy_hd_results.csv")
    results = {}
    
    print(f"Running HD Experiment (d={dim}, With Safe Selection)...")
    
    alphas = [1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0]

    for n in n_list:
        for delta in delta_list:
            metrics = {k: [] for k in ["scratch", "bb_only", "weighted", "concat", "residual_raw", "residual_safe", "fallback"]}
            
            for i in range(n_rep):
                rng = np.random.default_rng(i*100 + n)
                
                X_full = rng.uniform(0, 1, (n, dim))
                Y_full = true_hd(X_full) + rng.normal(0, sigma, n)
                f0 = get_f0(delta, dim)
                
                # --- Split validation set ---
                X_train, X_val, Y_train, Y_val = train_test_split(X_full, Y_full, test_size=0.2, random_state=i)

                X_test = rng.uniform(0, 1, (2000, dim))
                Y_test = true_hd(X_test)
                
                f0_train = f0(X_train)
                f0_val = f0(X_val)
                f0_test = f0(X_test)
                
                rff = RFF(dim, 512, rng)
                Z_train = rff.transform(X_train)
                Z_val = rff.transform(X_val)
                Z_test = rff.transform(X_test)
                
                # 1. BB Only
                metrics["bb_only"].append(np.mean((f0_test - Y_test)**2))

                # 2. Scratch
                model_s = RidgeCV(alphas=alphas).fit(Z_train, Y_train)
                p_s = model_s.predict(Z_test)
                metrics["scratch"].append(np.mean((p_s - Y_test)**2))

                # 3. Weighted (Select on Val)
                p_s_val = model_s.predict(Z_val)
                best_w = 0
                best_loss = float('inf')
                for w in np.linspace(0, 1, 21):
                    loss = np.mean((Y_val - (w * f0_val + (1-w) * p_s_val))**2)
                    if loss < best_loss:
                        best_loss = loss
                        best_w = w
                
                p_s_test = model_s.predict(Z_test)
                p_weighted = best_w * f0_test + (1-best_w) * p_s_test
                metrics["weighted"].append(np.mean((p_weighted - Y_test)**2))
                
                # 4. Concat
                Z_train_aug = np.hstack([Z_train, f0_train.reshape(-1, 1)])
                Z_test_aug = np.hstack([Z_test, f0_test.reshape(-1, 1)])
                model_c = RidgeCV(alphas=alphas).fit(Z_train_aug, Y_train)
                p_concat = model_c.predict(Z_test_aug)
                metrics["concat"].append(np.mean((p_concat - Y_test)**2))

                # 5. Residual (Raw & Safe)
                resid_train = Y_train - f0_train
                model_r = RidgeCV(alphas=alphas).fit(Z_train, resid_train)
                
                # Raw Evaluation
                p_res_raw = f0_test + model_r.predict(Z_test)
                metrics["residual_raw"].append(np.mean((p_res_raw - Y_test)**2))
                
                # Safe Selection
                pred_res_val = model_r.predict(Z_val)
                mse_res_val = np.mean((Y_val - (f0_val + pred_res_val))**2)
                mse_bb_val = np.mean((Y_val - f0_val)**2)
                
                if mse_res_val < mse_bb_val:
                    # Select Raw
                    metrics["residual_safe"].append(np.mean((p_res_raw - Y_test)**2))
                    metrics["fallback"].append(0)
                else:
                    # Select BB
                    metrics["residual_safe"].append(np.mean((f0_test - Y_test)**2))
                    metrics["fallback"].append(1)

            avg = {k: np.mean(v) for k, v in metrics.items()}
            print(f"n={n}, delta={delta} -> Raw: {avg['residual_raw']:.4f} | Safe: {avg['residual_safe']:.4f} | Fallback: {avg['fallback']*100:.1f}%")
            row = {"n": n, "delta": delta}
            row.update(avg)
            logger.log(row)
            results[(n, delta)] = avg
            
    logger.save()
    plot_learning_curves(results, n_list, delta_list, ["scratch", "bb_only", "weighted", "concat", "residual_raw", "residual_safe"], 
                        title_prefix="Toy HD", save_path="outputs/toy_hd/curves")

if __name__ == "__main__":
    run_hd_exp()