import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time

# 1. Import machine learning models and tools
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.tree import DecisionTreeRegressor
from xgboost import XGBRegressor
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.base import BaseEstimator, RegressorMixin

# -------------------------------------------------------------------
# 2. Global Configuration and Data Loading
# -------------------------------------------------------------------

# Specify .data file path
DATA_FILE_PATH = 'kin8nm.data'

# Attempt to load dataset
try:
    # According to the description of kin8nm.data, it typically has no header and is comma-separated numerical data
    df = pd.read_csv(DATA_FILE_PATH, sep=',', header=None, skipinitialspace=True)

    # The kin8nm dataset has 9 attributes, with the last column typically being the target variable (regression task)
    y_data = df.iloc[:, -1].values
    X_data_raw = df.iloc[:, :-1] # Features are all columns except the last one (target)

    # All features in the kin8nm dataset are continuous numerical.
    numerical_features = X_data_raw.columns.tolist()
    categorical_features = [] # No categorical features

    print(f"Successfully loaded data file: {DATA_FILE_PATH}. Data shape: {df.shape}")
    print("First 5 rows of data:")
    print(df.head())
    print(f"Number of numerical features identified: {len(numerical_features)}")
    if len(categorical_features) > 0:
        print(f"Categorical features identified: {categorical_features}")
    else:
        print("No categorical features identified.")

except FileNotFoundError:
    print(f"Error: File '{DATA_FILE_PATH}' not found. Please check the file path.")
except Exception as e:
    print(f"Error loading or parsing data file: {e}")
    print("Please check the data file format or `pd.read_csv` parameters.")


MIN_POINTS_FOR_SPLIT_BASE = 5


# -------------------------------------------------------------------
# 3. Core Algorithm Functions (Piecewise Linear Fitting Algorithm)
# -------------------------------------------------------------------

class Node:
    def __init__(self, is_leaf, region, params=None, split_coeffs=None, children=None, stop_reason=None,
                 data_indices=None):
        """
        :param is_leaf: True if it's a leaf node, False otherwise.
        :param region: A list containing (xmin, xmax) tuples for each feature.
        :param params: If it's a leaf node, a NumPy array of shape (n_features + 1,) representing hyperplane coefficients (w0, ..., wn-1, b).
        :param split_coeffs: If it's an internal node, a NumPy array of shape (n_features + 1,) representing split hyperplane coefficients (A0, ..., An-1, B).
        :param children: If it's an internal node, a list [left_child_node, right_child_node].
        :param stop_reason: If it's a leaf node, a string indicating the reason for recursion termination.
        :param data_indices: Indices of data points belonging to this node from the original dataset.
        """
        self.is_leaf = is_leaf
        self.region = region
        self.params = params
        self.split_coeffs = split_coeffs
        self.children = children
        self.stop_reason = stop_reason
        self.data_indices = data_indices

    def __repr__(self):
        if self.is_leaf:
            return f"LeafNode(params={np.round(self.params, 2)}, stop_reason='{self.stop_reason}')"
        else:
            return f"InternalNode(split_coeffs={np.round(self.split_coeffs, 2)}, children={len(self.children) if self.children else 0})"

    def count_leaves(self):
        """Recursively counts the number of leaf nodes under the current node."""
        if self.is_leaf:
            return 1
        if not self.children:
            return 0
        return sum(child.count_leaves() for child in self.children)


def linear_model_nd(X, coeffs):
    """
    N-dimensional linear model: \( \hat{y} = w_0x_0 + w_1x_1 + \dots + w_{n-1}x_{n-1} + b \).
    \(X\) is an array of shape \( (N_{samples}, N_{features}) \).
    \(coeffs\) is a NumPy array of shape \( (N_{features} + 1,) \), where the last element is the bias term \(b\).
    """
    X = np.asanyarray(X)
    # Add a column of ones for the bias term
    X_des = np.hstack([X, np.ones((X.shape[0], 1))])
    return X_des @ coeffs


def calculate_rmse(y_true, y_pred):
    """Calculates the Root Mean Squared Error (RMSE)."""
    return np.sqrt(np.mean((y_true - y_pred) ** 2))


def solve_ols_nd(X, Z, alpha=0):
    """
    Solves for N-dimensional hyperplane parameters using Ordinary Least Squares (OLS), with support for L2 regularization (Ridge Regression).
    \(X\) is the feature matrix of shape \( (N_{samples}, N_{features}) \).
    \(Z\) is the target vector of shape \( (N_{samples},) \).
    \(alpha\) is the regularization strength for Ridge Regression.
    Returns hyperplane coefficients \( (w_0, \dots, w_{n-1}, b) \) of shape \( (N_{features} + 1,) \).
    """
    n_points, n_features = X.shape
    X_des = np.hstack([X, np.ones((n_points, 1))])
    # Regularize w, without penalizing bias b
    I = np.eye(n_features + 1); I[-1, -1] = 0.0
    # Prevent numerical instability from large alpha, or unnecessary regularization matrix computation when alpha=0
    if alpha <= 1e-9: # Close to 0, use OLS directly
        theta = np.linalg.solve(X_des.T @ X_des, X_des.T @ Z)
    else:
        theta = np.linalg.solve(X_des.T @ X_des + alpha * I, X_des.T @ Z)
    return theta


def fit_single_plane(X, Z, ridge_alpha=0):
    """
    Fits a single hyperplane to data points.
    \(X\) is an array of shape \( (N, n_{features}) \), and \(Z\) is an array of shape \( (N,) \).
    Returns hyperplane parameters \( (w_0, \dots, w_{n-1}, b) \).
    """
    return solve_ols_nd(X, Z, alpha=ridge_alpha)


def initialize_thetas_nd(X, Z, random_seed=None, ridge_alpha=0):
    """
    Initializes the parameters for two hyperplanes.
    Initializes by finding any dimension with variance, performing a median split, and then fitting two subsets separately.
    """
    n_points, n_features = X.shape
    mean_z = np.mean(Z) if n_points > 0 else 0

    # Fitting a hyperplane requires at least \( n_{features} + 1 \) points
    min_points_to_fit = n_features + 1

    if n_points < min_points_to_fit:
        # Cannot fit, return a plane with zero coefficients and mean z value
        return np.array([0.0] * n_features + [mean_z]), np.array([0.0] * n_features + [mean_z])

    if random_seed is not None:
        np.random.seed(random_seed)

    # 1. Determine splitting dimension: Find any feature dimension with variance
    split_dim_found = False
    chosen_split_dim = -1
    for dim in range(n_features):
        if np.max(X[:, dim]) > np.min(X[:, dim]):  # Check if this dimension has variance
            chosen_split_dim = dim
            split_dim_found = True
            break

    # If all feature ranges are 0 (all points are identical), effective splitting is not possible
    if not split_dim_found:
        theta_fallback = np.array([0.0] * n_features + [mean_z])
        return theta_fallback, theta_fallback

    split_dim = chosen_split_dim

    # 2. Split data based on the median of the chosen dimension
    median_val = np.median(X[:, split_dim])
    mask_left_init = X[:, split_dim] < median_val
    mask_right_init = ~mask_left_init

    # 3. Ensure both initial segments have enough points to fit a hyperplane
    if np.sum(mask_left_init) < min_points_to_fit or np.sum(mask_right_init) < min_points_to_fit:
        # If median split is ineffective, fall back to fitting a global plane and perturbing it slightly
        theta_global = solve_ols_nd(X, Z, alpha=ridge_alpha)
        epsilon = np.random.rand(n_features + 1) * 0.005
        theta1 = theta_global + epsilon
        theta2 = theta_global - epsilon
        # Ensure perturbation does not lead to NaN or Inf
        theta1 = np.nan_to_num(theta1)
        theta2 = np.nan_to_num(theta2)
        # If they are still identical after perturbation, apply a fixed perturbation (very rare)
        if np.allclose(theta1, theta2, atol=1e-9):
            theta1[0] += 0.001 if n_features > 0 else 0
            theta1 = np.nan_to_num(theta1)
            theta2 = np.nan_to_num(theta2)
        return theta1, theta2

    # 4. Fit initial hyperplanes
    theta1 = solve_ols_nd(X[mask_left_init], Z[mask_left_init], alpha=ridge_alpha)
    theta2 = solve_ols_nd(X[mask_right_init], Z[mask_right_init], alpha=ridge_alpha)

    # Ensure the two planes are not identical after initialization to allow the split iteration to begin
    if np.allclose(theta1, theta2, atol=1e-6):
        epsilon = np.random.rand(n_features + 1) * 0.005
        theta1 = theta1 + epsilon
        theta2 = theta2 - epsilon
        theta1 = np.nan_to_num(theta1)
        theta2 = np.nan_to_num(theta2)
        if np.allclose(theta1, theta2, atol=1e-9):
            theta1[0] += 0.001 if n_features > 0 else 0
            theta1 = np.nan_to_num(theta1)
            theta2 = np.nan_to_num(theta2)

    return theta1, theta2


def _run_one_split_direction(X, Z, initial_theta1, initial_theta2, max_iter, tol, min_points_per_segment, min_points_to_fit, step_size, ridge_alpha, flip_split_direction=False):
    """
    Performs iterative optimization for splitting in one direction.
    :param X: Feature data.
    :param Z: Target data.
    :param initial_theta1: Initial coefficients for the left plane.
    :param initial_theta2: Initial coefficients for the right plane.
    :param flip_split_direction: If True, points with `prediction_diff > 0` are assigned to the left; otherwise, `prediction_diff < 0` are assigned to the left.
    :return: (split_coeffs, final_theta1, final_theta2, final_left_mask, final_right_mask, total_rmse) if successful, None otherwise.
    """
    n_points = X.shape[0]
    X_des = np.hstack([X, np.ones((n_points, 1))])

    theta1_current = initial_theta1.copy()
    theta2_current = initial_theta2.copy()

    prev_left_mask = None

    for i in range(max_iter):
        theta1_old, theta2_old = theta1_current.copy(), theta2_current.copy()

        # Calculate prediction difference, determine direction based on flip_split_direction
        # `prediction_diff` represents \( (prediction\_from\_plane1 - prediction\_from\_plane2) \)
        # If `flip_split_direction` is `False`, we expect points where `theta1` predicts better (`prediction_diff < 0`) to be assigned to the left side where `theta1` is.
        # If `flip_split_direction` is `True`, we expect points where `theta2` predicts better (`prediction_diff > 0` or \( -(prediction\_diff) < 0 \)) to be assigned to the left side where `theta1` is.
        if flip_split_direction:
            # Equivalent to \( (theta2\_current - theta1\_current) \) where prediction_diff < 0
            prediction_diff_for_mask = X_des @ (theta2_current - theta1_current)
        else:
            prediction_diff_for_mask = X_des @ (theta1_current - theta2_current)

        current_left_mask = prediction_diff_for_mask < 0

        # Check for convergence
        if prev_left_mask is not None and np.array_equal(current_left_mask, prev_left_mask):
            break
        prev_left_mask = current_left_mask

        n_left, n_right = np.sum(current_left_mask), n_points - np.sum(current_left_mask)

        # Check if segments contain enough data points
        if n_left < min_points_per_segment or n_right < min_points_per_segment or \
           n_left < min_points_to_fit or n_right < min_points_to_fit:
            return None # Indicates failure to split in this direction

        try:
            # Refit planes for data in each segment
            # `theta1_current` always fits points corresponding to `current_left_mask`
            theta1_target = solve_ols_nd(X[current_left_mask], Z[current_left_mask], alpha=ridge_alpha)
            # `theta2_current` always fits points corresponding to `~current_left_mask`
            theta2_target = solve_ols_nd(X[~current_left_mask], Z[~current_left_mask], alpha=ridge_alpha)

            # Update parameters using learning rate
            theta1_current = (1 - step_size) * theta1_old + step_size * theta1_target
            theta2_current = (1 - step_size) * theta2_old + step_size * theta2_target
        except np.linalg.LinAlgError:
            theta1_current, theta2_current = theta1_old, theta2_old
            break # OLS solver failed

        # Check for NaN/Inf
        if np.any(np.isnan(theta1_current)) or np.any(np.isinf(theta1_current)) or \
           np.any(np.isnan(theta2_current)) or np.any(np.isinf(theta2_current)):
            theta1_current, theta2_current = theta1_old, theta2_old
            break

        # Check if parameters have converged
        if np.linalg.norm(theta1_current - theta1_old) < tol and np.linalg.norm(theta2_current - theta2_old) < tol:
            break

    # After iteration, calculate final split parameters and RMSE
    if flip_split_direction:
        final_split_coeffs = theta2_current - theta1_current # Corresponds to the decision \( (theta2-theta1)<0 \)
    else:
        final_split_coeffs = theta1_current - theta2_current # Corresponds to the decision \( (theta1-theta2)<0 \)

    # The final mask determines which points belong to which plane
    # `final_left_mask` are points where `decision_value < 0`
    final_left_mask = (X_des @ final_split_coeffs) < 0
    final_right_mask = ~final_left_mask

    # Re-check the validity of the final split
    if np.sum(final_left_mask) < min_points_per_segment or np.sum(final_right_mask) < min_points_per_segment or \
       np.sum(final_left_mask) < min_points_to_fit or np.sum(final_right_mask) < min_points_to_fit:
        return None

    # Calculate total RMSE
    preds_on_left = linear_model_nd(X[final_left_mask], theta1_current)
    preds_on_right = linear_model_nd(X[final_right_mask], theta2_current)

    all_predictions = np.zeros_like(Z, dtype=float)
    all_predictions[final_left_mask] = preds_on_left
    all_predictions[final_right_mask] = preds_on_right
    total_rmse = calculate_rmse(Z, all_predictions)

    return (final_split_coeffs, theta1_current, theta2_current, final_left_mask, final_right_mask, total_rmse)


def optimize_split_optimized_nd(X, Z, max_iter=100, tol=1e-6, min_points_per_segment=MIN_POINTS_FOR_SPLIT_BASE,
                                step_size=1,
                                random_seed_for_init=None, ridge_alpha=1e-2):
    """
    Optimizes piecewise plane fitting in N-D space.
    Considers two splitting directions (`prediction_diff < 0` and `prediction_diff > 0`), choosing the better one.
    \(X\) is an array of shape \( (N, n_{features}) \). \(Z\) is an array of shape \( (N,) \).
    Returns the split hyperplane coefficients, parameters for the two segments, and the left/right masks.
    """
    n_points, n_features = X.shape
    min_points_to_fit = n_features + 1

    # Ensure enough data points for effective splitting
    if n_points < 2 * min_points_to_fit or n_points < 2 * min_points_per_segment:
        return None, None, None, None, None

    # Get initial `theta1` and `theta2`
    initial_theta1, initial_theta2 = initialize_thetas_nd(X, Z, random_seed=random_seed_for_init, ridge_alpha=ridge_alpha)

    # Check if initial parameters are valid
    if np.any(np.isnan(initial_theta1)) or np.any(np.isinf(initial_theta1)) or \
            np.any(np.isnan(initial_theta2)) or np.any(np.isinf(initial_theta2)):
        return None, None, None, None, None

    best_split_coeffs = None
    best_theta1 = None
    best_theta2 = None
    best_left_mask = None
    best_right_mask = None
    min_overall_rmse = float('inf')

    # Strategy 1: Original direction (`current_left_mask = prediction_diff < 0`)
    # Here, `theta1` is fitted for points with `prediction_diff < 0`, and `theta2` for points with `prediction_diff >= 0`.
    res1 = _run_one_split_direction(X, Z, initial_theta1, initial_theta2, max_iter, tol,
                                    min_points_per_segment, min_points_to_fit, step_size, ridge_alpha,
                                    flip_split_direction=False)

    if res1 is not None:
        split_coeffs1, theta1_1, theta2_1, mask_left_1, mask_right_1, rmse1 = res1
        if rmse1 < min_overall_rmse:
            min_overall_rmse = rmse1
            best_split_coeffs = split_coeffs1
            best_theta1 = theta1_1 # Plane fitted for `mask_left_1`
            best_theta2 = theta2_1 # Plane fitted for `mask_right_1`
            best_left_mask = mask_left_1
            best_right_mask = mask_right_1

    # Strategy 2: Reverse direction (`current_left_mask = prediction_diff > 0`)
    # In `_run_one_split_direction`, `flip_split_direction=True` means `prediction_diff` is calculated as \( (initial\_theta2 - initial\_theta1) \)
    # Then `current_left_mask = (initial_theta2 - initial_theta1) < 0`
    # Here, the returned `theta1_2` is the plane fitted for points where \( (initial\_theta2 - initial\_theta1) < 0 \)
    # The returned `theta2_2` is the plane fitted for points where \( (initial\_theta2 - initial\_theta1) >= 0 \)
    res2 = _run_one_split_direction(X, Z, initial_theta1, initial_theta2, max_iter, tol,
                                    min_points_per_segment, min_points_to_fit, step_size, ridge_alpha,
                                    flip_split_direction=True)

    if res2 is not None:
        split_coeffs2, theta1_2, theta2_2, mask_left_2, mask_right_2, rmse2 = res2
        if rmse2 < min_overall_rmse:
            min_overall_rmse = rmse2
            best_split_coeffs = split_coeffs2
            best_theta1 = theta1_2 # Plane fitted for `mask_left_2`
            best_theta2 = theta2_2 # Plane fitted for `mask_right_2`
            best_left_mask = mask_left_2
            best_right_mask = mask_right_2

    return best_split_coeffs, best_theta1, best_theta2, best_left_mask, best_right_mask


def recursive_piecewise_fit_nd(X_full, Z_full, current_indices, threshold,
                               min_points_per_segment=MIN_POINTS_FOR_SPLIT_BASE, depth=0, max_depth=5,
                               current_run_random_state=None, step_size_optimize=0.1, ridge_alpha=0):
    """
    Recursively performs piecewise linear fitting on N-D data, returning the root node of the segmented tree.
    :param X_full: The complete original feature dataset.
    :param Z_full: The complete original target dataset.
    :param current_indices: Indices of data points contained in the current node.
    :param threshold: RMSE stopping threshold.
    :param min_points_per_segment: Minimum number of data points required per segment.
    :param depth: Current recursion depth.
    :param max_depth: Maximum tree depth.
    :param current_run_random_state: Base random seed for the current run.
    :param step_size_optimize: Step size in the split optimization algorithm.
    :param ridge_alpha: Ridge regression regularization parameter.
    """
    X_node_data = X_full[current_indices]
    Z_node_data = Z_full[current_indices]

    n_points = X_node_data.shape[0]
    n_features = X_full.shape[1]
    min_points_to_fit = n_features + 1

    # Determine the region (hyper-rectangle boundary) of the current segment
    current_region = []
    if n_points > 0:
        for dim in range(n_features):
            current_region.append((np.min(X_node_data[:, dim]), np.max(X_node_data[:, dim])))
        mean_z_for_fallback = np.mean(Z_node_data)
    else:
        # Empty segment, return a leaf node with NaN parameters and empty region
        return Node(is_leaf=True, params=np.array([np.nan] * (n_features + 1)),
                    region=[(np.nan, np.nan)] * n_features, stop_reason='empty_segment', data_indices=current_indices)

    if n_points < min_points_to_fit:
        # Not enough points to fit a single hyperplane
        # Use the mean Z value as a fallback "plane" (horizontal plane)
        fallback_params = np.array([0.0] * n_features + [mean_z_for_fallback])
        return Node(is_leaf=True, params=fallback_params, region=current_region,
                    stop_reason='not_enough_points_for_fit', data_indices=current_indices)

    theta_single = fit_single_plane(X_node_data, Z_node_data, ridge_alpha=ridge_alpha)

    # Handle cases where `fit_single_plane` might return invalid parameters (e.g., singular matrix due to too few points)
    if np.any(np.isnan(theta_single)) or np.any(np.isinf(theta_single)):
        fallback_params = np.array([0.0] * n_features + [mean_z_for_fallback])
        return Node(is_leaf=True, params=fallback_params, region=current_region,
                    stop_reason='fit_failed_numerically', data_indices=current_indices)

    rmse_single = calculate_rmse(Z_node_data, linear_model_nd(X_node_data, theta_single))

    if depth >= max_depth:
        return Node(is_leaf=True, params=theta_single, region=current_region,
                    stop_reason='max_depth_reached', data_indices=current_indices)

    # Only enable RMSE threshold pre-pruning if `threshold` is non-negative
    if threshold >= 0 and rmse_single <= threshold:
        return Node(is_leaf=True, params=theta_single, region=current_region,
                    stop_reason='RMSE_threshold_met', data_indices=current_indices)

    # Pre-check: Not enough points to perform a meaningful split (requires at least 2 segments, each with at least `min_points_per_segment`
    # and also `min_points_to_fit` for their own plane fitting).
    if n_points < max(2 * min_points_to_fit, 2 * min_points_per_segment):
        return Node(is_leaf=True, params=theta_single, region=current_region,
                    stop_reason='not_enough_points_for_split_pre_check', data_indices=current_indices)

    max_init_attempts = 3
    split_successful = False
    split_coeffs, params_left, params_right, mask_left, mask_right = None, None, None, None, None

    for attempt_num in range(max_init_attempts):
        # Use `current_run_random_state` as the base seed for random operations within this call
        # Add `depth` and `attempt_num` to ensure different seeds for different branches/attempts
        seed_for_attempt = (
                               current_run_random_state if current_run_random_state is not None else 0) + depth * 1000 + attempt_num

        split_coeffs, params_left, params_right, mask_left, mask_right = \
            optimize_split_optimized_nd(X_node_data, Z_node_data,
                                        min_points_per_segment=min_points_per_segment,
                                        step_size=step_size_optimize,
                                        random_seed_for_init=seed_for_attempt,
                                        ridge_alpha=ridge_alpha)

        # Check if split was successful and produced valid segments
        if split_coeffs is not None and params_left is not None and params_right is not None and \
                np.sum(mask_left) >= min_points_per_segment and np.sum(mask_right) >= min_points_per_segment and \
                np.sum(mask_left) >= min_points_to_fit and np.sum(mask_right) >= min_points_to_fit:
            split_successful = True
            break

    # If all optimized split attempts fail, enable fallback "median split" strategy
    if not split_successful:
        # Find any dimension with variance for median splitting
        split_dim_found = False
        fallback_split_dim = -1
        for dim in range(n_features):
            if np.max(X_node_data[:, dim]) > np.min(X_node_data[:, dim]):  # Check if this dimension has variance
                fallback_split_dim = dim
                split_dim_found = True
                break

        # If all dimensions are zero (e.g., all points are identical), splitting is not possible
        if not split_dim_found:
            return Node(is_leaf=True, params=theta_single, region=current_region,
                        stop_reason='fallback_split_failed_no_varying_dimension', data_indices=current_indices)

        split_dim = fallback_split_dim
        median_val = np.median(X_node_data[:, split_dim])

        # Create split masks
        mask_left = X_node_data[:, split_dim] < median_val
        mask_right = ~mask_left

        # Define split hyperplane coefficients \[ (A_0 x_0 + \dots + A_{n-1} x_{n-1} + B = 0) \]
        # For a split at `median_val` in dimension `split_dim`:
        # \[ x_{\text{split_dim}} - \text{median_val} = 0 \]
        # Thus, the coefficient for \( x_{\text{split_dim}} \) is \( 1.0 \), and the bias term is \( -\text{median_val} \).
        split_coeffs = np.zeros(n_features + 1)
        split_coeffs[split_dim] = 1.0
        split_coeffs[-1] = -median_val

        # Check if fallback split is valid (whether it produced two non-empty children, and if they have enough points to eventually fit a hyperplane)
        if np.sum(mask_left) < min_points_per_segment or np.sum(mask_right) < min_points_per_segment or \
                np.sum(mask_left) < min_points_to_fit or np.sum(mask_right) < min_points_to_fit:
            # This is a final failure case, in which the current node becomes a leaf node
            return Node(is_leaf=True, params=theta_single, region=current_region,
                        stop_reason='fallback_split_failed_insufficient_points', data_indices=current_indices)

    # Recursive call, passing the full `X_full`, `Z_full`, and subset indices
    left_child = recursive_piecewise_fit_nd(X_full, Z_full, current_indices[mask_left], threshold,
                                            min_points_per_segment, depth + 1,
                                            max_depth, current_run_random_state, step_size_optimize, ridge_alpha)
    right_child = recursive_piecewise_fit_nd(X_full, Z_full, current_indices[mask_right], threshold,
                                             min_points_per_segment, depth + 1,
                                             max_depth, current_run_random_state, step_size_optimize, ridge_alpha)

    # Return an internal node
    return Node(is_leaf=False, split_coeffs=split_coeffs, children=[left_child, right_child], region=current_region,
                data_indices=current_indices)


def _predict_single_point_nd(x_val_arr, node):
    """
    Recursively predicts a single N-D data point.
    :param x_val_arr: Array for a single point \( [x_0, x_1, \dots, x_{n-1}, 1] \).
    :param node: The current tree node being traversed.
    :return: The predicted z value.
    """
    # If it's a leaf node, use its plane parameters directly for prediction
    if node.is_leaf:
        coeffs = node.params
        # Handle cases where parameters might be NaN (e.g., due to 'empty_segment' or 'fit_failed')
        if coeffs is None or np.any(np.isnan(coeffs)) or np.any(np.isinf(coeffs)):
            # Fallback for invalid leaf parameters: return 0.0 or mean value, here simplified to 0.0
            return 0.0
        # `linear_model_nd` expects input of shape \( (N, n_{features}) \), so remove the trailing 1 and reshape
        return linear_model_nd(x_val_arr[:-1].reshape(1, -1), coeffs)

    # If it's an internal node, decide which subtree to enter based on the splitting rule
    # `split_coeffs` is \( (A_0, \dots, A_{n-1}, B) \) for \( A_0x_0 + \dots + A_{n-1}x_{n-1} + B = 0 \)
    # `x_val_arr` is \( [x_0, \dots, x_{n-1}, 1] \), so a direct dot product with `split_coeffs` is sufficient
    decision_value = x_val_arr @ node.split_coeffs

    # If `decision_value < 0`, traverse the left child (corresponds to `prediction_diff < 0` during training)
    if decision_value < -1e-9:  # Add a small tolerance
        return _predict_single_point_nd(x_val_arr, node.children[0])
    else:  # Otherwise, traverse the right child
        return _predict_single_point_nd(x_val_arr, node.children[1])


def predict_nd(X_new, root_node):
    """
    Predicts on new data points (N-D) using the trained piecewise model (tree structure).
    \(X_{new}\) is an array of shape \( (N, n_{features}) \).
    """
    predictions = np.zeros(X_new.shape[0], dtype=float)

    # If no root node is provided (training failed), return zero predictions
    if root_node is None:
        return predictions

    # Prepare for prediction: add a bias term (1) to each data point for dot product operation
    # `X_new_des` will be \( (N_{samples}, n_{features} + 1) \)
    X_new_des = np.hstack([X_new, np.ones((X_new.shape[0], 1))])

    for i, x_val_des in enumerate(X_new_des):
        predictions[i] = _predict_single_point_nd(x_val_des, root_node)[0]

    return predictions


def collect_leaf_stop_reasons(node, stop_reason_counts):
    """
    Recursively traverses the tree, collecting and counting leaf node stop reasons.
    """
    if node is None:
        return

    if node.is_leaf:
        reason = node.stop_reason if node.stop_reason is not None else 'unknown_reason'
        stop_reason_counts[reason] = stop_reason_counts.get(reason, 0) + 1
    else:
        if node.children:
            collect_leaf_stop_reasons(node.children[0], stop_reason_counts)
            collect_leaf_stop_reasons(node.children[1], stop_reason_counts)


# -------------------------------------------------------------------
# 4. Main Experiment Flow
# -------------------------------------------------------------------

# --- 4.1 Data Splitting ---
TEST_SIZE = 0.5
RANDOM_STATE = 42

# Use raw data loaded from .data file
X_train_full, X_test, Z_train_full, Z_test = train_test_split(X_data_raw, y_data, test_size=TEST_SIZE,
                                                              random_state=RANDOM_STATE)

# --- Define Data Preprocessing Pipeline ---
# ColumnTransformer allows applying different transformers to different types of columns
preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), numerical_features),  # Standardize numerical features
        ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)  # One-hot encode categorical features
    ])


# --- Define OurMethodRegressor Wrapper for GridSearchCV ---
class OurMethodRegressor(BaseEstimator, RegressorMixin):
    def __init__(self, threshold=0.03, min_points=MIN_POINTS_FOR_SPLIT_BASE, max_depth=5, step_size=0.1, ridge_alpha=0, random_state=None):
        self.threshold = threshold
        self.min_points = min_points  # This is `min_points_per_segment`
        self.max_depth = max_depth
        self.step_size = step_size
        self.ridge_alpha = ridge_alpha
        self.random_state = random_state
        self.root_node = None  # Stores the trained tree

    def fit(self, X, y):
        # Pass `self.random_state` as the base seed to the recursive function
        # X, y are already preprocessed NumPy arrays here
        # Pass the complete X, y, and initial indices
        self.root_node = recursive_piecewise_fit_nd(
            X, y, np.arange(len(y)),
            threshold=self.threshold,
            min_points_per_segment=self.min_points,
            max_depth=self.max_depth,
            current_run_random_state=self.random_state,
            step_size_optimize=self.step_size,
            ridge_alpha=self.ridge_alpha
        )
        return self

    def predict(self, X):
        if self.root_node is None:
            # If fitting fails (e.g., not enough points), fall back
            return np.zeros(X.shape[0])
        # X is already a preprocessed NumPy array here
        return predict_nd(X, self.root_node)

    # Scikit-learn requires `BaseEstimator` to have `get_params` and `set_params`
    def get_params(self, deep=True):
        return {"threshold": self.threshold, "min_points": self.min_points, "max_depth": self.max_depth,
                "step_size": self.step_size, "ridge_alpha": self.ridge_alpha, "random_state": self.random_state}

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self


# --- 4.2 Define all Models and their Hyperparameter Grids (Including Our Method) ---
models_to_tune = {
    "Our Method (Pre-Pruning)": {
        "estimator": OurMethodRegressor(),
        "params": {
            "threshold": [0,0.5],  # Adjust according to the RMSE range of the actual dataset
            "max_depth": [ 4,6,8],
            "step_size": [0.01, 0.5, 1],
            "ridge_alpha": [0,  0.1, 1, 10],
        }
    },
    "CART": {
        "estimator": DecisionTreeRegressor(random_state=RANDOM_STATE),
        "params": {
            "max_depth": [3, 5, 7, 9, 11],
            "min_samples_split": [2, 5, 10],
            "min_samples_leaf": [1, 2, 4]
        }
    },
    "XGBoost": {
        "estimator": XGBRegressor(random_state=RANDOM_STATE, objective='reg:squarederror'),
        "params": {
            "n_estimators": [50, 100, 200],
            "learning_rate": [0.01, 0.05, 0.1],
            "max_depth": [3, 5, 7],
            "subsample": [0.7, 1.0],
        }
    }
}

# --- 4.3 Perform Grid Search and Train Best Models ---
results = []
trained_models = {}
best_params_for_repetitions = {}

print("\n--- Starting Hyperparameter Tuning for All Models (This may take some time) ---")
for name, model_info in models_to_tune.items():
    start_time = time.time()
    print(f"Tuning {name}...")

    estimator = model_info["estimator"]
    params = model_info["params"]

    # Create a complete pipeline, including preprocessing and estimator
    full_pipeline = Pipeline([
        ('preprocessor', preprocessor),
        ('regressor', estimator)
    ])

    # Adjust parameter names to match Pipeline's naming convention (e.g., 'regressor__max_depth')
    grid_params = {f'regressor__{key}': value for key, value in params.items()}

    grid_search = GridSearchCV(
        full_pipeline,
        grid_params,
        cv=5,
        scoring='neg_mean_squared_error',
        n_jobs=-1
    )
    grid_search.fit(X_train_full, Z_train_full)

    # Save best model and parameters for later use
    trained_models[name] = grid_search.best_estimator_
    # Extract original model parameter names from Pipeline's best parameters
    extracted_params = {k.replace('regressor__', ''): v for k, v in grid_search.best_params_.items()}
    best_params_for_repetitions[name] = extracted_params

    # Evaluate on test set and collect results
    z_pred = trained_models[name].predict(X_test)

    num_segments = "N/A"
    stop_reasons_str = "N/A"
    if "Our Method" in name:
        # Get the `OurMethodRegressor` instance (it's the 'regressor' step in the Pipeline)
        our_regressor_instance = trained_models[name].named_steps['regressor']
        if our_regressor_instance.root_node is not None:
            num_segments = our_regressor_instance.root_node.count_leaves()
            # Also collect stop reasons for the first grid search results
            temp_stop_reason_counts = {}
            collect_leaf_stop_reasons(our_regressor_instance.root_node, temp_stop_reason_counts)
            total_leaves = sum(temp_stop_reason_counts.values())
            if total_leaves > 0:
                reasons_percentages = {
                    reason: f"{count / total_leaves * 100:.1f}%"
                    for reason, count in temp_stop_reason_counts.items()
                }
                stop_reasons_str = str(reasons_percentages)
            else:
                stop_reasons_str = "No leaves found."
        else:
            num_segments = 0
            stop_reasons_str = "Model training failed, no tree."

    results.append({
        "Model": name,
        "RMSE": np.sqrt(mean_squared_error(Z_test, z_pred)),
        "MAE": mean_absolute_error(Z_test, z_pred),
        "R²": r2_score(Z_test, z_pred),
        "Best Params": str(best_params_for_repetitions[name]),
        "Segments (K)": num_segments,
        "Stop Reasons (Initial)": stop_reasons_str
    })

    duration = time.time() - start_time
    print(f"Tuning for {name} completed in {duration:.2f} seconds. Best cross-validation score (negative mean squared error): {grid_search.best_score_:.4f}")
    print(f"Best parameters found: {best_params_for_repetitions[name]}\n")

results_df_initial = pd.DataFrame(results)

# -------------------------------------------------------------------
# 5. Results Display (First Grid Search Results)
# -------------------------------------------------------------------

# --- 5.1 Print Performance Comparison Table (First Grid Search Results) ---
print("\n" + "=" * 80)
print(" " * 15 + "Algorithm Performance Comparison (Initial Tuning Results)")
print("=" * 80)
with pd.option_context('display.max_colwidth', None, 'display.width', 1000):
    print(results_df_initial.to_string(index=False).replace('nan', r'\text{N/A}'))
print("=" * 80)

# -------------------------------------------------------------------
# 6. Repeat Training and Evaluation N Times, and Calculate Mean and Standard Deviation
# -------------------------------------------------------------------
N_REPETITIONS = 5
all_run_results = {name: {'RMSE': [], 'MAE': [], 'R2': [], 'Segments': []}
                   for name in models_to_tune.keys()}
# Add stop reason list for Our Method
for name in models_to_tune.keys():
    if "Our Method" in name:
        all_run_results[name]["StopReasons"] = []

print(f"\n--- Starting {N_REPETITIONS} Repetitions with Final Tuned Parameters ---")

for i in range(N_REPETITIONS):
    print(f"\n--- Repetition {i + 1}/{N_REPETITIONS} ---")

    current_random_state = RANDOM_STATE + i
    # For each repetition, perform new data splitting using different random seeds (X_data_raw and y_data are original data)
    X_train_full_i, X_test_i, Z_train_full_i, Z_test_i = train_test_split(X_data_raw, y_data, test_size=TEST_SIZE,
                                                                          random_state=current_random_state)

    # --- Train and Evaluate Each Model (Using Best Parameters from Grid Search) ---
    for name, model_info in models_to_tune.items():
        base_estimator_class = model_info["estimator"].__class__
        best_params = best_params_for_repetitions[name].copy()

        # Set the current repetition's random state for models that support `random_state`
        if 'random_state' in best_params:
            best_params['random_state'] = current_random_state
        # Ensure our method also receives random state
        if "Our Method" in name and "random_state" not in best_params:
            best_params["random_state"] = current_random_state

        # Create a new complete pipeline instance for each repetition
        current_model_instance = Pipeline([
            ('preprocessor', preprocessor),
            ('regressor', base_estimator_class(**best_params))
        ])

        current_model_instance.fit(X_train_full_i, Z_train_full_i)
        z_pred_i = current_model_instance.predict(X_test_i)

        all_run_results[name]["RMSE"].append(np.sqrt(mean_squared_error(Z_test_i, z_pred_i)))
        all_run_results[name]["MAE"].append(mean_absolute_error(Z_test_i, z_pred_i))
        all_run_results[name]["R2"].append(r2_score(Z_test_i, z_pred_i))

        if "Our Method" in name:
            our_regressor_instance = current_model_instance.named_steps['regressor']
            if our_regressor_instance.root_node is not None:
                all_run_results[name]["Segments"].append(our_regressor_instance.root_node.count_leaves())
                # Collect stop reasons for this run
                stop_reason_counts = {}
                collect_leaf_stop_reasons(our_regressor_instance.root_node, stop_reason_counts)
                all_run_results[name]["StopReasons"].append(stop_reason_counts)
            else:
                all_run_results[name]["Segments"].append(0)
                all_run_results[name]["StopReasons"].append({})

# --- Calculate Mean and Standard Deviation ---
final_summary_results = []
for model_name, metrics_dict in all_run_results.items():
    mean_rmse = np.mean(metrics_dict["RMSE"])
    std_rmse = np.std(metrics_dict["RMSE"])
    mean_mae = np.mean(metrics_dict["MAE"])
    std_mae = np.std(metrics_dict["MAE"])
    mean_r2 = np.mean(metrics_dict["R2"])
    std_r2 = np.std(metrics_dict["R2"])

    row = {
        "Model": model_name,
        "RMSE (Mean \( \pm \) Std)": f"\[ {mean_rmse:.3f} \pm {std_rmse:.3f} \]",
        "MAE (Mean \( \pm \) Std)": f"\[ {mean_mae:.3f} \pm {std_mae:.3f} \]",
        "R² (Mean \( \pm \) Std)": f"\[ {mean_r2:.3f} \pm {std_r2:.3f} \]"
    }
    if "Our Method" in model_name:
        mean_segments = np.mean(metrics_dict["Segments"])
        std_segments = np.std(metrics_dict["Segments"])
        row["Segments (K) (Mean \( \pm \) Std)"] = f"\[ {mean_segments:.1f} \pm {std_segments:.1f} \]"

        # Aggregate stop reasons
        total_stop_reasons = {}
        for reason_counts_dict in metrics_dict["StopReasons"]:
            for reason, count in reason_counts_dict.items():
                total_stop_reasons[reason] = total_stop_reasons.get(reason, 0) + count

        # Convert to percentage
        total_leaves_across_runs = sum(total_stop_reasons.values())
        if total_leaves_across_runs > 0:
            reasons_percentages = {
                reason: f"{count / total_leaves_across_runs * 100:.1f}%"
                for reason, count in sorted(total_stop_reasons.items())
            }
            row["Stop Reasons Distribution (%)"] = str(reasons_percentages)
        else:
            row["Stop Reasons Distribution (%)"] = "No leaves found across runs."
    else:
        row["Segments (K) (Mean \( \pm \) Std)"] = "N/A"
        row["Stop Reasons Distribution (%)"] = "N/A"
    final_summary_results.append(row)

final_summary_df = pd.DataFrame(final_summary_results)

# -------------------------------------------------------------------
# 7. Final Results Display (Mean \( \pm \) Standard Deviation from N Repetitions)
# -------------------------------------------------------------------

print("\n" + "=" * 80)
print(f" " * 10 + f"Algorithm Performance Comparison (Mean \( \pm \) Standard Deviation from {N_REPETITIONS} Repetitions)")
print("=" * 80)
with pd.option_context('display.max_colwidth', None, 'display.width', 1000):
    print(final_summary_df.to_string(index=False))
print("=" * 80)
