import numpy as np
import matplotlib.pyplot as plt


def true_func_sinc(x):
    """A function with both concave and convex parts (unstable case)"""
    return -np.sinc(x * 5)


def true_func_twisted_sigmoid(x):
    """A well-behaved function with clear inflection points (stable case)"""
    # A scaled, shifted, and twisted Sigmoid function
    sigmoid = 1 / (1 + np.exp(-x * 3))
    return 2 * sigmoid - x * 0.8  # Subtracting a linear term to create an interesting shape


# Modified generate_data function to accept random_seed
def generate_data(func_name='sinc', n_samples=200, noise=0.0, random_seed=None):
    """
    Generates experimental data.

    Args:
        func_name (str): 'sinc' or 'twisted_sigmoid'.
        n_samples (int): Number of sample points.
        noise (float): Noise level.
        random_seed (int, optional): Random seed for noise generation.
    """
    print(f"Generating data for function: '{func_name}'")
    if func_name == 'sinc':
        true_func = true_func_sinc
        x = np.linspace(-1.5, 1.5, n_samples)
    elif func_name == 'twisted_sigmoid':
        true_func = true_func_twisted_sigmoid
        x = np.linspace(-3, 3, n_samples)
    else:
        raise ValueError("func_name must be 'sinc' or 'costan'")

    y = true_func(x)
    if noise > 0:
        if random_seed is not None:
            np.random.seed(random_seed)  # Set seed before generating noise
        y += np.random.normal(0, noise, size=x.shape)

    X_augmented = np.vstack([x, np.ones(n_samples)]).T
    return X_augmented, y


# --- 2. Core Algorithm ---
def solve_ols(X, y):
    """Calculates the Ordinary Least Squares solution"""
    # Ensure there are enough data points to fit the model
    if X.shape[0] < X.shape[1]:
        # print("Warning: Not enough data points to solve OLS. Returning zeros.")
        return np.zeros(X.shape[1])
    try:
        # Use pseudo-inverse for numerical stability
        return np.linalg.pinv(X.T @ X) @ X.T @ y
    except np.linalg.LinAlgError:
        # print("Warning: Singular matrix in OLS. Returning zeros.")
        return np.zeros(X.shape[1])


def calculate_model_losses(X, y, theta1, theta2):
    """Calculates the loss for upper and lower envelopes"""
    pred1, pred2 = X @ theta1, X @ theta2
    loss_upper = 0.5 * np.sum((y - np.maximum(pred1, pred2)) ** 2)
    loss_lower = 0.5 * np.sum((y - np.minimum(pred1, pred2)) ** 2)
    return loss_upper, loss_lower


def initialize_thetas_random_boundary(X, y, random_seed=None):
    """Initializes two lines by connecting points on data boundaries and a random inner point"""
    if random_seed is not None:
        np.random.seed(random_seed)

    x_coords = X[:, 0]
    min_idx, max_idx = np.argmin(x_coords), np.argmax(x_coords)
    p_min = (x_coords[min_idx], y[min_idx])
    p_max = (x_coords[max_idx], y[max_idx])

    x_split = np.random.uniform(p_min[0], p_max[0])
    y_split = np.interp(x_split, x_coords, y)
    p_split = (x_split, y_split)
    print(f"Initialized (Boundary Anchored): splitting at x={x_split:.2f}")

    A1 = np.array([[p_min[0], 1], [p_split[0], 1]])
    B1 = np.array([p_min[1], p_split[1]])
    theta1 = np.linalg.solve(A1, B1) if np.linalg.det(A1) != 0 else np.array([0, p_min[1]])

    A2 = np.array([[p_max[0], 1], [p_split[0], 1]])
    B2 = np.array([p_max[1], p_split[1]])
    theta2 = np.linalg.solve(A2, B2) if np.linalg.det(A2) != 0 else np.array([0, p_max[1]])

    return (theta1, theta2), x_split


def find_optimal_split_newton(X, y, initial_thetas, max_iter=50, step_size=1.0, tol=1e-6):
    """
    Iteratively finds the optimal split.
    If partitions become too small (a sign of non-convergence), it falls back to a single linear model.
    """
    theta1, theta2 = initial_thetas[0].copy(), initial_thetas[1].copy()
    n_samples, n_features = X.shape
    loss_history, old_mask1 = [], np.zeros(n_samples, dtype=bool)
    is_single_line_fallback = False

    for i in range(max_iter):
        loss_upper, loss_lower = calculate_model_losses(X, y, theta1, theta2)
        loss_history.append(min(loss_upper, loss_lower))

        pred1, pred2 = X @ theta1, X @ theta2
        mask1 = pred1 >= pred2

        if i > 0 and (np.array_equal(mask1, old_mask1) or abs(loss_history[-1] - loss_history[-2]) < tol):
            print(f"Converged at iteration {i + 1}.")
            # break # Keep iterating to observe curve changes

        old_mask1 = mask1

        if np.sum(mask1) < n_features or np.sum(~mask1) < n_features:
            print(f"Partition became too small at iteration {i + 1}. "
                  f"Step size {step_size} is likely too large. "
                  f"Falling back to a single OLS line for all data.")

            theta_global = solve_ols(X, y)
            theta1 = theta_global
            theta2 = theta_global
            is_single_line_fallback = True
            # break # No point in further iterations once this happens, break out

        X1, y1 = X[mask1], y[mask1]
        X2, y2 = X[~mask1], y[~mask1]
        theta1_ols, theta2_ols = solve_ols(X1, y1), solve_ols(X2, y2)

        theta1 = (1 - step_size) * theta1 + step_size * theta1_ols
        theta2 = (1 - step_size) * theta2 + step_size * theta2_ols

    loss_upper_final, loss_lower_final = calculate_model_losses(X, y, theta1, theta2)
    model_type = 'max' if loss_upper_final <= loss_lower_final else 'min'
    w = theta1 - theta2

    return theta1, theta2, w, loss_history, model_type, is_single_line_fallback


# --- 3. Experiment Execution and Visualization ---

# ==========================================================
# ---              Configure your experiment here        ---
# ==========================================================
FUNCTION_TO_TEST = 'sinc'  # Options: 'sinc' or 'twisted_sigmoid'
NOISE_LEVEL = 0.025  # e.g.: 0.0, 0.05, 0.1
N_SAMPLES = 1000  # Number of sample points
# ==========================================================

print("--- Running Experiments ---")
initial_seeds = [60]  # Different random start points
step_sizes_to_test = [1.0, 0.5, 0.1, 0.05]

colors = {
    1.0: '#1f77b4',  # Blue
    0.5: '#ff7f0e',  # Orange
    0.1: '#2ca02c',  # Green
    0.05: '#d62728'  # Red
}
styles = {
    1.0: '-',  # Solid line
    0.5: '--',  # Dashed line
    0.1: '-',  # Solid line
    0.05: '-.'  # Dash-dot line
}
linewidths = {
    1.0: 2.5,
    0.5: 2.5,
    0.1: 2.5,
    0.05: 2.0
}


for i, seed in enumerate(initial_seeds):
    run_title = f"Analysis from Start {i + 1} (seed={seed}) on '{FUNCTION_TO_TEST}' data"
    print(f"\n--- {run_title} ---")

    # At the beginning of each loop, generate data using the current seed
    X, y = generate_data(func_name=FUNCTION_TO_TEST, noise=NOISE_LEVEL, n_samples=N_SAMPLES, random_seed=seed)

    # Adjust figsize to make space for the combined legend at the top
    fig, (ax_loss, ax_fit) = plt.subplots(1, 2, figsize=(8, 5.5))

    initial_thetas, init_split_point = initialize_thetas_random_boundary(X, y, random_seed=seed)

    ax_fit.scatter(X[:, 0], y, s=10, alpha=0.2, color='gray', label='Data')
    ax_fit.axvline(init_split_point, color='purple', linestyle=':', lw=3,
                   label=f"Initial Split")

    initial_loss = min(calculate_model_losses(X, y, initial_thetas[0], initial_thetas[1]))
    ax_loss.plot(0, initial_loss, '*', markersize=15, color='red', markeredgecolor='black',
                 label='Start Point', zorder=10)

    for step_size in step_sizes_to_test:
        if step_size not in colors:
            continue
        print(f"--- Testing Step Size: {step_size} ---")

        theta1, theta2, w, losses, model_type, is_single_line = find_optimal_split_newton(
            X, y, initial_thetas, step_size=step_size, max_iter=75
        )

        ax_loss.plot(losses, 'o-', label=f'Step={step_size}', color=colors[step_size],
                     markersize=4, lw=linewidths[step_size])

        x_range = np.linspace(X[:, 0].min(), X[:, 0].max(), 200)
        X_range_aug = np.vstack([x_range, np.ones_like(x_range)]).T

        if is_single_line:
            final_preds = X_range_aug @ theta1
            label_text = f'Step={step_size}, single line'
            ax_fit.plot(x_range, final_preds,
                        linestyle=styles[step_size],
                        linewidth=linewidths[step_size],
                        color=colors[step_size],
                        label=label_text)
        else:
            if abs(w[0]) > 1e-9:
                final_split = -w[1] / w[0]
                pred1, pred2 = X_range_aug @ theta1, X_range_aug @ theta2
                final_preds = np.maximum(pred1, pred2) if model_type == 'max' else np.minimum(pred1, pred2)

                label_text = f'Step={step_size}'
                ax_fit.plot(x_range, final_preds,
                            linestyle=styles[step_size],
                            linewidth=linewidths[step_size],
                            color=colors[step_size],
                            label=label_text)
                ax_fit.axvline(final_split, color=colors[step_size], linestyle='--', linewidth=linewidths[step_size])

    ax_loss.set_xlabel('Iteration')
    ax_loss.set_ylabel('Loss')

    ax_fit.set_xlabel('x')
    ax_fit.set_ylabel('y')
    ax_fit.set_ylim(y.min() - 0.2, y.max() + 0.2)


    # --- Logic for combining and sorting legends ---
    handles_loss, labels_loss = ax_loss.get_legend_handles_labels()
    handles_fit, labels_fit = ax_fit.get_legend_handles_labels()

    all_legend_items = {}

    for h, l in list(zip(handles_loss, labels_loss)) + list(zip(handles_fit, labels_fit)):
        if l.startswith('Step='):
            base_step_label = l.replace(', single line', '')
            step_val = float(base_step_label.replace('Step=', ''))

            if base_step_label not in all_legend_items:
                all_legend_items[base_step_label] = (h, l, step_val)
            else:
                current_stored_label = all_legend_items[base_step_label][1]
                if 'single line' in l and 'single line' not in current_stored_label:
                    all_legend_items[base_step_label] = (h, l, step_val)
        else:
            all_legend_items[l] = (h, l, None)

    final_handles = []
    final_labels = []

    specific_order_labels = ['Data', 'Initial Split', 'Start Point']
    for label in specific_order_labels:
        if label in all_legend_items:
            final_handles.append(all_legend_items[label][0])
            final_labels.append(all_legend_items[label][1])

    step_items_for_sorting = []
    for base_label, (h, full_label, step_val) in all_legend_items.items():
        if base_label not in specific_order_labels and step_val is not None:
            step_items_for_sorting.append((step_val, h, full_label))

    step_items_for_sorting.sort(key=lambda x: x[0], reverse=True)

    for _, h, full_label in step_items_for_sorting:
        final_handles.append(h)
        final_labels.append(full_label)

    # Place the combined legend at the top center of the entire figure
    # rect=[left, bottom, right, top], adjust top to make space for the legend
    plt.tight_layout(rect=[0, 0.05, 1, 0.853])  # Adjust top to 0.88 to leave space for the legend at the top

    fig.legend(final_handles, final_labels, loc='upper center',
               bbox_to_anchor=(0.5, 0.99),  # Anchor the center of the legend near the top of the figure
               ncol=len(final_labels),      # Arrange all legend items in a single row
               fontsize='small')


    pdf_filename = f"analysis_{FUNCTION_TO_TEST}.pdf"
    plt.savefig(pdf_filename, bbox_inches='tight')
    plt.show()
