# script/paper_experiment/main_experiment/run_linear_optuna.py
# -*- coding: utf-8 -*-
"""
Optuna + Lambda sweep for a single regression dataset
across multiple methods (clari_tree, greedy, streed, streed_sl, guide).

Input:
  data/{dataset}/splits/outer_0..5/{train.csv, test.csv}

For each method, tree depth, and regularization parameter:
  - clari_tree / greedy:
        sweep over cost_complexity (lambda) values.
        Each lambda is tuned with Optuna for ridge_penalty.
  - streed / streed_sl:
        sweep over cost_complexity (lambda) values.
        Each lambda is tuned with Optuna for ridge_penalty and lasso_penalty.
        Threshold search strategy is controlled by --threshold_mode:
            * "threshold": uses small fixed number of thresholds (stride / subsampling)
            * "full":   uses all available thresholds (n_thresholds = n_train)
  - guide:
        sweep over max_nodes values (independent of threshold_mode; threshold_mode should be changed in the processor).

All methods are evaluated on training and test splits.
Metrics and timing are recorded to CSV.

Output:
  ./results/{prefix}_depth{D}/{dataset}/results_linear_optuna_d{D}_outer{fold}.csv
  or a single ..._all.csv if outer_id is None.

CSV columns:
  dataset, outer, method, depth, lambda,
  ridge_penalty, lasso_penalty,
  leaves, r2_train, r2_test, mse_train, mse_test, train_time_s

At the end, per-(method,lambda) mean and std rows are appended
if multiple outer folds are present (outer_id is None).

Notes:
  - GUIDE uses its own R backend and does not depend on --threshold_mode.
"""

from __future__ import annotations
from pathlib import Path
import argparse
import csv
import time
import os
import numpy as np
from tqdm import tqdm
import optuna
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import pandas as pd

# --- Processors and loaders ---
from script.processors.streed import STreeDProcessor          # linear STreeD
from script.processors.claritree import (
    CLARITreeProcessor,
    GreedyProcessor,
)
from script.utils.dataio import load_xy_claritree
from script.utils.utils import scorer_r2, scorer_mse
from script.processors.guide import GuideProcessor
from script.processors.guide_utils import (
    update_guide_model_r,
    run_rscript,
    parse_r2,
    parse_guide_train_r2_and_elapsed,
    find_guide_training_out,
)


GUIDE_DATASET_ALIASES: dict[str, str] = {
    # Short aliases to keep GUIDE R code file paths safely under its length limit.
    "temperature_min": "t_min",
    "temperature_max": "t_max",
    "california_housing": "ch",
}


def auto_n_jobs(max_cap: int = 8) -> int:
    ncpu = os.cpu_count() or 1
    return max(1, min(ncpu // 2, max_cap))


def time_block(fn):
    """Run fn() and measure execution time (seconds)."""
    t0 = time.perf_counter()
    out = fn()
    t1 = time.perf_counter()
    return out, (t1 - t0)


def standardize_and_center(X_tr, y_tr, X_val=None, y_val=None, X_te=None, y_te=None):
    """
    Standardize X and center y using training statistics to avoid data leakage.
    
    Supports three-way split: Train (fit model), Valid (Optuna tuning), Test (final eval).
    
    Note: This function assumes X does NOT contain an intercept column.
    For CLARITree/Greedy, the intercept column should be added AFTER standardization.
    For STreeD, use the standardized data as-is (no intercept needed).
    
    Args:
        X_tr: Training features (numpy array, no intercept column)
        y_tr: Training targets (numpy array)
        X_val: Validation features (optional, numpy array, no intercept column)
        y_val: Validation targets (optional, numpy array)
        X_te: Test features (optional, numpy array, no intercept column)
        y_te: Test targets (optional, numpy array)
    
    Returns:
        Tuple of (X_tr_scaled, y_tr_centered, X_val_scaled, y_val_centered, X_te_scaled, y_te_centered)
        If X_val/y_val are None, returns None for validation components.
    """
    # Standardize X using training statistics
    scaler_X = StandardScaler()
    X_tr_scaled = scaler_X.fit_transform(X_tr)
    
    # Center y to have mean zero using training mean
    y_tr_mean = np.mean(y_tr)
    y_tr_centered = y_tr - y_tr_mean
    
    # Transform validation set if provided
    if X_val is not None and y_val is not None:
        X_val_scaled = scaler_X.transform(X_val)  # Use training statistics
        y_val_centered = y_val - y_tr_mean  # Use training mean
    else:
        X_val_scaled = None
        y_val_centered = None
    
    # Transform test set if provided
    if X_te is not None and y_te is not None:
        X_te_scaled = scaler_X.transform(X_te)  # Use training statistics
        y_te_centered = y_te - y_tr_mean  # Use training mean
    else:
        X_te_scaled = None
        y_te_centered = None
    
    if X_val is not None:
        return X_tr_scaled, y_tr_centered, X_val_scaled, y_val_centered, X_te_scaled, y_te_centered
    else:
        return X_tr_scaled, y_tr_centered, X_te_scaled, y_te_centered


def get_processors():
    """Return the processors we want to test (linear)."""
    return {
        "clari_tree": CLARITreeProcessor(),
        "greedy":   GreedyProcessor(),
        "streed":     STreeDProcessor(),
        "streed_sl":  STreeDProcessor(),
        "guide":      GuideProcessor(),
    }


def build_params(
    depth: int,
    lam: float,
    method: str,
    n_train: int,
    threshold_mode: str,
) -> dict:
    """Build hyperparameter dictionary for linear processors."""
    params: dict = {
        "depth": int(depth),
        "cost_complexity": float(lam),
    }

    # STreeD (full / threshold control via n_thresholds)
    if method == "streed":
        params["ridge_penalty"] = params.get("ridge_penalty", 1e-5)
        params["lasso_penalty"] = params.get("lasso_penalty", 0.0)
        params["n_thresholds"] = n_train if threshold_mode == "full" else 20

    # Simple STreeD variant
    elif method == "streed_sl":
        params["ridge_penalty"] = params.get("ridge_penalty", 1e-5)
        params["lasso_penalty"] = params.get("lasso_penalty", 0.0)
        params["n_thresholds"] = n_train if threshold_mode == "full" else 20
        params["simple"] = True

    # CLARITree / Greedy: use stride vs full search
    elif method in {"clari_tree", "greedy"}:
        params["ridge_penalty"] = 1e-5
        params["lasso_penalty"] = 0.0
        # full: stride=1 (all thresholds), threshold: roughly 20 thresholds
        params["stride"] = 1 if threshold_mode == "full" else max(1, n_train // 20)

    else:
        # default for other methods (e.g., guide is handled separately)
        params["ridge_penalty"] = 1e-5
        params["lasso_penalty"] = 0.0

    return params


def fit_and_eval(
    proc,
    method: str,
    X_tr,
    y_tr,
    depth: int,
    lam: float,
    X_val=None,
    y_val=None,
    X_te=None,
    y_te=None,
    trial=None,
    override_params: dict | None = None,
    threshold_mode: str = "threshold",
):
    """
    Train model and evaluate metrics on train/valid/test (non-GUIDE methods).
    
    Note: X_tr, X_val, X_te should be standardized but WITHOUT intercept column.
    For CLARITree/Greedy, intercept column is added here.
    For STreeD, data is used as-is.
    
    Args:
        X_tr, y_tr: Training data (for fitting model parameters)
        X_val, y_val: Validation data (for Optuna hyperparameter tuning, optional)
        X_te, y_te: Test data (for final evaluation, optional)
    """
    # CLARITree/Greedy need intercept column (first column all 1s)
    # STreeD handles intercept itself, so don't add it
    if method in ["clari_tree", "greedy"]:
        # Check if intercept column already exists (first column all 1s)
        has_intercept_tr = (X_tr.shape[1] > 0 and np.allclose(X_tr[:, 0], 1.0))
        if not has_intercept_tr:
            # Add intercept column as first column
            X_tr = np.concatenate([np.ones((X_tr.shape[0], 1)), X_tr], axis=1)
        if X_val is not None:
            has_intercept_val = (X_val.shape[1] > 0 and np.allclose(X_val[:, 0], 1.0))
            if not has_intercept_val:
                X_val = np.concatenate([np.ones((X_val.shape[0], 1)), X_val], axis=1)
        if X_te is not None:
            has_intercept_te = (X_te.shape[1] > 0 and np.allclose(X_te[:, 0], 1.0))
            if not has_intercept_te:
                X_te = np.concatenate([np.ones((X_te.shape[0], 1)), X_te], axis=1)
    # For STreeD, use data as-is (no intercept column needed)
    
    params = build_params(
        depth=depth,
        lam=lam,
        method=method,
        n_train=len(X_tr),
        threshold_mode=threshold_mode,
    )

    # Optuna tuning
    if trial is not None:
        if method in ["streed", "streed_sl"]:
            params["ridge_penalty"] = trial.suggest_float(
                "ridge_penalty", 1e-6, 5e-1, log=True
            )
            params["lasso_penalty"] = trial.suggest_float(
                "lasso_penalty", 1e-6, 5e-1, log=True
            )
        elif method in ["clari_tree", "greedy"]:
            params["ridge_penalty"] = trial.suggest_float(
                "ridge_penalty", 1e-6, 5e-1, log=True
            )

    if override_params is not None:
        params.update(override_params)

    def _fit():
        m = proc.build(**params)
        return proc.fit(m, X_tr, y_tr)

    art, t_fit = time_block(_fit)

    # predict on training set
    y_hat_tr = proc.predict(art.model, X_tr)
    r2_tr = float(scorer_r2(y_tr, y_hat_tr))
    mse_tr = float(scorer_mse(y_tr, y_hat_tr))

    # predict on validation set if provided
    if X_val is not None and y_val is not None:
        y_hat_val = proc.predict(art.model, X_val)
        r2_val = float(scorer_r2(y_val, y_hat_val))
        mse_val = float(scorer_mse(y_val, y_hat_val))
    else:
        r2_val = float("nan")
        mse_val = float("nan")

    # predict on test set if provided
    if X_te is not None and y_te is not None:
        y_hat_te = proc.predict(art.model, X_te)
        r2_te = float(scorer_r2(y_te, y_hat_te))
        mse_te = float(scorer_mse(y_te, y_hat_te))
    else:
        r2_te = float("nan")
        mse_te = float("nan")

    # leaves = int(getattr(art, "complexity", 0))
    # Handle case when streed fails to find feasible tree (complexity is NaN)
    # Check feasible flag in extras, or check if complexity is NaN
    feasible = art.extras.get("feasible", True) if art.extras else True
    complexity_val = getattr(art, "complexity", 0)
    if not feasible or (isinstance(complexity_val, float) and np.isnan(complexity_val)):
        leaves = 0
    else:
        leaves = int(complexity_val)
    return {
        "ridge_penalty": params.get("ridge_penalty", float("nan")),
        "lasso_penalty": params.get("lasso_penalty", float("nan")),
        "leaves": leaves,
        "r2_train": r2_tr,
        "r2_valid": r2_val,
        "r2_test": r2_te,
        "mse_train": mse_tr,
        "mse_valid": mse_val,
        "mse_test": mse_te,
        "train_time_s": float(t_fit),
    }


def run_optuna_for_streed(
    proc,
    X_tr,
    y_tr,
    X_val,
    y_val,
    depth,
    lam,
    method: str = "streed",
    n_trials: int = 20,
    threshold_mode: str = "threshold",
):
    def objective(trial):
        res = fit_and_eval(
            proc,
            method,  # Use the actual method (streed or streed_sl) for tuning
            X_tr,
            y_tr,
            X_val=X_val,
            y_val=y_val,
            depth=depth,
            lam=lam,
            trial=trial,
            threshold_mode=threshold_mode,
        )
        # Use validation MSE for hyperparameter tuning
        # Test set should only be used for final evaluation
        return res["mse_valid"]

    n_jobs = auto_n_jobs()
    print(f"[Optuna|streed] Detected {os.cpu_count()} CPUs, using n_jobs={n_jobs}")
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=n_trials, n_jobs=n_jobs)
    return study.best_params, study.best_value


def run_optuna_for_chol(
    proc,
    X_tr,
    y_tr,
    X_val,
    y_val,
    depth,
    lam,
    method,
    n_trials: int = 20,
    threshold_mode: str = "threshold",
):
    def objective(trial):
        res = fit_and_eval(
            proc,
            method,
            X_tr,
            y_tr,
            X_val=X_val,
            y_val=y_val,
            depth=depth,
            lam=lam,
            trial=trial,
            threshold_mode=threshold_mode,
        )
        # Use validation MSE for hyperparameter tuning
        # Test set should only be used for final evaluation
        return res["mse_valid"]

    n_jobs = auto_n_jobs()
    print(f"[Optuna|{method}] Detected {os.cpu_count()} CPUs, using n_jobs={n_jobs}")
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=n_trials, n_jobs=n_jobs)
    return study.best_params, study.best_value


def run_for_dataset(
    dataset_dir: Path,
    depth: int,
    lambdas_streed: list[float],
    lambdas_clari_tree: list[float],
    lambdas_greedy: list[float],
    guide_nodes: list[int],
    out_csv: Path,
    outer_id: int | None = None,
    methods: list[str] | None = None,
    threshold_mode: str = "full",
):
    """Run sweep for one dataset and save results to CSV."""
    processors = get_processors()
    # If a subset of methods is requested, filter here.
    if methods is not None:
        invalid = [m for m in methods if m not in processors]
        if invalid:
            raise ValueError(
                f"Unknown method(s): {invalid}. "
                f"Valid methods are: {list(processors.keys())}"
            )
        processors = {name: proc for name, proc in processors.items() if name in methods}
    splits_root = dataset_dir / "splits"
    outers = sorted(
        [p for p in splits_root.iterdir() if p.is_dir() and p.name.startswith("outer_")],
        key=lambda p: int(p.name.split("_")[-1]),
    )

    # restrict to one outer if requested
    if outer_id is not None:
        outers = [p for p in outers if p.name == f"outer_{outer_id}"]

    dataset = dataset_dir.name

    out_csv.parent.mkdir(parents=True, exist_ok=True)
    with out_csv.open("w", newline="") as f:
        fieldnames = [
            "dataset",
            "outer",
            "method",
            "depth",
            "lambda",
            "ridge_penalty",
            "lasso_penalty",
            "leaves",
            "r2_train",
            "r2_valid",
            "r2_test",
            "mse_train",
            "mse_valid",
            "mse_test",
            "train_time_s",
        ]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()

        # for mean/std aggregation when multiple folds are run
        agg: dict[tuple[str, float], dict[str, list[float]]] = {}

        for odir in tqdm(outers, desc=f"{dataset}|outer folds"):
            k_outer = odir.name
            train_csv = odir / "train.csv"
            test_csv = odir / "test.csv"
            if not train_csv.exists() or not test_csv.exists():
                print(f"[skip] {dataset}/{k_outer} missing train/test.csv")
                continue

            X_tr_full, y_tr_full = load_xy_claritree(train_csv)
            X_te, y_te = load_xy_claritree(test_csv)
            
            # Split training data into train (for fitting) and valid (for Optuna tuning)
            # Use 80/20 split: 80% for training, 20% for validation
            X_tr, X_val, y_tr, y_val = train_test_split(
                X_tr_full, y_tr_full, test_size=0.2, random_state=42
            )
            
            # Stage 1: Optuna hyperparameter tuning
            # Standardize using inner train (64%) statistics for Optuna tuning
            # This ensures no data leakage during hyperparameter selection
            X_tr_optuna, y_tr_optuna, X_val_optuna, y_val_optuna, _, _ = standardize_and_center(
                X_tr, y_tr, X_val=X_val, y_val=y_val
            )
            
            # Stage 2: Final model training
            # Re-fit scaler on full training set (train+valid, 80%) for final model
            # This ensures the final model uses statistics from all available training data
            # Note: No validation set needed here, so function returns 4 values
            X_tr_final, y_tr_final, X_te_final, y_te_final = standardize_and_center(
                X_tr_full, y_tr_full, X_te=X_te, y_te=y_te
            )

            for method, proc in tqdm(
                processors.items(),
                desc=f"{dataset}|{k_outer}|methods",
                leave=False,
            ):
                # GUIDE: delegates to R, no threshold_mode
                if method == "guide":
                    # Use a short alias for certain long-named datasets to keep the
                    # GUIDE R code file path well below its internal length limit.
                    guide_tag = GUIDE_DATASET_ALIASES.get(dataset, dataset)

                    # Create temporary CSV files with standardized and centered data
                    # to ensure GUIDE uses the same preprocessing as other methods
                    # Note: GUIDE doesn't use Optuna, but we use full training set (train+valid) for consistency
                    guide_temp_dir = Path("./results/guide_work_optuna") / guide_tag / k_outer
                    guide_temp_dir.mkdir(parents=True, exist_ok=True)
                    guide_train_csv = guide_temp_dir / "train_standardized.csv"
                    guide_test_csv = guide_temp_dir / "test_standardized.csv"
                    
                    # Write standardized and centered full training data (train+valid) to CSV
                    # Use full training set statistics (80%) for final model, consistent with other methods
                    # Add meaningful column names so GUIDE's _make_dsc can read them correctly
                    n_features = X_tr_final.shape[1]
                    feature_cols = [f"X{i+1}" for i in range(n_features)]
                    guide_train_df = pd.DataFrame(X_tr_final, columns=feature_cols)
                    guide_train_df['target'] = y_tr_final
                    guide_train_df.to_csv(guide_train_csv, index=False, header=True)
                    
                    # Write standardized and centered test data to CSV
                    # Use full training set statistics (80%) for test set transformation
                    guide_test_df = pd.DataFrame(X_te_final, columns=feature_cols)
                    guide_test_df['target'] = y_te_final
                    guide_test_df.to_csv(guide_test_csv, index=False, header=True)

                    for max_nodes in tqdm(
                        guide_nodes,
                        desc="guide max_nodes",
                        leave=False,
                    ):
                        # For GUIDE we use a dedicated, short work directory root so
                        # that the R code file path stays safely within 100 chars.
                        guide_work_root = Path("./results/guide_work_optuna")
                        work_dir = (
                            guide_work_root
                            / guide_tag
                            / f"d{depth}_{k_outer}_n{max_nodes}"
                        )
                        work_dir.mkdir(parents=True, exist_ok=True)

                        m = proc.build(
                            csv_path=guide_train_csv,  # Use standardized data
                            depth=depth,
                            max_nodes=max_nodes,
                            work_dir=work_dir,
                        )
                        try:
                            art, t_fit = time_block(lambda: proc.fit(m, None, None))
                        except RuntimeError as e:
                            print(f"[ERROR] GUIDE fit failed for {k_outer} n={max_nodes}: {e}")
                            continue

                        rfile = work_dir / "guide_model.R"
                        
                        # Update R script to use standardized test data
                        r_content = rfile.read_text(encoding="utf-8", errors="ignore")
                        # Replace train_standardized.csv with test_standardized.csv
                        r_content = r_content.replace("train_standardized.csv", "test_standardized.csv")
                        # Also handle the case where it might reference the original path
                        r_content = r_content.replace(str(guide_train_csv), str(guide_test_csv))
                        rfile.write_text(r_content, encoding="utf-8")
                        # Also call update_guide_model_r for any other updates it might do
                        update_guide_model_r(rfile)
                        # Ensure test_standardized.csv is used (in case update_guide_model_r changed it back)
                        r_content = rfile.read_text(encoding="utf-8", errors="ignore")
                        r_content = r_content.replace("train.csv", "test_standardized.csv")
                        r_content = r_content.replace("train_standardized.csv", "test_standardized.csv")
                        r_content = r_content.replace(str(guide_train_csv), str(guide_test_csv))
                        rfile.write_text(r_content, encoding="utf-8")
                        out, t_pred = time_block(lambda: run_rscript(rfile))

                        try:
                            r2_te = parse_r2(out)
                        except Exception:
                            r2_te = float("nan")

                        train_out_path = find_guide_training_out(work_dir)
                        if train_out_path is not None:
                            r2_tr, elapsed = parse_guide_train_r2_and_elapsed(
                                train_out_path
                            )
                        else:
                            r2_tr, elapsed = parse_guide_train_r2_and_elapsed(out)

                        fit_time = elapsed if elapsed else t_fit

                        row = {
                            "dataset": dataset,
                            "outer": k_outer,
                            "method": "guide",
                            "depth": depth,
                            "lambda": max_nodes,
                            "ridge_penalty": float("nan"),
                            "lasso_penalty": float("nan"),
                            "leaves": int(getattr(art, "complexity", max_nodes)),
                            "r2_train": r2_tr,
                            "r2_valid": float("nan"),  # GUIDE doesn't use validation set
                            "r2_test": r2_te,
                            "mse_train": float("nan"),
                            "mse_valid": float("nan"),  # GUIDE doesn't use validation set
                            "mse_test": float("nan"),
                            "train_time_s": fit_time,
                        }
                        writer.writerow(row)

                        key = (method, float(max_nodes))
                        if key not in agg:
                            agg[key] = {
                                k: []
                                for k in row.keys()
                                if k
                                not in {
                                    "dataset",
                                    "outer",
                                    "method",
                                    "depth",
                                    "lambda",
                                }
                            }
                        for k in ["leaves", "r2_train", "r2_test", "train_time_s"]:
                            agg[key][k].append(row[k])

                # STreeD / STreeD simple
                elif method in ["streed", "streed_sl"]:
                    current_lambdas = lambdas_streed
                    for lam in tqdm(
                        current_lambdas, desc=f"{method} lambdas", leave=False
                    ):
                        # Step 1: Use train/valid for Optuna hyperparameter tuning
                        # Use inner train statistics (64%) for Optuna tuning
                        # Pass the actual method to ensure tuning matches final training
                        best_params, best_val = run_optuna_for_streed(
                            proc,
                            X_tr_optuna,
                            y_tr_optuna,
                            X_val_optuna,
                            y_val_optuna,
                            depth,
                            lam,
                            method=method,  # Pass the actual method (streed or streed_sl)
                            n_trials=20,
                            threshold_mode=threshold_mode,
                        )
                        # Step 2: Retrain on full training set (train+valid, 80%) with best params
                        # Use full training set statistics (80%) for final model
                        # Step 3: Evaluate on test set (test set never used for tuning)
                        res = fit_and_eval(
                            proc,
                            method,
                            X_tr_final,  # Use full training set for final model
                            y_tr_final,
                            X_te=X_te_final,
                            y_te=y_te_final,
                            depth=depth,
                            lam=lam,
                            override_params=best_params,
                            threshold_mode=threshold_mode,
                        )

                        row = {
                            "dataset": dataset,
                            "outer": k_outer,
                            "method": method,
                            "depth": depth,
                            "lambda": lam,
                            **res,
                        }
                        writer.writerow(row)

                        key = (method, float(lam))
                        if key not in agg:
                            agg[key] = {k: [] for k in res.keys()}
                        for k, v in res.items():
                            agg[key][k].append(v)

                # CLARITree / Greedy
                elif method in ["clari_tree", "greedy"]:
                    if method == "clari_tree":
                        current_lambdas = lambdas_clari_tree
                    else:
                        current_lambdas = lambdas_greedy
                    for lam in tqdm(
                        current_lambdas, desc=f"{method} lambdas", leave=False
                    ):
                        # Step 1: Use train/valid for Optuna hyperparameter tuning
                        # Use inner train statistics (64%) for Optuna tuning
                        best_params, best_val = run_optuna_for_chol(
                            proc,
                            X_tr_optuna,
                            y_tr_optuna,
                            X_val_optuna,
                            y_val_optuna,
                            depth,
                            lam,
                            method,
                            n_trials=20,
                            threshold_mode=threshold_mode,
                        )
                        # Step 2: Retrain on full training set (train+valid, 80%) with best params
                        # Use full training set statistics (80%) for final model
                        # Step 3: Evaluate on test set (test set never used for tuning)
                        res = fit_and_eval(
                            proc,
                            method,
                            X_tr_final,  # Use full training set for final model
                            y_tr_final,
                            X_te=X_te_final,
                            y_te=y_te_final,
                            depth=depth,
                            lam=lam,
                            override_params=best_params,
                            threshold_mode=threshold_mode,
                        )

                        row = {
                            "dataset": dataset,
                            "outer": k_outer,
                            "method": method,
                            "depth": depth,
                            "lambda": lam,
                            **res,
                        }
                        writer.writerow(row)

                        key = (method, float(lam))
                        if key not in agg:
                            agg[key] = {k: [] for k in res.keys()}
                        for k, v in res.items():
                            agg[key][k].append(v)

        # only compute mean/std when multiple outers are run together
        if outer_id is None:
            def mean(xs): return float(np.mean(xs)) if xs else float("nan")
            def std(xs): return float(np.std(xs, ddof=1)) if len(xs) > 1 else 0.0

            for (method, lam), stats in agg.items():
                mean_row = {
                    "dataset": dataset,
                    "outer": "mean",
                    "method": method,
                    "depth": depth,
                    "lambda": lam,
                }
                std_row = {
                    "dataset": dataset,
                    "outer": "std",
                    "method": method,
                    "depth": depth,
                    "lambda": lam,
                }
                for k, xs in stats.items():
                    mean_row[k] = mean(xs)
                    std_row[k] = std(xs)
                writer.writerow(mean_row)
                writer.writerow(std_row)


def parse_nodes(arg: str) -> list[int]:
    """Parse --guide_nodes argument: either 'a-b' range or comma list."""
    if "-" in arg:
        lo, hi = arg.split("-")
        return list(range(int(lo), int(hi) + 1))
    return [int(x) for x in arg.split(",")]


def main():
    parser = argparse.ArgumentParser(
        description=(
            "Optuna+Lambda sweep for "
            "clari_tree/greedy/streed/streed_sl/guide on one dataset (linear)."
        )
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Dataset folder, e.g. data/airfoil (must contain /splits/outer_0..5/)",
    )
    parser.add_argument(
        "--depth",
        type=int,
        required=True,
        help="Tree depth for all runs",
    )
    parser.add_argument(
        "--lambdas",
        type=str,
        required=True,
        help="Comma-separated list of lambdas, e.g. '0.001,0.01,0.1,1.0'",
    )
    parser.add_argument(
        "--lambdas_streed",
        type=str,
        default=None,
        help=(
            "Optional comma-separated list of lambdas specifically for "
            "streed/streed_sl. Defaults to --lambdas if not provided."
        ),
    )
    parser.add_argument(
        "--lambdas_clari_tree",
        type=str,
        default=None,
        help=(
            "Optional comma-separated list of lambdas specifically for "
            "clari_tree. Defaults to --lambdas if not provided."
        ),
    )
    parser.add_argument(
        "--lambdas_greedy",
        type=str,
        default=None,
        help=(
            "Optional comma-separated list of lambdas specifically for "
            "greedy. Defaults to --lambdas if not provided."
        ),
    )
    parser.add_argument(
        "--guide_nodes",
        type=str,
        default="1-32",
        help="Range of max_nodes for GUIDE, e.g. '1-32' or '2,4,8,16'",
    )
    parser.add_argument(
        "--outer_id",
        type=int,
        default=None,
        help="Run only the specified outer fold (0..5). If omitted, run all folds.",
    )
    parser.add_argument(
        "--methods",
        type=str,
        default=None,
        help=(
            "Comma-separated list of methods to run, subset of "
            "clari_tree,greedy,streed,streed_sl,guide. "
            "Defaults to all methods."
        ),
    )
    parser.add_argument(
        "--threshold_mode",
        type=str,
        default="full",
        choices=["threshold", "full"],
        help="Threshold search mode for streed/clari_tree/greedy.",
    )
    parser.add_argument(
        "--results_dir",
        type=str,
        default=None,
        help="Output directory. Default: ./results/{prefix}_depth{depth}/{dataset}",
    )
    args = parser.parse_args()

    # Base lambda list (required for backward compatibility)
    base_lambdas = [float(x) for x in args.lambdas.split(",")]

    # Method-specific lambda sequences (default to base_lambdas if not set)
    lambdas_streed = (
        [float(x) for x in args.lambdas_streed.split(",")]
        if args.lambdas_streed is not None
        else base_lambdas
    )
    lambdas_clari_tree = (
        [float(x) for x in args.lambdas_clari_tree.split(",")]
        if args.lambdas_clari_tree is not None
        else base_lambdas
    )
    lambdas_greedy = (
        [float(x) for x in args.lambdas_greedy.split(",")]
        if args.lambdas_greedy is not None
        else base_lambdas
    )
    dataset_dir = Path(args.data_dir)
    if not dataset_dir.exists():
        raise FileNotFoundError(f"Dataset folder {dataset_dir} not found")
    guide_nodes = parse_nodes(args.guide_nodes)
    methods = (
        [m.strip() for m in args.methods.split(",") if m.strip()]
        if args.methods is not None
        else None
    )

    prefix = "LRT" if args.threshold_mode == "threshold" else "LRF"

    # Decide default results directory.
    # - If user explicitly sets --results_dir, always honor it.
    # - Otherwise:
    #     * If running all methods, keep original location:
    #         ./results/{prefix}_depth{depth}/{dataset}
    #     * If running a subset of methods, write into a "patch" area to
    #       avoid mixing with existing full-method results:
    #         ./results/patch/{prefix}_depth{depth}/{dataset}
    if args.results_dir is not None:
        results_root = Path(args.results_dir)
    else:
        all_methods = set(get_processors().keys())
        is_all_methods = methods is None or set(methods) == all_methods
        if is_all_methods:
            results_root = Path(f"./results/{prefix}_depth{args.depth}/{dataset_dir.name}")
        else:
            results_root = Path(
                f"./results/patch/{prefix}_depth{args.depth}/{dataset_dir.name}"
            )

    # nicer naming: if outer_id is None, aggregate all folds
    if args.outer_id is None:
        suffix = "all"
    else:
        suffix = f"outer{args.outer_id}"

    out_csv = results_root / f"results_linear_optuna_d{args.depth}_{suffix}.csv"

    print(f"[{dataset_dir.name}] -> {out_csv}")
    print(f"threshold_mode = {args.threshold_mode}")
    run_for_dataset(
        dataset_dir,
        args.depth,
        lambdas_streed,
        lambdas_clari_tree,
        lambdas_greedy,
        guide_nodes,
        out_csv,
        outer_id=args.outer_id,
        methods=methods,
        threshold_mode=args.threshold_mode,
    )
    print("All done.")


if __name__ == "__main__":
    main()
