import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from matplotlib.legend_handler import HandlerTuple

# Import machine learning models and tools
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.tree import DecisionTreeRegressor
import xgboost as xgb  # Import XGBoost

# Import base classes for creating custom models
from sklearn.base import BaseEstimator, RegressorMixin
# Import Ridge Regression
from sklearn.linear_model import Ridge

# -------------------------------------------------------------------
# 2. Global Configuration and Data Generation
# -------------------------------------------------------------------
# Matplotlib plotting configuration
plt.rcParams.update({
    "font.size": 20,
    "legend.fontsize": 16,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
})


def true_func(x):
    """
    Defines the target function for synthetic data generation.

    Args:
        x (np.ndarray): Input feature.

    Returns:
        np.ndarray: Target variable.
    """
    y = -np.sinc(5 * x)
    return y

    """
    sigmoid = 1 / (1 + np.exp(-x * 3))
    return 2 * sigmoid - x * 0.8
    """

# Generate synthetic data
np.random.seed(0) # Set random seed for reproducibility
x_data = np.linspace(-1.5, 1.5, 1000) # Generate 1000 equally spaced x values in the specified interval
#x_data = np.linspace(-3, 3, 1000) # Generate 1000 equally spaced x values in the specified interval
noise_level = 0.025 # Define noise level
y_true = true_func(x_data) # Calculate true y values
y_data = y_true + np.random.normal(0, noise_level, size=x_data.shape) # Add Gaussian noise to generate observed y values

# Algorithm hyperparameters
# Minimum number of data points required for a valid split.
MIN_POINTS_FOR_SPLIT = 5


# -------------------------------------------------------------------
# 3. Core Algorithm Functions (Your Piecewise Linear Fitting Algorithm)
# -------------------------------------------------------------------
def linear_model(x, a, b):
    """
    Calculates the prediction of a linear model: \(y = ax + b\).

    Args:
        x (np.ndarray): Input feature.
        a (float): Slope.
        b (float): Intercept.

    Returns:
        np.ndarray: Predicted values from the linear model.
    """
    x = np.asanyarray(x)
    return a * x + b


def calculate_rmse(y_true, y_pred):
    """
    Calculates the Root Mean Squared Error (RMSE).

    Args:
        y_true (np.ndarray): True target variable.
        y_pred (np.ndarray): Predicted target variable.

    Returns:
        float: The RMSE value.
    """
    return np.sqrt(mean_squared_error(y_true, y_pred))


def solve_ridge_regression(X, y, alpha):
    """
    Solves for linear model parameters using Ridge Regression.

    Args:
        X (np.ndarray): Feature matrix, should include an intercept term (a column of ones).
        y (np.ndarray): Target variable.
        alpha (float): Regularization strength for Ridge Regression.

    Returns:
        np.ndarray: Model coefficients \([a, b]\), where \(a\) is the slope and \(b\) is the intercept.
                    Returns default values if fitting fails or insufficient data points.
    """
    # Ensure enough data points to fit the model (at least as many as features)
    if len(y) < X.shape[1]:
        # Not enough data points to fit, return default: slope 0, intercept as mean of y
        return np.array([0, np.mean(y) if len(y) > 0 else 0])
    try:
        # Our X_des already contains a column of ones for the intercept, so fit_intercept for Ridge model is False
        ridge_model = Ridge(alpha=alpha, fit_intercept=False)
        ridge_model.fit(X, y)
        return ridge_model.coef_
    except Exception as e:
        # If Ridge Regression fitting fails, print a warning and return default values
        print(f"Warning: Ridge regression failed, returning default coefficients. Error: {e}")
        return np.array([0, np.mean(y) if len(y) > 0 else 0])


def fit_single_line(x, y, ridge_alpha):
    """
    Fits a single line to the given data points (using Ridge Regression).

    Args:
        x (np.ndarray): Array of x-coordinates.
        y (np.ndarray): Array of y-coordinates.
        ridge_alpha (float): Regularization strength for Ridge Regression.

    Returns:
        tuple: (slope \(a\), intercept \(b\)). Returns (None, None) if insufficient data points.
    """
    if len(x) < 2:  # At least two points are needed to fit a line
        return None, None
    # Construct the design matrix \([x, 1]\) for linear regression
    X_des = np.vstack([x, np.ones_like(x)]).T
    # Call the Ridge Regression solver function
    theta = solve_ridge_regression(X_des, y, ridge_alpha)
    return theta[0], theta[1]


def initialize_thetas_guaranteed_intersect(X, y, random_seed=None):
    """
    Initializes parameters for two piecewise lines, ensuring they intersect at some point.
    This function initializes by exactly solving a 2x2 linear system, not by regression.

    Args:
        X (np.ndarray): Design matrix containing x-coordinates and a constant term.
        y (np.ndarray): Target variable.
        random_seed (int, optional): Random seed, not used in this function but kept as a parameter.

    Returns:
        tuple: (theta1, theta2), representing parameters \([a, b]\) for the left and right lines, respectively.
    """
    if random_seed is not None:
        np.random.seed(random_seed)

    x_coords = X[:, 0]
    # Sort by x-coordinates
    sort_indices = np.argsort(x_coords)
    x_sorted, y_sorted = x_coords[sort_indices], y[sort_indices]

    # Define left boundary point, right boundary point, and anchor point (y-value corresponding to the median x-value)
    p_left_boundary = (x_sorted[0], y_sorted[0])
    p_right_boundary = (x_sorted[-1], y_sorted[-1])
    x_anchor = np.median(x_coords)
    y_anchor = np.interp(x_anchor, x_sorted, y_sorted) # Interpolate on sorted data to get y_anchor
    p_anchor = (x_anchor, y_anchor)

    # Solve for the left line (through the left boundary point and anchor point)
    # Construct linear system A1 * theta1 = B1
    A1 = np.array([[p_left_boundary[0], 1], [p_anchor[0], 1]])
    B1 = np.array([p_left_boundary[1], p_anchor[1]])
    # Check determinant to avoid singular matrix, otherwise slope might be 0
    theta1 = np.linalg.solve(A1, B1) if np.abs(np.linalg.det(A1)) > 1e-9 else np.array([0, p_left_boundary[1]])

    # Solve for the right line (through the anchor point and right boundary point)
    # Construct linear system A2 * theta2 = B2
    A2 = np.array([[p_anchor[0], 1], [p_right_boundary[0], 1]])
    B2 = np.array([p_anchor[1], p_right_boundary[1]])
    theta2 = np.linalg.solve(A2, B2) if np.abs(np.linalg.det(A2)) > 1e-9 else np.array([0, p_right_boundary[1]])

    return theta1, theta2


def optimize_split_optimized(x, y, max_iter=10, tol=1e-6, min_points=MIN_POINTS_FOR_SPLIT, step_size=0.01, ridge_alpha=1.0):
    """
    Iteratively optimizes to find the best split point and two fitted lines.

    Args:
        x (np.ndarray): Input feature.
        y (np.ndarray): Target variable.
        max_iter (int): Maximum number of iterations.
        tol (float): Convergence tolerance for parameter changes.
        min_points (int): Minimum number of data points required for left and right sub-segments.
        step_size (float): Step size for parameter updates during iteration.
        ridge_alpha (float): Regularization strength for Ridge Regression.

    Returns:
        tuple: (split point \(x_s\), left parameters \((a_1,b_1)\), right parameters \((a_2,b_2)\), left mask, right mask).
               Returns \((np.nan, (None, None), (None, None), None, None)\) if an effective split cannot be made.
    """
    n_points = len(x)
    if n_points < 2 * min_points: # Not enough data points for a valid split
        return np.nan, (None, None), (None, None), None, None

    # Construct the design matrix [x, 1]
    X_des = np.vstack([x, np.ones_like(x)]).T

    try:
        # Initialize parameters for the two lines
        theta1, theta2 = initialize_thetas_guaranteed_intersect(X_des, y)
    except Exception as e:
        print(f"Warning: Initialization of split parameters failed. Error: {e}")
        # On initialization failure, fall back to global Ridge Regression fit
        theta_global = solve_ridge_regression(X_des, y, ridge_alpha)
        theta1, theta2 = theta_global.copy(), theta_global.copy()

    prev_left_mask = None
    for i in range(max_iter):
        theta1_old, theta2_old = theta1.copy(), theta2.copy()

        # Classify left/right sub-segments based on current parameters: point \((x, y)\) belongs to the left
        # if \(a_1 x + b_1 < a_2 x + b_2\), i.e., \((a_1 - a_2)x + (b_1 - b_2) < 0\)
        prediction_diff = X_des @ (theta1 - theta2)
        current_left_mask = prediction_diff < 0

        # If the split mask no longer changes, it's considered converged
        if prev_left_mask is not None and np.array_equal(current_left_mask, prev_left_mask):
            break
        prev_left_mask = current_left_mask

        n_left, n_right = np.sum(current_left_mask), n_points - np.sum(current_left_mask)

        # Check if sub-segments meet the minimum data points requirement
        if n_left < min_points or n_right < min_points:
            # If not met, fall back to a global fit (i.e., no split), stop iteration
            theta_global = solve_ridge_regression(X_des, y, ridge_alpha)
            theta1, theta2 = theta_global.copy(), theta_global.copy()
            break

        try:
            # Fit the left and right parts using Ridge Regression and update parameters
            theta1_target = solve_ridge_regression(X_des[current_left_mask], y[current_left_mask], ridge_alpha)
            theta2_target = solve_ridge_regression(X_des[~current_left_mask], y[~current_left_mask], ridge_alpha)

            # Update parameters using a step size for stability
            theta1 = (1 - step_size) * theta1_old + step_size * theta1_target
            theta2 = (1 - step_size) * theta2_old + step_size * theta2_target
        except Exception as e: # Catch any errors that Ridge Regression might encounter
            print(f"Warning: Error during iterative ridge regression update, stopping iteration. Error: {e}")
            theta1, theta2 = theta1_old, theta2_old # Roll back to previous parameters
            break

        # Check if parameters have converged
        if np.linalg.norm(theta1 - theta1_old) < tol and np.linalg.norm(theta2 - theta2_old) < tol:
            break

    # Calculate the final split mask and split point
    final_prediction_diff = X_des @ (theta1 - theta2)
    final_left_mask = final_prediction_diff < 0
    final_right_mask = ~final_left_mask

    a1, b1 = theta1
    a2, b2 = theta2

    # Re-check if the final split sub-segments meet the minimum data points requirement
    if np.sum(final_left_mask) < min_points or np.sum(final_right_mask) < min_points:
        return np.nan, (None, None), (None, None), None, None

    # Calculate the split point \(x_s\). When \(a_1 x + b_1 = a_2 x + b_2\), \(x_s = (b_2 - b_1) / (a_1 - a_2)\).
    # Avoid division by zero
    x_s_final = (b2 - b1) / (a1 - a2) if not np.isclose(a1, a2) else np.nan
    return x_s_final, (a1, b1), (a2, b2), final_left_mask, final_right_mask


def recursive_piecewise_fit(x, y, threshold, min_points=MIN_POINTS_FOR_SPLIT, depth=0, max_depth=5, step_size=0.01, ridge_alpha=1.0):
    """
    Recursively performs piecewise linear fitting.

    Args:
        x (np.ndarray): Input features for the current segment.
        y (np.ndarray): Target variables for the current segment.
        threshold (float): RMSE threshold; splitting stops if RMSE is below this value.
        min_points (int): Minimum number of data points.
        depth (int): Current recursion depth.
        max_depth (int): Maximum recursion depth.
        step_size (float): Step size for iterative optimization.
        ridge_alpha (float): Regularization strength for Ridge Regression.

    Returns:
        list: A list of dictionaries, each containing parameters and x-range for a segment.
    """
    n_points = len(x)
    # If insufficient data points to fit, return an empty list
    if n_points < min_points:
        return []

    # If maximum recursion depth is reached, stop splitting and fit a single line to the current segment
    if depth >= max_depth:
        a, b = fit_single_line(x, y, ridge_alpha)
        return [{'params': (a, b), 'x_range': (np.min(x), np.max(x))}] if a is not None else []

    # Attempt to fit a single line to the entire current segment
    a_single, b_single = fit_single_line(x, y, ridge_alpha)
    if a_single is None: # If a single line cannot be fitted, return an empty list
        return []

    # Calculate RMSE for the single line fit
    rmse_single = calculate_rmse(y, linear_model(x, a_single, b_single))
    x_range_single = (np.min(x), np.max(x))

    # If RMSE is below threshold, or data points are insufficient for further splitting,
    # stop splitting and return the current single line fit result.
    if rmse_single <= threshold or n_points < 2 * min_points:
        return [{'params': (a_single, b_single), 'x_range': x_range_single}]

    # Attempt optimized splitting
    x_s, params_left, params_right, mask_left, mask_right = optimize_split_optimized(
        x, y, min_points=min_points, step_size=step_size, ridge_alpha=ridge_alpha
    )

    split_successful = (params_left[0] is not None and params_right[0] is not None)

    # --- Fallback mechanism for split failure ---
    if not split_successful:
        # If optimized split fails, try a simple split by median
        median_x_val = np.median(x)
        mask_left_fallback = x < median_x_val
        mask_right_fallback = ~mask_left_fallback

        n_left_fallback = np.sum(mask_left_fallback)
        n_right_fallback = np.sum(mask_right_fallback)

        # Only accept fallback split if it's valid (meets minimum points requirement)
        if n_left_fallback >= min_points and n_right_fallback >= min_points:
            mask_left = mask_left_fallback
            mask_right = mask_right_fallback
            split_successful = True # Mark as fallback split successful
        else:
            split_successful = False # Fallback split also failed
    # --- End of fallback mechanism ---

    # If splitting is successful, recursively process left and right sub-segments
    if split_successful:
        segments_left = recursive_piecewise_fit(x[mask_left], y[mask_left], threshold, min_points, depth + 1, max_depth, step_size, ridge_alpha)
        segments_right = recursive_piecewise_fit(x[mask_right], y[mask_right], threshold, min_points, depth + 1, max_depth, step_size, ridge_alpha)
        return segments_left + segments_right
    else:
        # If all splitting attempts fail, return the single line fit result for the current segment
        return [{'params': (a_single, b_single), 'x_range': x_range_single}]


def predict(x_new, segments):
    """
    Makes predictions on new data points using the trained piecewise model.

    Args:
        x_new (np.ndarray): New x-coordinates for prediction.
        segments (list): Trained piecewise model, containing parameters and x-ranges for each segment.

    Returns:
        np.ndarray: Predicted values corresponding to x_new.
    """
    predictions = np.zeros_like(x_new, dtype=float)
    if not segments: # If no segments, return an array of zeros
        return predictions

    for i, x_val in enumerate(x_new):
        found_segment = False
        for seg in segments:
            x_min, x_max = seg['x_range']
            # Check if the point falls within the current segment's x-range, add a small tolerance for float precision
            if x_min - 1e-9 <= x_val <= x_max + 1e-9:
                a, b = seg['params']
                predictions[i] = linear_model(x_val, a, b)
                found_segment = True
                break

        if not found_segment:
            # If the point is not within any segment's exact range (possibly due to split point rounding or extrapolation),
            # find the closest segment by x-value for extrapolation.
            segment_centers = [np.mean(seg['x_range']) for seg in segments]
            closest_seg_idx = np.argmin(np.abs(np.array(segment_centers) - x_val))
            a, b = segments[closest_seg_idx]['params']
            predictions[i] = linear_model(x_val, a, b)

    return predictions


# -------------------------------------------------------------------
# 3.5. Wrap your algorithm in a Scikit-learn compatible class
# -------------------------------------------------------------------
class PiecewiseLinearRegressor(BaseEstimator, RegressorMixin):
    """
    Scikit-learn compatible Piecewise Linear Regressor.
    This class wraps the recursive piecewise linear fitting algorithm, allowing it to be used
    with Scikit-learn's API, e.g., for grid search and cross-validation.
    """

    def __init__(self, threshold=0.03, max_depth=5, min_points=5, step_size=0.01, ridge_alpha=1.0):
        """
        Initializes the Piecewise Linear Regressor.

        Args:
            threshold (float): RMSE threshold for stopping splitting.
            max_depth (int): Maximum recursion depth.
            min_points (int): Minimum number of data points required per sub-segment.
            step_size (float): Step size for iterative optimization.
            ridge_alpha (float): Regularization strength for Ridge Regression.
        """
        self.threshold = threshold
        self.max_depth = max_depth
        self.min_points = min_points
        self.step_size = step_size
        self.ridge_alpha = ridge_alpha

    def fit(self, X, y):
        """
        Trains the piecewise linear regression model.

        Args:
            X (array-like): Training data features, expected as a 2D array \((n_{samples}, 1)\).
            y (array-like): Training data target variable.

        Returns:
            self: Returns the model instance.
        """
        # Extract the first column as input features from X, which might contain multiple columns
        x_train = X[:, 0] if X.ndim > 1 else X

        # Call the core recursive fitting function
        segments = recursive_piecewise_fit(
            x=x_train,
            y=y,
            threshold=self.threshold,
            min_points=self.min_points,
            max_depth=self.max_depth,
            step_size=self.step_size,
            ridge_alpha=self.ridge_alpha
        )

        # Filter out invalid segments (where parameters are None) and sort by the start of their x-range
        self.segments_ = [seg for seg in segments if seg['params'][0] is not None]
        if not self.segments_:
            # If no valid segments are found, use the mean of y as a fallback prediction
            self.fallback_prediction_ = np.mean(y)
        else:
            self.segments_.sort(key=lambda seg: seg['x_range'][0])
            self.fallback_prediction_ = None # No fallback value needed if valid segments exist

        return self

    def predict(self, X):
        """
        Makes predictions using the trained model.

        Args:
            X (array-like): Data features for prediction, expected as a 2D array \((n_{samples}, 1)\).

        Returns:
            np.ndarray: Array of predicted values.
        """
        # Extract the first column as input features from X, which might contain multiple columns
        x_new = X[:, 0] if X.ndim > 1 else X

        if not hasattr(self, 'segments_'):
            raise RuntimeError("The model has not been trained. Please call the fit method first!")

        if not self.segments_:
            # If no valid segments, return the fallback prediction
            return np.full(x_new.shape, self.fallback_prediction_)

        return predict(x_new, self.segments_)


# -------------------------------------------------------------------
# 4. Experimental Main Workflow
# -------------------------------------------------------------------

# --- 4.1 Data Splitting ---
TEST_SIZE = 0.3 # Test set ratio
RANDOM_STATE = 42 # Random seed
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=TEST_SIZE, random_state=RANDOM_STATE)

# Reshape x data to a 2D array of (n_samples, 1) to meet Scikit-learn model input requirements
X_train_reshaped = x_train.reshape(-1, 1)
X_test_reshaped = x_test.reshape(-1, 1)

# For plotting, ensure x_test_sorted is ordered
sort_indices = np.argsort(x_test)
x_test_sorted = x_test[sort_indices]

# --- 4.3 Define all models and their hyperparameter grids ---
models_to_tune = {
    "Our Method": {
        "estimator": PiecewiseLinearRegressor(),
        "params": {
            "threshold": [0.01, 0.03, 0.05],  # RMSE threshold for stopping recursive splitting
            "max_depth": [4, 6, 8, 10, 12],  # Maximum recursion depth
            "step_size": [0.01, 0.5, 1], # Step size during iterative optimization
            "ridge_alpha": [0.001, 0.1] # Regularization strength for Ridge Regression
        }
    },
    "CART": {
        "estimator": DecisionTreeRegressor(random_state=RANDOM_STATE),
        "params": {
            "max_depth": [3, 5, 7, 9, 11],
            "min_samples_split": [2, 5, 10],
            "min_samples_leaf": [1, 2, 4]
        }
    },
    "XGBoost": {
        # objective='reg:squarederror' for regression tasks
        # eval_metric='rmse' for evaluation metric
        # use_label_encoder=False to avoid warnings in XGBoost 2.0+ (this parameter is deprecated in later versions)
        "estimator": xgb.XGBRegressor(random_state=RANDOM_STATE, objective='reg:squarederror', eval_metric='rmse',
                                      use_label_encoder=False),
        "params": {
            "n_estimators": [50, 100, 200], # Number of weak learners
            "learning_rate": [0.01, 0.1, 0.2], # Learning rate
            "max_depth": [2, 3, 5], # Maximum depth of weak learners (trees)
            "subsample": [0.7, 1.0] # Proportion of samples used for training in each iteration
        }
    }
}

# --- 4.4 Perform Grid Search and Train the Best Models ---
results = [] # Store results from the initial tuning
trained_models = {} # Store the best models after hyperparameter tuning
best_params_for_repetitions = {} # Store the best hyperparameters for each model

print("\n--- Starting hyperparameter tuning for all models (this may take a moment) ---")
for name, model_info in models_to_tune.items():
    print(f"Tuning model: {name}...")

    estimator = model_info["estimator"]
    params = model_info["params"]

    grid_search = GridSearchCV(
        estimator,
        params,
        cv=5, # 5-fold cross-validation
        scoring='neg_mean_squared_error', # Evaluation metric: negative mean squared error (higher is better)
        n_jobs=-1 # Use all available CPU cores for parallel computation
    )

    grid_search.fit(X_train_reshaped, y_train) # Perform grid search on the training set

    trained_models[name] = grid_search.best_estimator_ # Get the best model instance
    best_params_for_repetitions[name] = grid_search.best_params_ # Store the best hyperparameters

    y_pred = trained_models[name].predict(X_test_reshaped) # Make predictions on the test set

    # Format best parameters string
    best_params_str = str(best_params_for_repetitions[name])
    if name == "Our Method":
        # For our custom method, also display the number of segments
        num_segments = len(trained_models[name].segments_) if hasattr(trained_models[name], 'segments_') else 0
        best_params_str += f", K={num_segments}"

    # Record evaluation metrics
    results.append({
        "Model": name,
        "RMSE": np.sqrt(mean_squared_error(y_test, y_pred)),
        "MAE": mean_absolute_error(y_test, y_pred),
        "R²": r2_score(y_test, y_pred),
        "Best Params": best_params_str
    })

    print(f"Model {name} tuning completed. Best score (neg MSE on CV): {grid_search.best_score_:.4f}")
    print(f"Best parameters found: {best_params_for_repetitions[name]}\n")

results_df_initial = pd.DataFrame(results)
results_df_initial.sort_values(by="RMSE", inplace=True) # Sort in ascending order by RMSE

# -------------------------------------------------------------------
# 5. Result Presentation (Initial Grid Search Results)
# -------------------------------------------------------------------
print("\n" + "=" * 80)
print(" " * 15 + "ALGORITHM PERFORMANCE COMPARISON (Initial Tuning Results)")
print("=" * 80)
with pd.option_context('display.max_colwidth', None, 'display.width', 1000):
    print(results_df_initial.to_string(index=False))
print("=" * 80)

# -------------------------------------------------------------------
# 6. Repeat Training and Evaluation N times, and Calculate Mean and Standard Deviation
# -------------------------------------------------------------------
N_REPETITIONS = 10 # Number of repetitions
# Initialize dictionary to store metrics for each run
all_run_results = {name: {'RMSE': [], 'MAE': [], 'R2': []} for name in models_to_tune.keys()}

print(f"\n--- Starting {N_REPETITIONS} repetitions with final tuned parameters ---")

for i in range(N_REPETITIONS):
    print(f"--- Repetition {i + 1}/{N_REPETITIONS} ---")

    # Use a different random seed for data splitting in each repetition, to simulate the effect of different data sampling
    current_random_state = RANDOM_STATE + i
    x_train_i, x_test_i, y_train_i, y_test_i = train_test_split(x_data, y_data, test_size=TEST_SIZE,
                                                                random_state=current_random_state)

    X_train_reshaped_i = x_train_i.reshape(-1, 1)
    X_test_reshaped_i = x_test_i.reshape(-1, 1)

    # --- All models (including Our Method) are trained and evaluated with the same logic ---
    for name, model_info in models_to_tune.items():
        base_estimator_class = model_info["estimator"].__class__
        best_params = best_params_for_repetitions[name].copy()  # Copy dictionary to avoid modifying the original

        current_model_instance = None # Initialize model instance

        # Instantiate and set random state based on model type (if supported)
        if name == "XGBoost":
            # Ensure XGBoost specific parameters are set, if they are not in best_params
            if 'objective' not in best_params: best_params['objective'] = 'reg:squarederror'
            if 'eval_metric' not in best_params: best_params['eval_metric'] = 'rmse'
            if 'use_label_encoder' not in best_params: best_params['use_label_encoder'] = False
            current_model_instance = base_estimator_class(random_state=current_random_state, **best_params)
        elif name == "CART":
            # CART (DecisionTreeRegressor) accepts random_state
            current_model_instance = base_estimator_class(random_state=current_random_state, **best_params)
        elif name == "Our Method":
            # PiecewiseLinearRegressor does not accept random_state
            best_params.pop('random_state', None) # Safely remove, even if not present
            current_model_instance = base_estimator_class(**best_params)
        else:
            # Generic handling for other potential models
            temp_instance = base_estimator_class()
            if hasattr(temp_instance, 'random_state') or 'random_state' in temp_instance.get_params():
                current_model_instance = base_estimator_class(random_state=current_random_state, **best_params)
            else:
                current_model_instance = base_estimator_class(**best_params)

        # Train the current model instance
        current_model_instance.fit(X_train_reshaped_i, y_train_i)

        # Predict and evaluate
        y_pred_i = current_model_instance.predict(X_test_reshaped_i)

        # Record evaluation metrics for this repetition
        all_run_results[name]["RMSE"].append(np.sqrt(mean_squared_error(y_test_i, y_pred_i)))
        all_run_results[name]["MAE"].append(mean_absolute_error(y_test_i, y_pred_i))
        all_run_results[name]["R2"].append(r2_score(y_test_i, y_pred_i))

# --- Calculate Mean and Standard Deviation of Metrics ---
final_summary_results = []
for model_name, metrics_dict in all_run_results.items():
    mean_rmse = np.mean(metrics_dict["RMSE"])
    std_rmse = np.std(metrics_dict["RMSE"])
    mean_mae = np.mean(metrics_dict["MAE"])
    std_mae = np.std(metrics_dict["MAE"])
    mean_r2 = np.mean(metrics_dict["R2"])
    std_r2 = np.std(metrics_dict["R2"])

    final_summary_results.append({
        "Model": model_name,
        # Use LaTeX math format to represent mean and standard deviation
        "RMSE (Mean \\( \\pm \\) Std)": f"\\( {mean_rmse:.4f} \\pm {std_rmse:.4f} \\)",
        "MAE (Mean \\( \\pm \\) Std)": f"\\( {mean_mae:.4f} \\pm {std_mae:.4f} \\)",
        "R² (Mean \\( \\pm \\) Std)": f"\\( {mean_r2:.4f} \\pm {std_r2:.4f} \\)"
    })

final_summary_df = pd.DataFrame(final_summary_results)
# Extract the mean part of RMSE for sorting, then remove the auxiliary column
final_summary_df['RMSE_Mean'] = final_summary_df['RMSE (Mean \\( \\pm \\) Std)'].apply(
    lambda x: float(x.split(' ')[1]))
final_summary_df.sort_values(by="RMSE_Mean", inplace=True)
final_summary_df.drop('RMSE_Mean', axis=1, inplace=True)

# The final DataFrame for display
final_combined_df = final_summary_df

# -------------------------------------------------------------------
# 7. Final Result Presentation (Mean and Standard Deviation over N Repetitions)
# -------------------------------------------------------------------
print("\n" + "=" * 100)
print(f" " * 10 + f"ALGORITHM PERFORMANCE METRICS (Mean ± Std Over {N_REPETITIONS} Repetitions)")
print("=" * 100)
with pd.option_context('display.max_colwidth', None, 'display.width', 1200):
    print(final_combined_df.to_string(index=False))
print("=" * 100)


# -------------------------------------------------------------------
# 8. Plotting Results (Using the Best Model from the Initial Grid Search)
# -------------------------------------------------------------------
def create_residual_plot(x_train, y_train, x_test, y_test, x_plot, true_func, segments, trained_models):
    """
    Generates plots of model predictions and their residuals.

    Args:
        x_train (np.ndarray): Training set x.
        y_train (np.ndarray): Training set y.
        x_test (np.ndarray): Test set x.
        y_test (np.ndarray): Test set y.
        x_plot (np.ndarray): Ordered x values for plotting prediction curves.
        true_func (callable): The true target function.
        segments (list): Segment information for "Our Method".
        trained_models (dict): Dictionary of other trained baseline models.
    """
    # Define plotting styles for different models and data points
    STYLE_CONFIG = {
        "Our Method": {"color": "#d62728", "linewidth": 3.0, "linestyle": "-", "alpha": 1, "label": "Our method"}, # Added label here
        "CART": {"color": "#1f77b4", "linewidth": 2.0, "linestyle": "--", "alpha": 1, "label": "CART"}, # Added label here
        "XGBoost": {"color": "#2ca02c", "linewidth": 2.0, "linestyle": ":", "alpha": 1, "label": "XGBoost"}, # Added label here
        "Training Data": {"color": "gray", "marker": "o", "s": 45, "alpha": 0.2, "label": "Training Data"},
        "Test Data": {"color": "gray", "marker": "x", "s": 30, "alpha": 0.2, "label": "Test Data"},
        "True Function": {"color": "#7f7f7f", "linewidth": 2, "linestyle": "--", "zorder": 0, "label": "True Function"}
    }

    # Create main plot and residual plot, sharing the x-axis
    fig, (ax_main, ax_res) = plt.subplots(2, 1, figsize=(8, 6), sharex=True,
                                          gridspec_kw={'height_ratios': [3, 1.5]})

    # Plot training data, test data, and true function
    ax_main.scatter(x_train, y_train, **STYLE_CONFIG["Training Data"])
    ax_main.scatter(x_test, y_test, **STYLE_CONFIG["Test Data"])



    y_true_plot = true_func(x_plot)

    # Plot predictions and residuals for "Our Method"
    if not segments:
        print("Warning: 'Our Method' generated no segments, skipping its prediction curve in the plot.")
        y_plot_our_sorted = np.full_like(x_plot, np.nan)
    else:
        y_plot_our_sorted = predict(x_plot, segments)
    # The 'label' is already in STYLE_CONFIG["Our Method"], so remove explicit 'label' here.
    ax_main.plot(x_plot, y_plot_our_sorted, **STYLE_CONFIG["Our Method"])
    if not np.all(np.isnan(y_plot_our_sorted)):
        res_our = y_plot_our_sorted - y_true_plot
        ax_res.plot(x_plot, res_our, **STYLE_CONFIG["Our Method"])

    # Plot predictions and residuals for other baseline models
    model_names_to_plot = ["CART", "XGBoost"] # Explicitly specify plotting order
    for name in model_names_to_plot:
        if name in trained_models:
            model = trained_models[name]
            y_plot_baseline = model.predict(x_plot.reshape(-1, 1))
            # The 'label' is already in STYLE_CONFIG[name], so remove explicit 'label' here.
            ax_main.plot(x_plot, y_plot_baseline, **STYLE_CONFIG[name])

            res_baseline = y_plot_baseline - y_true_plot
            ax_res.plot(x_plot, res_baseline, **STYLE_CONFIG[name])

    # Set labels and legend for the main plot
    ax_main.set_ylabel("y")
    ax_main.legend(loc='best')
    ax_main.set_ylim(y_data.min() - 0.1, y_data.max() + 0.1)

    # Set labels and zero line for the residual plot
    ax_res.axhline(0, color='black', linestyle='-', linewidth=1.5, zorder=0)
    ax_res.set_xlabel("x")
    ax_res.set_ylabel("r")

    plt.tight_layout() # Automatically adjust subplot parameters for a tight layout
    plt.savefig("model_comparison_with_residuals_sinc.pdf", bbox_inches='tight') # Save image
    plt.show()


print("\nGenerating plot with residuals (based on the best model from the initial run)...")

# Extract the best parameters from `best_params_for_repetitions` from the first grid search,
# then instantiate and train models for plotting results.
plot_trained_models = {}
for name, model_info in models_to_tune.items():
    base_estimator_class = model_info["estimator"].__class__
    best_params = best_params_for_repetitions[name].copy() # Get a copy of the best parameters

    current_model_instance = None
    if name == "XGBoost":
        if 'objective' not in best_params: best_params['objective'] = 'reg:squarederror'
        if 'eval_metric' not in best_params: best_params['eval_metric'] = 'rmse'
        if 'use_label_encoder' not in best_params: best_params['use_label_encoder'] = False
        current_model_instance = base_estimator_class(random_state=RANDOM_STATE, **best_params)
    elif name == "CART":
        current_model_instance = base_estimator_class(random_state=RANDOM_STATE, **best_params)
    elif name == "Our Method":
        best_params.pop('random_state', None) # PiecewiseLinearRegressor does not accept random_state
        current_model_instance = base_estimator_class(**best_params)
    else: # Generic handling for other potential models
        temp_instance = base_estimator_class()
        if hasattr(temp_instance, 'random_state') or 'random_state' in temp_instance.get_params():
            current_model_instance = base_estimator_class(random_state=RANDOM_STATE, **best_params)
        else:
            current_model_instance = base_estimator_class(**best_params)

    # Train the model for plotting (using the original training data)
    current_model_instance.fit(X_train_reshaped, y_train)
    plot_trained_models[name] = current_model_instance

# Separate "Our Method" and other baseline models
our_best_model_initial = plot_trained_models.pop("Our Method")
baseline_models_for_plot = plot_trained_models

# Call the plotting function
create_residual_plot(
    x_train=x_train,
    y_train=y_train,
    x_test=x_test,
    y_test=y_test,
    x_plot=x_test_sorted,
    true_func=true_func,
    segments=our_best_model_initial.segments_,
    trained_models=baseline_models_for_plot
)
