import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm

# Import machine learning models and utilities
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.tree import DecisionTreeRegressor
import xgboost as xgb
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator, RegressorMixin
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches
from matplotlib.colors import LightSource
import colorsys
import matplotlib.colors as mcolors
import matplotlib as mpl

# -------------------------------------------------------------------
# 2. Global Configuration and Data Generation (3D)
# -------------------------------------------------------------------
# Matplotlib Plotting Configuration
plt.rcParams.update({
    "font.size": 16,  # Adjust font size for 3D plot readability
    "legend.fontsize": 14,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "axes.labelsize": 14,  # Added configuration for 3D axis labels
    "font.family": "serif",
})

# Target function (3D)
def true_func_3d(x1, x2):
    """
    Defines the true underlying 3D function to be approximated.

    """
# def complex_poly_trig_mix(x1, x2):
    return (0.5 * x1 ** 3 - 2 * x1 * x2 ** 2 +
            3 * np.sin(4 * x1) * np.cos(2 * x2) +
            0.1 * np.exp(-(x1 ** 2 + x2 ** 2)))
#
#
# def oscillating_3d(x1, x2):  # good 图很适合 -33
#     return np.sin(3 * x1) + np.cos(2 * x2) + 0.5 * np.sin(5 * x1) * np.cos(4 * x2)


#
# def rational_with_singularity(x1, x2):
#     r = np.sqrt(x1 ** 2 + x2 ** 2) + 1e-6
#     return (x1 ** 2 - x2 ** 2) / (0.5 + r ** 2) + np.sin(r) * np.exp(-r)
#
#
# def asymmetric_bimodal(x1, x2):
#     gauss1 = np.exp(-((x1 - 1) ** 2 + (x2 - 1) ** 2) / 0.5)
#     gauss2 = np.exp(-((x1 + 1) ** 2 + (x2 + 1.5) ** 2) / 0.3)
#     return 2 * gauss1 - 3 * gauss2 + 0.5 * x1
# #

# Generate synthetic data (3D)
np.random.seed(0)
num_points_per_dim = 100  # Number of points per dimension, total points approx. 100*100 = 10000
line_1 = 3
x_lin = np.linspace(-line_1, line_1, num_points_per_dim)
y_lin = np.linspace(-line_1, line_1, num_points_per_dim)
X_grid, Y_grid = np.meshgrid(x_lin, y_lin)  # Generate grid data for full dataset

x_lin1 = np.linspace(-line_1, line_1, 20)
y_lin1 = np.linspace(-line_1, line_1, 20)
X_grid1, Y_grid1 = np.meshgrid(x_lin1, y_lin1)  # Generate coarser grid data for plotting surfaces

# Flatten grid data for model input
x_flat = X_grid.flatten()
y_flat = Y_grid.flatten()

# Combine x and y into an array of shape (n_samples, n_features) for scikit-learn models
data_input_3d = np.column_stack((x_flat, y_flat))

noise_level_3d = 0.05  # Adjust noise level
z_true_3d = true_func_3d(x_flat, y_flat)  # Use the defined target function
z_data_3d = z_true_3d + np.random.normal(0, noise_level_3d, size=z_true_3d.shape)

# Algorithm hyperparameters for the piecewise linear fitting
MIN_POINTS_FOR_SPLIT = 5  # At least 3 points are needed to fit a plane; 5 points provide more robustness.

# -------------------------------------------------------------------
# 3. Core Algorithm Functions (Piecewise Linear Fitting Algorithm - 3D)
# -------------------------------------------------------------------

class Node:
    """Represents a node in the piecewise linear fitting tree."""
    def __init__(self, is_leaf, region, params=None, split_coeffs=None, children=None, stop_reason=None):
        self.is_leaf = is_leaf
        self.region = region
        self.params = params  # Plane coefficients (a, b, c) if leaf node
        self.split_coeffs = split_coeffs  # Splitting plane coefficients if internal node
        self.children = children  # List of child nodes
        self.stop_reason = stop_reason # Reason for becoming a leaf node

    def __repr__(self):
        if self.is_leaf:
            return f"LeafNode(params={np.round(self.params, 2)}, stop_reason='{self.stop_reason}')"
        else:
            return f"InternalNode(split_coeffs={np.round(self.split_coeffs, 2)}, children={len(self.children) if self.children else 0})"

    def count_leaves(self):
        """Recursively counts the number of leaf nodes in the subtree."""
        if self.is_leaf:
            return 1
        if not self.children:
            return 0
        return sum(child.count_leaves() for child in self.children)

def linear_model_3d(X, a, b, c):
    """Calculates the output of a 3D plane: y = ax1 + bx2 + c."""
    X = np.asanyarray(X)
    return a * X[:, 0] + b * X[:, 1] + c

def calculate_rmse(y_true, y_pred):
    """Calculates the Root Mean Squared Error."""
    return np.sqrt(np.mean((y_true - y_pred) ** 2))

def fit_single_plane(X, Z):
    """
    Fits a single plane (ax1 + bx2 + c = z) to the given data points using OLS.
    Returns coefficients (a, b, c).
    """
    if len(Z) < 3: # Need at least 3 points to define a plane
        return None, None, None
    X_des = np.hstack([X, np.ones((X.shape[0], 1))]) # Design matrix [x1, x2, 1]
    try:
        theta, _, _, _ = np.linalg.lstsq(X_des, Z, rcond=None)
        return theta[0], theta[1], theta[2]
    except np.linalg.LinAlgError:
        return None, None, None

def solve_ols_3d(X_des, Z):
    """
    Solves Ordinary Least Squares for 3D plane fitting given design matrix X_des.
    Handles cases with insufficient data points by returning a default plane.
    """
    if len(Z) < X_des.shape[1]: # Not enough points to uniquely determine parameters
        mean_z_val = np.mean(Z) if len(Z) > 0 else 0
        return np.array([0, 0, mean_z_val]) # Return a horizontal plane at mean Z
    try:
        theta, _, _, _ = np.linalg.lstsq(X_des, Z, rcond=None)
        return theta
    except np.linalg.LinAlgError: # Fallback in case of numerical instability
        mean_z_val = np.mean(Z) if len(Z) > 0 else 0
        return np.array([0, 0, mean_z_val])

def fit_plane_from_3_points(p1, p2, p3):
    """
    Fits a plane from three 3D points (x, y, z).
    Returns plane coefficients (a, b, c) such that z = ax + by + c.
    """
    M = np.array([
        [p1[0], p1[1], 1],
        [p2[0], p2[1], 1],
        [p3[0], p3[1], 1]
    ])
    target_z = np.array([p1[2], p2[2], p3[2]])
    try:
        params = np.linalg.solve(M, target_z)
        return params
    except np.linalg.LinAlgError:
        return None

def initialize_thetas_guaranteed_intersect_3d(X, Z, init_strategy=0, random_seed=None):
    """
    Initializes two plane parameters (theta1, theta2) such that they are guaranteed to intersect.
    This helps to provide a robust starting point for the iterative split optimization.
    """
    n_points = len(Z)
    mean_z = np.mean(Z) if n_points > 0 else 0
    if n_points < 3: # Not enough points to define even one plane
        return np.array([0, 0, mean_z]), np.array([0, 0, mean_z])
    if random_seed is not None:
        np.random.seed(random_seed)

    # Select anchor points near the center of the data
    x_median_val = np.median(X[:, 0])
    y_median_val = np.median(X[:, 1])

    idx_anchor1 = np.argmin(np.abs(X[:, 0] - x_median_val))
    p_anchor1_data = np.array([X[idx_anchor1, 0], X[idx_anchor1, 1], Z[idx_anchor1]])

    distances_y = np.abs(X[:, 1] - y_median_val)
    temp_distances_y = distances_y.copy()
    temp_distances_y[idx_anchor1] = np.inf # Exclude anchor1 from being anchor2
    idx_anchor2 = np.argmin(temp_distances_y)

    # Handle edge cases where finding two distinct anchors is hard
    if np.isinf(temp_distances_y[idx_anchor2]) or n_points < 2:
        if n_points >= 2:
            # Fallback: pick two distinct points, e.g., min and max x-coordinates
            idx_temp_anchor1 = np.argmin(X[:, 0])
            idx_temp_anchor2 = np.argmax(X[:, 0])
            if idx_temp_anchor1 == idx_temp_anchor2 and n_points > 1:
                other_indices = np.delete(np.arange(n_points), idx_temp_anchor1)
                if len(other_indices) > 0:
                    idx_temp_anchor2 = np.random.choice(other_indices)
                else:
                    return np.array([0, 0, mean_z]), np.array([0, 0, mean_z])
            p_anchor1_data = np.array([X[idx_temp_anchor1, 0], X[idx_temp_anchor1, 1], Z[idx_temp_anchor1]])
            p_anchor2_data = np.array([X[idx_temp_anchor2, 0], X[idx_temp_anchor2, 1], Z[idx_temp_anchor2]])
        else:
            return np.array([0, 0, mean_z]), np.array([0, 0, mean_z])

    p_anchor2_data = np.array([X[idx_anchor2, 0], X[idx_anchor2, 1], Z[idx_anchor2]])

    # Select two additional points to define initial planes
    # Strategy 0: Use min/max x-values
    # Strategy 1: Use min/max y-values
    # Strategy 2: Use random points
    if init_strategy == 0:
        idx_boundary1 = np.argmin(X[:, 0])
        idx_boundary2 = np.argmax(X[:, 0])
    elif init_strategy == 1:
        idx_boundary1 = np.argmin(X[:, 1])
        idx_boundary2 = np.argmax(X[:, 1])
    elif init_strategy == 2:
        all_indices = np.arange(n_points)
        excluded_indices = list(set([idx_anchor1, idx_anchor2]))
        valid_indices_for_random = np.setdiff1d(all_indices, excluded_indices)
        if len(valid_indices_for_random) < 2: # Not enough points left to pick two more
            return np.array([0, 0, mean_z]), np.array([0, 0, mean_z])
        rand_indices = np.random.choice(valid_indices_for_random, 2, replace=False)
        idx_boundary1, idx_boundary2 = rand_indices[0], rand_indices[1]
    else:
        return np.array([0, 0, mean_z]), np.array([0, 0, mean_z])

    p_boundary1_data = np.array([X[idx_boundary1, 0], X[idx_boundary1, 1], Z[idx_boundary1]])
    p_boundary2_data = np.array([X[idx_boundary2, 0], X[idx_boundary2, 1], Z[idx_boundary2]])

    def get_robust_plane_params(points_for_plane, fallback_z_val):
        """Helper to fit a plane from 3 points, handling colinearity/duplicates."""
        unique_points_list = []
        for p in points_for_plane:
            is_unique = True
            for up in unique_points_list:
                if np.allclose(p, up, atol=1e-9):
                    is_unique = False
                    break
            if is_unique: unique_points_list.append(p)

        unique_points = np.array(unique_points_list)
        if len(unique_points) < 3:
            return np.array([0, 0, fallback_z_val]) # Fallback to horizontal plane
        
        # Try to fit from the first 3 unique points
        params = fit_plane_from_3_points(unique_points[0], unique_points[1], unique_points[2])
        if params is None:
            return np.array([0, 0, fallback_z_val]) # Fallback if fitting fails
        return params

    # Construct two planes, each using the two anchor points and one boundary point
    points_plane1_candidates = [p_boundary1_data, p_anchor1_data, p_anchor2_data]
    theta1 = get_robust_plane_params(points_plane1_candidates, mean_z)

    points_plane2_candidates = [p_boundary2_data, p_anchor1_data, p_anchor2_data]
    theta2 = get_robust_plane_params(points_plane2_candidates, mean_z)

    # Ensure the two planes are not identical to allow for a split
    if np.allclose(theta1, theta2, atol=1e-6):
        epsilon = np.random.rand(3) * 0.005 # Perturb slightly
        theta1 = theta1 + epsilon
        theta2 = theta2 - epsilon
        theta1 = np.nan_to_num(theta1)
        theta2 = np.nan_to_num(theta2)
        if np.allclose(theta1, theta2, atol=1e-9): # If still identical after perturbation
            theta1[0] += 0.001 # Add a small value to one coefficient
            theta1 = np.nan_to_num(theta1)
            theta2 = np.nan_to_num(theta2)
    return theta1, theta2

def optimize_split_optimized_3d(X, Z, max_iter=100, tol=1e-6, min_points_per_segment=MIN_POINTS_FOR_SPLIT,
                               random_seed_for_init=None):
    """
    Optimizes a piecewise planar fit in 3D space by iteratively adjusting two planes.
    It returns the splitting plane coefficients, parameters of the two fitted planes,
    and masks for the left and right segments.
    """
    n_points = X.shape[0]
    min_points_to_fit = 3  # At least 3 points are needed to fit a 3D plane

    # Ensure sufficient data for a valid split
    if n_points < 2 * min_points_to_fit or n_points < 2 * min_points_per_segment:
        return None, None, None, None, None

    # Construct design matrix: [x1, x2, 1]
    X_des = np.hstack([X, np.ones((n_points, 1))])

    # Initialize two planes using the 3D-specific initialization function
    theta1, theta2 = initialize_thetas_guaranteed_intersect_3d(X, Z, init_strategy=0, random_seed=random_seed_for_init)

    # Validate initial parameters
    if np.any(np.isnan(theta1)) or np.any(np.isinf(theta1)) or \
            np.any(np.isnan(theta2)) or np.any(np.isinf(theta2)):
        return None, None, None, None, None

    prev_left_mask = None
    step_size = 1.0 # Step size for parameter update in each iteration
    for i in range(max_iter):
        theta1_old, theta2_old = theta1.copy(), theta2.copy()

        # Calculate prediction difference: (a1*x1 + b1*x2 + c1) - (a2*x1 + b2*x2 + c2)
        # Points where this difference is negative belong to the 'left' segment (closer to plane 1)
        prediction_diff = X_des @ (theta1 - theta2)
        current_left_mask = prediction_diff < 0

        # Check for convergence based on mask stability
        if prev_left_mask is not None and np.array_equal(current_left_mask, prev_left_mask):
            break
        prev_left_mask = current_left_mask

        n_left, n_right = np.sum(current_left_mask), n_points - np.sum(current_left_mask)

        # Validate split effectiveness (enough points in both segments)
        if n_left < min_points_per_segment or n_right < min_points_per_segment or \
                n_left < min_points_to_fit or n_right < min_points_to_fit:
            return None, None, None, None, None

        try:
            # Re-fit the two planes based on the new masks
            theta1_target = solve_ols_3d(X_des[current_left_mask], Z[current_left_mask])
            theta2_target = solve_ols_3d(X_des[~current_left_mask], Z[~current_left_mask])

            # Update parameters smoothly
            theta1 = (1 - step_size) * theta1_old + step_size * theta1_target
            theta2 = (1 - step_size) * theta2_old + step_size * theta2_target
        except np.linalg.LinAlgError: # Catch potential numerical errors during OLS solve
            theta1, theta2 = theta1_old, theta2_old # Revert to old params
            break

        # Validate updated parameters
        if np.any(np.isnan(theta1)) or np.any(np.isinf(theta1)) or \
                np.any(np.isnan(theta2)) or np.any(np.isinf(theta2)):
            theta1, theta2 = theta1_old, theta2_old
            break

        # Check for parameter convergence
        if np.linalg.norm(theta1 - theta1_old) < tol and np.linalg.norm(theta2 - theta2_old) < tol:
            break

    # Determine final splitting mask
    final_prediction_diff = X_des @ (theta1 - theta2)
    final_left_mask = final_prediction_diff < 0
    final_right_mask = ~final_left_mask

    # Final check for split validity
    if np.sum(final_left_mask) < min_points_per_segment or np.sum(final_right_mask) < min_points_per_segment or \
            np.sum(final_left_mask) < min_points_to_fit or np.sum(final_right_mask) < min_points_to_fit:
        return None, None, None, None, None

    # Splitting plane coefficients: (a1-a2)*x1 + (b1-b2)*x2 + (c1-c2) = 0
    split_plane_coeffs = theta1 - theta2
    return split_plane_coeffs, theta1, theta2, final_left_mask, final_right_mask

def recursive_piecewise_fit_3d(X, Z, threshold, min_points=MIN_POINTS_FOR_SPLIT, depth=0, max_depth=5,
                               current_run_random_state=None):
    """
    Recursively performs piecewise linear fitting on 3D data, returning the root node of the fitted tree.
    X is an array of shape (N, 2), Z is an array of shape (N,).
    """
    n_points = X.shape[0]
    min_points_to_fit = 3  # At least 3 points are needed to fit a 3D plane

    # Determine current region boundaries for visualization and debugging
    if n_points > 0:
        current_x_min, current_x_max = np.min(X[:, 0]), np.max(X[:, 0])
        current_y_min, current_y_max = np.min(X[:, 1]), np.max(X[:, 1])
        mean_z_for_fallback = np.mean(Z)
        current_region = {'x_range': (current_x_min, current_x_max), 'y_range': (current_y_min, current_y_max)}
    else:
        # Handle empty segment
        return Node(is_leaf=True, params=np.array([np.nan, np.nan, np.nan]),
                    region={'x_range': (np.nan, np.nan), 'y_range': (np.nan, np.nan)},
                    stop_reason='empty_segment')

    # Base case 1: Not enough points to fit a plane
    if n_points < min_points_to_fit:
        return Node(is_leaf=True, params=np.array([0.0, 0.0, mean_z_for_fallback]), region=current_region,
                    stop_reason='not_enough_points_for_fit')

    # Fit a single plane to the current segment
    a_single, b_single, c_single = fit_single_plane(X, Z)
    if a_single is None or np.any(np.isnan([a_single, b_single, c_single])):
        # Fallback if fitting fails
        return Node(is_leaf=True, params=np.array([0.0, 0.0, mean_z_for_fallback]), region=current_region,
                    stop_reason='fit_failed_numerically')

    theta_single = np.array([a_single, b_single, c_single])
    rmse_single = calculate_rmse(Z, linear_model_3d(X, a_single, b_single, c_single))

    # Base case 2: Maximum depth reached
    if depth >= max_depth:
        return Node(is_leaf=True, params=theta_single, region=current_region,
                    stop_reason='max_depth_reached')

    # Base case 3: RMSE threshold met (good enough fit)
    if rmse_single <= threshold:
        return Node(is_leaf=True, params=theta_single, region=current_region,
                    stop_reason='RMSE_threshold_met')

    # Pre-check for splitting: ensure enough points for a valid split
    if n_points < max(2 * min_points_to_fit, 2 * min_points):
        return Node(is_leaf=True, params=theta_single, region=current_region,
                    stop_reason='not_enough_points_for_split_pre_check')

    # Attempt to optimize a split
    max_init_attempts = 1 # Number of attempts to initialize and optimize split
    split_successful = False
    split_coeffs, params_left, params_right, mask_left, mask_right = None, None, None, None, None

    for attempt_num in range(max_init_attempts):
        seed_for_attempt = (
            current_run_random_state if current_run_random_state is not None else 0) + depth * 1000 + attempt_num
        split_coeffs, params_left, params_right, mask_left, mask_right = \
            optimize_split_optimized_3d(X, Z, min_points_per_segment=min_points,
                                       random_seed_for_init=seed_for_attempt)

        # Check if the optimized split is valid
        if split_coeffs is not None and params_left is not None and params_right is not None and \
                np.sum(mask_left) >= min_points and np.sum(mask_right) >= min_points and \
                np.sum(mask_left) >= min_points_to_fit and np.sum(mask_right) >= min_points_to_fit:
            split_successful = True
            break

    # Fallback to a median split if optimized split fails or is invalid
    if not split_successful:
        x_range = current_x_max - current_x_min
        y_range = current_y_max - current_y_min
        if x_range >= y_range: # Split along the wider dimension (x-axis)
            median_val = np.median(X[:, 0])
            mask_left = X[:, 0] < median_val
            mask_right = ~mask_left
            split_coeffs = np.array([1.0, 0.0, -median_val]) # Equation x - median_val = 0
        else: # Split along y-axis
            median_val = np.median(X[:, 1])
            mask_left = X[:, 1] < median_val
            mask_right = ~mask_left
            split_coeffs = np.array([0.0, 1.0, -median_val]) # Equation y - median_val = 0

        # Check if fallback split is valid
        if np.sum(mask_left) < min_points or np.sum(mask_right) < min_points or \
                np.sum(mask_left) < min_points_to_fit or np.sum(mask_right) < min_points_to_fit:
            return Node(is_leaf=True, params=theta_single, region=current_region,
                        stop_reason='fallback_split_failed_insufficient_points')

    # Recursive calls for child nodes
    left_child = recursive_piecewise_fit_3d(X[mask_left], Z[mask_left], threshold, min_points, depth + 1,
                                            max_depth, current_run_random_state)
    right_child = recursive_piecewise_fit_3d(X[mask_right], Z[mask_right], threshold, min_points, depth + 1,
                                             max_depth, current_run_random_state)

    return Node(is_leaf=False, split_coeffs=split_coeffs, children=[left_child, right_child], region=current_region)

def _predict_single_point(x_val_arr, node):
    """Helper function to predict a single point by traversing the tree."""
    if node.is_leaf:
        a, b, c = node.params
        if np.isnan(a) or np.isnan(b) or np.isnan(c):
            return 0.0 # Return a default value if plane parameters are invalid
        # Ensure the linear model receives 2D X, not 3D [x, y, 1]
        return linear_model_3d(x_val_arr[:2].reshape(1, -1), a, b, c)

    # split_coeffs format is (A, B, C) corresponding to Ax1 + Bx2 + C = 0
    # x_val_arr is [x1, x2, 1]
    # Decision is based on split_coeffs[0]*x1 + split_coeffs[1]*x2 + split_coeffs[2]*1
    decision_value = node.split_coeffs[0] * x_val_arr[0] + \
                     node.split_coeffs[1] * x_val_arr[1] + \
                     node.split_coeffs[2] * x_val_arr[2]  # x_val_arr[2] is effectively 1

    # Determine on which side of the splitting plane the point lies
    if decision_value < -1e-9: # Use a small tolerance for numerical stability
        return _predict_single_point(x_val_arr, node.children[0])
    else:
        return _predict_single_point(x_val_arr, node.children[1])

def predict_3d(X_new, root_node):
    """
    Predicts Z values for new 3D data points using the trained piecewise linear tree.
    X_new is an array of shape (N, 2).
    """
    predictions = np.zeros(X_new.shape[0], dtype=float)
    if root_node is None:
        # If model is not trained or training failed, return a default prediction (e.g., mean of training Z)
        return np.full(X_new.shape[0], np.mean(z_data_3d) if len(z_data_3d) > 0 else 0.0)
    for i, (x_val, y_val) in enumerate(X_new):
        x_val_arr = np.array([x_val, y_val, 1])  # Construct input in [x1, x2, 1] format
        predictions[i] = _predict_single_point(x_val_arr, root_node)[0]
    return predictions

# -------------------------------------------------------------------
# 4. Experiment Main Flow (Updated for 3D Data)
# -------------------------------------------------------------------

# --- 4.1 Data Splitting ---
TEST_SIZE = 0.3
RANDOM_STATE = 42

X_train_full, X_test, Z_train_full, Z_test = train_test_split(data_input_3d, z_data_3d, test_size=TEST_SIZE,
                                                              random_state=RANDOM_STATE)

# --- Define OurMethodRegressor wrapper for GridSearchCV ---
class OurMethodRegressor(BaseEstimator, RegressorMixin):
    """
    Wrapper for our piecewise linear fitting algorithm to be compatible with scikit-learn's API.
    """
    def __init__(self, threshold=0.03, min_points=MIN_POINTS_FOR_SPLIT, max_depth=5, random_state=None):
        self.threshold = threshold
        self.min_points = min_points
        self.max_depth = max_depth
        self.random_state = random_state  # Store random_state for fit calls
        self.root_node = None

    def fit(self, X, y):
        """Fits the piecewise linear model to the training data."""
        self.root_node = recursive_piecewise_fit_3d(
            X, y,
            threshold=self.threshold,
            min_points=self.min_points,
            max_depth=self.max_depth,
            current_run_random_state=self.random_state  # Pass random_state to recursive function
        )
        return self

    def predict(self, X):
        """Predicts target values for new input data."""
        if self.root_node is None:
            # If model is not trained or training failed, return a default prediction (e.g., mean of training Z)
            return np.full(X.shape[0], np.mean(Z_train_full) if len(Z_train_full) > 0 else 0.0)
        return predict_3d(X, self.root_node)

    def get_params(self, deep=True):
        """Returns parameters of the regressor."""
        # Ensure random_state is returned in get_params
        return {"threshold": self.threshold, "min_points": self.min_points, "max_depth": self.max_depth,
                "random_state": self.random_state}

    def set_params(self, **parameters):
        """Sets parameters of the regressor."""
        # Ensure set_params correctly sets random_state
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

# --- 4.2 Define All Models and Their Hyperparameter Grids (Including Our Method) ---
models_to_tune = {
    "Our Method": {
        "estimator": OurMethodRegressor(),
        "params": {
            "threshold": [0.01, 0.03, 0.05],
            "max_depth": [4, 6, 8, 10, 12],
            # "min_points": [5, 10] # Can be added for tuning if desired
        }
    },
    "CART": {
        "estimator": DecisionTreeRegressor(random_state=RANDOM_STATE),
        "params": {
            "max_depth": [3, 5, 7, 9, 11],
            "min_samples_split": [2, 5, 10],
            "min_samples_leaf": [1, 2, 4]
        }
    },
    "XGBoost": {  # Changed GBDT to XGBoost
        "estimator": xgb.XGBRegressor(random_state=RANDOM_STATE, objective='reg:squarederror', eval_metric='rmse',
                                      use_label_encoder=False),
        "params": {
            "n_estimators": [50, 100, 200],
            "learning_rate": [0.01, 0.1, 0.2],
            "max_depth": [2, 3, 5],
            "subsample": [0.7, 1.0] # XGBoost specific parameter
        }
    }
}

# --- 4.3 Perform Grid Search and Train Best Models ---
results = []
trained_models = {}
best_params_for_repetitions = {}

print("\n--- Starting Hyperparameter Tuning for All Models (this may take a moment) ---")
for name, model_info in models_to_tune.items():
    print(f"Tuning {name}...")
    estimator = model_info["estimator"]
    params = model_info["params"]
    preprocessing = model_info.get("preprocessing")

    if preprocessing is not None:
        pipeline = Pipeline([('scaler', preprocessing), ('regressor', estimator)])
        grid_params = {f'regressor__{key}': value for key, value in params.items()}
        current_estimator_for_gs = pipeline
    else:
        grid_params = params
        current_estimator_for_gs = estimator

    grid_search = GridSearchCV(current_estimator_for_gs, grid_params, cv=5, scoring='neg_mean_squared_error', n_jobs=-1)

    # Address XGBoost use_label_encoder warning before fitting
    if name == "XGBoost" and isinstance(grid_search.estimator, xgb.XGBRegressor):
        if 'use_label_encoder' not in grid_search.estimator.get_params():
            grid_search.estimator.set_params(use_label_encoder=False)
    elif name == "XGBoost" and isinstance(grid_search.estimator, Pipeline):
        if 'use_label_encoder' not in grid_search.estimator.named_steps['regressor'].get_params():
            grid_search.estimator.named_steps['regressor'].set_params(use_label_encoder=False)

    grid_search.fit(X_train_full, Z_train_full)
    trained_models[name] = grid_search.best_estimator_
    if preprocessing is not None:
        extracted_params = {k.replace('regressor__', ''): v for k, v in grid_search.best_params_.items()}
        best_params_for_repetitions[name] = extracted_params
    else:
        best_params_for_repetitions[name] = grid_search.best_params_
    z_pred = trained_models[name].predict(X_test)

    best_params_str = str(best_params_for_repetitions[name])
    if name == "Our Method":
        # If OurMethodRegressor has a root_node attribute, count leaf nodes as number of segments
        # Note: Ensure best_estimator_ has root_node and is fitted
        if hasattr(trained_models[name], 'root_node') and trained_models[name].root_node is not None:
            num_segments = trained_models[name].root_node.count_leaves()
        else:
            num_segments = 0
        best_params_str += f", K={num_segments}"

    results.append({
        "Model": name, "RMSE": np.sqrt(mean_squared_error(Z_test, z_pred)),
        "MAE": mean_absolute_error(Z_test, z_pred), "R²": r2_score(Z_test, z_pred),
        "Best Params": best_params_str,
    })

    print(f"Finished tuning {name}. Best score (neg MSE on CV): {grid_search.best_score_:.4f}")
    print(f"Best params found: {best_params_for_repetitions[name]}\n")
results_df_initial = pd.DataFrame(results)
results_df_initial.sort_values(by="RMSE", inplace=True)  # Sort by RMSE

# -------------------------------------------------------------------
# 5. Display Results (First Grid Search Results)
# -------------------------------------------------------------------
print("\n" + "=" * 80)
print(" " * 15 + "ALGORITHM PERFORMANCE COMPARISON (Initial Tuning Results)")
print("=" * 80)
with pd.option_context('display.max_colwidth', None, 'display.width', 1000):
    print(results_df_initial.to_string(index=False))
print("=" * 80)

# -------------------------------------------------------------------
# 6. Repeat Training and Evaluation N Times, Calculate Mean and Standard Deviation
# -------------------------------------------------------------------
N_REPETITIONS = 5
all_run_results = {name: {'RMSE': [], 'MAE': [], 'R2': []} for name in models_to_tune.keys()}

print(f"\n--- Starting {N_REPETITIONS} repetitions with final tuned parameters ---")
for i in range(N_REPETITIONS):
    print(f"\n--- Repetition {i + 1}/{N_REPETITIONS} ---")
    current_random_state = RANDOM_STATE + i
    X_train_full_i, X_test_i, Z_train_full_i, Z_test_i = train_test_split(
        data_input_3d, z_data_3d, test_size=TEST_SIZE, random_state=current_random_state)

    for name, model_info in models_to_tune.items():
        base_estimator_class = model_info["estimator"].__class__
        best_params = best_params_for_repetitions[name].copy()  # Copy dictionary to avoid modifying the original

        current_model_instance = None  # Initialize

        if name == "XGBoost":
            # Ensure XGBoost specific parameters are set if not in best_params
            if 'objective' not in best_params: best_params['objective'] = 'reg:squarederror'
            if 'eval_metric' not in best_params: best_params['eval_metric'] = 'rmse'
            if 'use_label_encoder' not in best_params: best_params['use_label_encoder'] = False
            # XGBoost directly accepts random_state
            current_model_instance = base_estimator_class(random_state=current_random_state, **best_params)
        elif name == "CART":
            # CART (DecisionTreeRegressor) accepts random_state
            current_model_instance = base_estimator_class(random_state=current_random_state, **best_params)
        elif name == "Our Method":  # OurMethodRegressor directly accepts random_state
            # Ensure random_state is returned in get_params and set in set_params
            current_model_instance = base_estimator_class(random_state=current_random_state, **best_params)
        else:  # Generic handling for potential future models
            temp_instance = base_estimator_class()
            if hasattr(temp_instance, 'random_state') or 'random_state' in temp_instance.get_params():
                current_model_instance = base_estimator_class(random_state=current_random_state, **best_params)
            else:
                current_model_instance = base_estimator_class(**best_params)

        current_model_instance.fit(X_train_full_i, Z_train_full_i)

        z_pred_i = current_model_instance.predict(X_test_i)
        all_run_results[name]["RMSE"].append(np.sqrt(mean_squared_error(Z_test_i, z_pred_i)))
        all_run_results[name]["MAE"].append(mean_absolute_error(Z_test_i, z_pred_i))
        all_run_results[name]["R2"].append(r2_score(Z_test_i, z_pred_i))

# --- Calculate Mean and Standard Deviation of Metrics ---
final_summary_results = []
for model_name, metrics_dict in all_run_results.items():
    row = {
        "Model": model_name,
        "RMSE (Mean \( \pm \) Std)": f"\( {np.mean(metrics_dict['RMSE']):.4f} \\pm {np.std(metrics_dict['RMSE']):.4f} \)",
        "MAE (Mean \( \pm \) Std)": f"\( {np.mean(metrics_dict['MAE']):.4f} \\pm {np.std(metrics_dict['MAE']):.4f} \)",
        "R² (Mean \( \pm \) Std)": f"\( {np.mean(metrics_dict['R2']):.4f} \\pm {np.std(metrics_dict['R2']):.4f} \)"
    }
    final_summary_results.append(row)
final_summary_df = pd.DataFrame(final_summary_results)
final_summary_df['RMSE_Mean'] = final_summary_df['RMSE (Mean \( \\pm \) Std)'].apply(lambda x: float(x.split(' ')[1]))
final_summary_df.sort_values(by="RMSE_Mean", inplace=True)
final_summary_df.drop('RMSE_Mean', axis=1, inplace=True)

# -------------------------------------------------------------------
# 7. Final Results Display (Mean and Standard Deviation over N Repetitions)
# -------------------------------------------------------------------
print("\n" + "=" * 100)
print(f" " * 5 + f"ALGORITHM PERFORMANCE (MEAN \\( \\pm \\) STD OVER {N_REPETITIONS} REPETITIONS)")
print("=" * 100)
with pd.option_context('display.max_colwidth', None, 'display.width', 1200):
    print(final_summary_df.to_string(index=False))
print("=" * 100)

# -------------------------------------------------------------------
# 8. Plot Results (Using Results from First Grid Search) - Updated for 3D Plotting
# -------------------------------------------------------------------
def auto_pick_point_colors(surface_cmap: str, min_hue_sep_to_surface=0.10, min_hue_sep_between_points=0.12):
    """Automatically selects visually distinct colors for data points based on surface colormap."""
    candidates = ["#0072B2", "#D55E00", "#009E73", "#CC79A7", "#56B4E9", "#E69F00", "#6A51A3"]
    cmap = cm.get_cmap(surface_cmap)
    sr, sg, sb, _ = cmap(0.55) # Get HSV of middle colormap value for contrast
    sh = colorsys.rgb_to_hsv(sr, sg, sb)[0]

    def hue(hex_color):
        r, g, b = mcolors.to_rgb(hex_color);
        return colorsys.rgb_to_hsv(r, g, b)[0]

    def hue_dist(h1, h2):
        d = abs(h1 - h2);
        return min(d, 1.0 - d)

    cand_hues = [(c, hue(c)) for c in candidates]
    cand_hues.sort(key=lambda x: hue_dist(x[1], sh), reverse=True) # Sort by contrast to surface
    first = None
    for c, h in cand_hues:
        if hue_dist(h, sh) >= min_hue_sep_to_surface: first = (c, h); break
    if first is None: first = cand_hues[0]
    second = None
    for c, h in cand_hues:
        if c == first[0]: continue
        if hue_dist(h, sh) >= min_hue_sep_to_surface and hue_dist(h, first[1]) >= min_hue_sep_between_points:
            second = (c, h);
            break
    if second is None: # Fallback if specific contrast conditions not met
        remain = [x for x in cand_hues if x[0] != first[0]]
        remain.sort(key=lambda x: hue_dist(x[1], first[1]), reverse=True)
        second = remain[0]
    return first[0], second[0]

def create_3d_plot1(
        X_train, Z_train, X_test, Z_test,
        X_grid_plot, Y_grid_plot,
        true_func, our_method_model,
        surface_cmap: str = "viridis",
        elev: float = 28, azim: float = -60, use_ortho: bool = True
):
    """
    Generates a 3D plot visualizing the Our Method's approximation,
    training data, and test data.
    """
    with plt.rc_context({
        "font.family": "serif", "font.size": 11, "axes.labelsize": 12, "axes.titlesize": 13,
        "legend.fontsize": 10, "axes.linewidth": 0.8, "savefig.transparent": True,
        "pdf.fonttype": 42, "ps.fonttype": 42,
    }):
        fig = plt.figure(figsize=(5.6, 4.6))
        ax = fig.add_subplot(111, projection="3d")
        ax.view_init(elev=elev, azim=azim)
        if use_ortho and hasattr(ax, "set_proj_type"): ax.set_proj_type("ortho")

        train_color, test_color = auto_pick_point_colors(surface_cmap)

        legend_elements = [
            Line2D([0], [0], color=train_color, marker="o", linestyle="None", markersize=4, alpha=0.3,
                   label="Training Data"),
            Line2D([0], [0], color=test_color, marker="x", linestyle="None", markersize=4, alpha=0.3,
                   label="Test Data"),
        ]

        # Plot Our Method's predicted surface
        if our_method_model is not None:
            X_plot_flat = np.column_stack((X_grid_plot.ravel(), Y_grid_plot.ravel()))
            Z_plot_flat = our_method_model.predict(X_plot_flat)
            Z_surf = Z_plot_flat.reshape(X_grid_plot.shape)

            cmap = cm.get_cmap(surface_cmap) # Get 'viridis' colormap
            ls = LightSource(azdeg=315, altdeg=45)

            # Shade surface using colormap
            face_rgb = ls.shade(Z_surf, cmap=cmap, vert_exag=0.8, blend_mode="soft")

            ax.plot_surface(X_grid_plot, Y_grid_plot, Z_surf, facecolors=face_rgb, rstride=1, cstride=1,
                            linewidth=0, antialiased=True, shade=False, alpha=0.8)

            # Draw contour lines, also using viridis colormap
            zmin, zmax = float(np.min(Z_surf)), float(np.max(Z_surf))
            dz = 0.05 * (zmax - zmin) if zmax > zmin else 1.0
            ax.contour(X_grid_plot, Y_grid_plot, Z_surf, zdir="z", offset=zmin - dz, cmap=cmap,
                       levels=10, linewidths=0.8, alpha=0.7)

            # Add Our Method to the legend
            legend_elements.append(
                mpatches.Patch(facecolor=cmap(0.55), edgecolor="none", label="Our Method")) # Use the middle color of the colormap for the legend
        else:
            print("Warning: 'Our Method' model not provided. Skipping surface plot.")

        # Plot training and test data points
        train_kw = dict(c=train_color, marker="o", s=10, alpha=0.3, linewidths=0)
        test_kw = dict(c=test_color, marker="x", s=10, alpha=0.3)
        # Ensure sufficient data points to prevent over-sparse sampling
        step_train = max(1, len(X_train) // 1000)
        step_test = max(1, len(X_test) // 1000)
        X_train_sample, Z_train_sample = X_train[::step_train], Z_train[::step_train]
        X_test_sample, Z_test_sample = X_test[::step_test], Z_test[::step_test]
        ax.scatter(X_test_sample[:, 0], X_test_sample[:, 1], Z_test_sample, **test_kw)
        ax.scatter(X_train_sample[:, 0], X_train_sample[:, 1], Z_train_sample, **train_kw)

        ax.set_xlabel(r"$x_1$");
        ax.set_ylabel(r"$x_2$");
        ax.set_zlabel(r"$y$")
        ax.grid(False)  # Remove background grid lines
        for pane in (ax.xaxis.pane, ax.yaxis.pane, ax.zaxis.pane):
            pane.set_edgecolor("none");
            pane.set_alpha(0.0)
        try: # Set aspect ratio for better visualization
            ax.set_box_aspect((1, 1, 0.6))
        except Exception:
            pass
        ax.legend(handles=legend_elements, frameon=False, loc="upper left", bbox_to_anchor=(0.02, 0.98))
        fig.tight_layout(pad=0.3)
        fig.savefig("3d_our_method_approximation.pdf", bbox_inches="tight", transparent=True) # Modify save filename
        plt.show()

print("\nGenerating 3D approximation plot for Our Method (based on initial run)...")

# Before plotting, ensure plot_trained_models dictionary is complete, i.e., contains all models
# And only pass the "Our Method" model to the plotting function
our_method_model_for_plot = None
if "Our Method" in trained_models:
    # Reconstruct Our Method instance, ensuring best parameters from first grid search and fixed RANDOM_STATE are used
    estimator_class = models_to_tune["Our Method"]["estimator"].__class__
    params_for_our_method = best_params_for_repetitions["Our Method"].copy()
    our_method_model_for_plot = estimator_class(random_state=RANDOM_STATE, **params_for_our_method)
    our_method_model_for_plot.fit(X_train_full, Z_train_full)  # Train with original training data
else:
    print("Error: 'Our Method' model not found in trained_models, cannot plot.")

if our_method_model_for_plot is not None:
    create_3d_plot1(
        X_train=X_train_full,
        Z_train=Z_train_full,
        X_test=X_test,
        Z_test=Z_test,
        X_grid_plot=X_grid1,
        Y_grid_plot=Y_grid1,
        true_func=true_func_3d,
        our_method_model=our_method_model_for_plot,  # Only pass Our Method model
        surface_cmap="viridis"  # Explicitly specify using viridis colormap
    )
