//! Refinement module for updating tree statistics after splits
//!
//! This module contains the refinement strategies that update precomputed statistics
//! after each split/resplit/merge operation. The actual tree updates are handled by
//! the reducer in `grid/reducer.rs`.

use std::iter::once;

use crate::grid::state::{AffectedRange, FittingState};
use crate::grid::two_tensor_solver::{solve_two_tensor, DEFAULT_V_MAX, DEFAULT_V_MIN};

pub enum RefinementStrategy {
    L2Refinement {
        alpha: f64,
        /// Two-tensor L2 coupling between u_+ and u_-.
        tilt_tau: f64,
        /// Two-tensor L1 coupling on (u_+ - u_-).
        tilt_rho: f64,
        /// Prior sample size for parent anchoring (tau_0).
        /// Interpreted as "how many samples worth of confidence in the parent".
        prior_sample_size: f64,
        update_clamp: f64,
    },
    HuberRefinement {
        alpha: f64,
        c: f64,
        tilt_tau: f64,
        tilt_rho: f64,
        prior_sample_size: f64,
        update_clamp: f64,
    },
}

impl RefinementStrategy {
    #[inline]
    pub fn tilt_tau(&self) -> f64 {
        match self {
            RefinementStrategy::L2Refinement { tilt_tau, .. }
            | RefinementStrategy::HuberRefinement { tilt_tau, .. } => *tilt_tau,
        }
    }

    #[inline]
    pub fn tilt_rho(&self) -> f64 {
        match self {
            RefinementStrategy::L2Refinement { tilt_rho, .. }
            | RefinementStrategy::HuberRefinement { tilt_rho, .. } => *tilt_rho,
        }
    }
}

#[inline]
pub fn prefix_range(prefix: &[f64], start: usize, end: usize) -> f64 {
    if start == 0 {
        prefix[end - 1]
    } else {
        prefix[end - 1] - prefix[start - 1]
    }
}

#[inline]
fn boundary_index_range(boundaries: &[usize], lo_pos: usize, hi_pos: usize) -> (usize, usize) {
    assert!(
        !boundaries.is_empty(),
        "boundary_index_range: called with empty boundaries"
    );
    if lo_pos > hi_pos {
        panic!(
            "boundary_index_range: lower position ({}) is greater than higher position ({})!",
            lo_pos, hi_pos
        );
    }
    // conservative range covering boundaries whose (start..end) intersects [lo_pos..hi_pos]
    let lo = boundaries
        .partition_point(|&b| b < lo_pos)
        .saturating_sub(1);
    let hi = boundaries
        .partition_point(|&b| b <= hi_pos)
        .min(boundaries.len() - 1);

    (lo, hi)
}

#[inline]
pub fn refinement_update(alpha: f64, s_rb: f64, s_bb: f64, n: f64) -> (f64, f64) {
    (
        s_rb / (s_bb + n * alpha) + 1.0,
        s_rb.powi(2) * (s_bb + 2.0 * n * alpha) / (s_bb + n * alpha).powi(2),
    )
}

#[inline]
pub fn l2_update_unanchored(alpha: f64, s_rb: f64, s_bb: f64) -> (f64, f64, f64) {
    let numerator = s_rb;
    let denominator = s_bb + alpha;
    let u = if denominator > 0.0 {
        numerator / denominator
    } else {
        0.0
    };
    (u, denominator, numerator)
}

#[inline]
pub fn l2_update(alpha: f64, tau: f64, u_p: f64, s_rb: f64, s_bb: f64) -> (f64, f64, f64) {
    let numerator = tau.mul_add(u_p, s_rb);
    let denominator = s_bb + tau + alpha;
    let u = if denominator > 0.0 {
        numerator / denominator
    } else {
        0.0
    };
    (u, denominator, numerator)
}

#[inline]
pub fn l2_gain_raw(update: f64, s_rb: f64, s_bb: f64) -> f64 {
    2.0 * update * s_rb - s_bb * update * update
}

#[inline]
pub fn unpenalized_merge(
    alpha: f64,
    s_rb_left: f64,
    s_bb_left: f64,
    s_rb_right: f64,
    s_bb_right: f64,
    v_left: f64,
    v_right: f64,
    n_left: f64,
    n_right: f64,
) -> (f64, f64) {
    // Solves current_err - merged_err on unpenalized scale (r^2)
    let inv_vl = 1.0 / v_left;
    let inv_vr = 1.0 / v_right;
    let inv_vl2 = inv_vl * inv_vl;
    let inv_vr2 = inv_vr * inv_vr;

    let s_rb_union = s_rb_left + s_rb_right;
    let s_bb_union = s_bb_left + s_bb_right;

    let current_error = -2.0 * s_rb_union - s_bb_union;

    #[allow(non_snake_case)]
    let D = (s_rb_left + s_bb_left) * inv_vl + (s_rb_right + s_bb_right) * inv_vr;
    #[allow(non_snake_case)]
    let B = s_bb_left * inv_vl2 + s_bb_right * inv_vr2;

    let merged_update = (D + alpha * (n_left + n_right)) / (B + alpha * (n_left + n_right));
    let merged_error = merged_update * merged_update * B - 2.0 * merged_update * D;

    let err_reduction_merge = current_error - merged_error;
    (merged_update, err_reduction_merge)
}

#[inline]
pub fn l2_weighting(_: f64) -> f64 {
    1.0
}

#[inline]
pub fn huber_weighting(res: f64, c: f64) -> f64 {
    if res.abs() <= c {
        1.0
    } else {
        c / res.abs()
    }
}

/// Clamps a multiplier value to the exponential range [exp(-c), exp(c)] while preserving sign.
/// For finite c > 0, clamps the magnitude in log-space and preserves the sign.
/// For infinite c, returns the original value (no clamping).
#[inline]
fn clamp_multiplier_signed(m_raw: f64, c: f64) -> f64 {
    if !c.is_finite() {
        return m_raw;
    }
    if m_raw == 0.0 {
        return (-c).exp();
    }
    let s = m_raw.signum();
    let mag = (m_raw.abs()).ln().clamp(-c, c).exp();
    s * mag
}

impl RefinementStrategy {
    pub fn initialize<'a>(&'a self, mut state: FittingState<'a>) -> FittingState<'a> {
        // Check if Stage 1 positive-only mode: all residuals are nonnegative
        // Spec: AI_CONTEXT/30_algorithms.md §1.4
        let is_stage1 = state.residuals.iter().all(|&r| r >= 0.0);

        // Initialize two-tensor lambdas from current outer residuals (labels - stage prediction).
        // Spec: AI_CONTEXT/30_algorithms.md §1.2 (full) or §1.4.2 (Stage 1)
        const EPS_LAMBDA: f64 = 1e-10;

        if is_stage1 {
            // Stage 1 positive-only initialization
            // Set λ_+ = 1 (canonical starting point) or scale-matched value
            // Set λ_- = 0 (fixed, no negative component)
            state.lambda_plus = 1.0;
            state.lambda_minus = 0.0;

            // Ensure all tilt values are zero (already initialized in FittingState::new)
            // But explicitly set them to be safe
            for d_vec in state.tilt_values.iter_mut() {
                for d in d_vec.iter_mut() {
                    *d = 0.0;
                }
            }
        } else {
            // Full two-tensor initialization
            let mut sum_w = 0.0;
            let mut sum_pos = 0.0;
            let mut sum_neg = 0.0;
            for i in 0..state.n {
                let w = self.weight(state.residuals[i]);
                let r = state.residuals[i];
                sum_w += w;
                if r > 0.0 {
                    sum_pos += w * r;
                } else {
                    sum_neg += w * (-r);
                }
            }
            let denom = sum_w.max(EPS_LAMBDA);
            state.lambda_plus = (sum_pos / denom).max(EPS_LAMBDA);
            state.lambda_minus = (sum_neg / denom).max(EPS_LAMBDA);
        }

        // Initialize per-point caches from (b=1, d=0) and lambdas.
        // In Stage 1: f_+ = λ_+ * prod_j b_j, f_- = 0, f = f_+
        // In full mode: f_+ = λ_+ * prod_j b_j * exp(d_j), f_- = λ_- * prod_j b_j * exp(-d_j), f = f_+ - f_-
        state.f_plus.fill(state.lambda_plus);
        state.f_minus.fill(state.lambda_minus);
        state.f.assign(&(&state.f_plus - &state.f_minus));
        state.y_hat.assign(&state.f);
        state
            .residuals
            .assign(&(state.labels.to_owned() - state.y_hat.view()));
        state.r_tilde.assign(&state.residuals);

        for col in 0..state.p {
            let mut indices = (0..state.n).collect::<Vec<_>>();
            indices.sort_by(|&a, &b| state.x[[a, col]].partial_cmp(&state.x[[b, col]]).unwrap());

            for (pos, &i) in indices.iter().enumerate() {
                state.precomputed_statistics.sort_order[col][i] = pos;
            }

            // Compute per-point two-tensor statistics and build prefix sums
            // For two-tensor, we need 5 sufficient statistics:
            //   c_s11[i] = w_i * f_plus[i]²
            //   c_s22[i] = w_i * f_minus[i]²
            //   c_s12[i] = w_i * f_plus[i] * f_minus[i]
            //   c_t1[i] = w_i * r_tilde[i] * f_plus[i]
            //   c_t2[i] = w_i * r_tilde[i] * f_minus[i]
            // We only need to populate per-point contributions once (col == 0), since they're point-specific
            let mut s11_acc = 0.0;
            let mut s22_acc = 0.0;
            let mut s12_acc = 0.0;
            let mut t1_acc = 0.0;
            let mut t2_acc = 0.0;
            let mut s11 = Vec::with_capacity(indices.len());
            let mut s22 = Vec::with_capacity(indices.len());
            let mut s12 = Vec::with_capacity(indices.len());
            let mut t1 = Vec::with_capacity(indices.len());
            let mut t2 = Vec::with_capacity(indices.len());

            for &i in indices.iter() {
                let fp = state.f_plus[i];
                let fm = state.f_minus[i];
                let rt = state.r_tilde[i];
                let w = self.weight(state.residuals[i]);

                // In Stage 1 mode: c_s22 = 0, c_s12 = 0, c_t2 = 0
                // Reuse existing caches but set unused ones to zero
                let c_s11_val = w * fp * fp;
                let c_s22_val = if is_stage1 { 0.0 } else { w * fm * fm };
                let c_s12_val = if is_stage1 { 0.0 } else { w * fp * fm };
                let c_t1_val = w * rt * fp;
                let c_t2_val = if is_stage1 { 0.0 } else { w * rt * fm };

                // Store per-point values (only on first column to avoid redundant writes)
                if col == 0 {
                    state.precomputed_statistics.c_s11[i] = c_s11_val;
                    state.precomputed_statistics.c_s22[i] = c_s22_val;
                    state.precomputed_statistics.c_s12[i] = c_s12_val;
                    state.precomputed_statistics.c_t1[i] = c_t1_val;
                    state.precomputed_statistics.c_t2[i] = c_t2_val;
                }

                s11_acc += c_s11_val;
                s22_acc += c_s22_val;
                s12_acc += c_s12_val;
                t1_acc += c_t1_val;
                t2_acc += c_t2_val;
                s11.push(s11_acc);
                s22.push(s22_acc);
                s12.push(s12_acc);
                t1.push(t1_acc);
                t2.push(t2_acc);
            }

            // Store prefix sums and sorted indices BEFORE computing error reductions
            state.precomputed_statistics.prefix_sums_s11[col] = s11;
            state.precomputed_statistics.prefix_sums_s22[col] = s22;
            state.precomputed_statistics.prefix_sums_s12[col] = s12;
            state.precomputed_statistics.prefix_sums_t1[col] = t1;
            state.precomputed_statistics.prefix_sums_t2[col] = t2;
            state.precomputed_statistics.sorted_indices[col] = indices;

            // Initialize interval stats: one interval covering all points
            // Note: S12 and t2 have negative signs in the sufficient stats definition
            state.precomputed_statistics.interval_stats[col] =
                vec![crate::grid::state::IntervalStats {
                    sum_s11: s11_acc,
                    sum_s22: s22_acc,
                    sum_s12: s12_acc, // Will be negated when used: S12 = -sum(c_s12)
                    sum_t1: t1_acc,
                    sum_t2: t2_acc, // Will be negated when used: t2 = -sum(c_t2)
                    n: state.n,
                }];

            // Mark duplicate values as NaN in error_reductions
            state.precomputed_statistics.sorted_indices[col]
                .iter()
                .enumerate()
                .scan(None, |last_value: &mut Option<f64>, (i, &idx)| {
                    let current_value = state.x[[idx, col]];
                    let is_duplicate = last_value
                        .map(|val| (current_value - val).abs() < f64::EPSILON)
                        .unwrap_or(false);
                    *last_value = Some(current_value);
                    Some((i, is_duplicate))
                })
                .filter(|&(_i, is_duplicate)| is_duplicate)
                .for_each(|(i, _)| {
                    state.precomputed_statistics.error_reductions_split[col][i] = f64::NAN;
                    state.precomputed_statistics.update_pairs_split_left[col][i] =
                        (f64::NAN, f64::NAN);
                    state.precomputed_statistics.update_pairs_split_right[col][i] =
                        (f64::NAN, f64::NAN);
                    state.precomputed_statistics.error_reductions_split_pairs[col][i] =
                        (f64::NAN, f64::NAN);
                });
        }

        // Compute error reductions AFTER all columns have been initialized
        state = self.update_error_reductions_split_for_all_cols(state);

        state.precomputed_statistics.initialized = true;
        state
    }

    pub fn update_statistics<'a>(
        &self,
        mut state: FittingState<'a>,
        col: usize,
        start: usize,
        end: usize,
        affected_points_range: &[(usize, usize)],
    ) -> FittingState<'a> {
        let updated_points = &state.precomputed_statistics.sorted_indices[col][start..end];
        if updated_points.is_empty() {
            return state;
        }

        // 1) Compute new per-point two-tensor statistics and their deltas vs stored c_*.
        // Precomputing deltas avoids re-loading old c_* values for every affected column.
        let f_plus = &state.f_plus;
        let f_minus = &state.f_minus;
        let r_tilde = &state.r_tilde;
        let residuals = &state.residuals;

        let (c_s11_new, c_s22_new, c_s12_new, c_t1_new, c_t2_new, d_s11, d_s22, d_s12, d_t1, d_t2) =
            {
                let old_s11 = &state.precomputed_statistics.c_s11;
                let old_s22 = &state.precomputed_statistics.c_s22;
                let old_s12 = &state.precomputed_statistics.c_s12;
                let old_t1 = &state.precomputed_statistics.c_t1;
                let old_t2 = &state.precomputed_statistics.c_t2;

                let mut c_s11_new = Vec::with_capacity(updated_points.len());
                let mut c_s22_new = Vec::with_capacity(updated_points.len());
                let mut c_s12_new = Vec::with_capacity(updated_points.len());
                let mut c_t1_new = Vec::with_capacity(updated_points.len());
                let mut c_t2_new = Vec::with_capacity(updated_points.len());
                let mut d_s11 = Vec::with_capacity(updated_points.len());
                let mut d_s22 = Vec::with_capacity(updated_points.len());
                let mut d_s12 = Vec::with_capacity(updated_points.len());
                let mut d_t1 = Vec::with_capacity(updated_points.len());
                let mut d_t2 = Vec::with_capacity(updated_points.len());

                for &pt in updated_points.iter() {
                    let fp = f_plus[pt];
                    let fm = f_minus[pt];
                    let rt = r_tilde[pt];
                    let w = self.weight(residuals[pt]);

                    let new_s11 = w * fp * fp;
                    let new_s22 = w * fm * fm;
                    let new_s12 = w * fp * fm;
                    let new_t1 = w * rt * fp;
                    let new_t2 = w * rt * fm;

                    c_s11_new.push(new_s11);
                    c_s22_new.push(new_s22);
                    c_s12_new.push(new_s12);
                    c_t1_new.push(new_t1);
                    c_t2_new.push(new_t2);

                    d_s11.push(new_s11 - old_s11[pt]);
                    d_s22.push(new_s22 - old_s22[pt]);
                    d_s12.push(new_s12 - old_s12[pt]);
                    d_t1.push(new_t1 - old_t1[pt]);
                    d_t2.push(new_t2 - old_t2[pt]);
                }

                (
                    c_s11_new, c_s22_new, c_s12_new, c_t1_new, c_t2_new, d_s11, d_s22, d_s12,
                    d_t1, d_t2,
                )
            };

        // 2) For other columns, apply delta updates using stored c_* values
        let mut bucket_ds11: Vec<f64> = Vec::new();
        let mut bucket_ds22: Vec<f64> = Vec::new();
        let mut bucket_ds12: Vec<f64> = Vec::new();
        let mut bucket_dt1: Vec<f64> = Vec::new();
        let mut bucket_dt2: Vec<f64> = Vec::new();
        for (c, &(min_pos, max_pos)) in affected_points_range
            .iter()
            .enumerate()
            .filter(|&(c, _)| c != col)
        {
            let span = max_pos - min_pos + 1;
            bucket_ds11.clear();
            bucket_ds22.clear();
            bucket_ds12.clear();
            bucket_dt1.clear();
            bucket_dt2.clear();
            bucket_ds11.resize(span, 0.0);
            bucket_ds22.resize(span, 0.0);
            bucket_ds12.resize(span, 0.0);
            bucket_dt1.resize(span, 0.0);
            bucket_dt2.resize(span, 0.0);

            let sort_order = &state.precomputed_statistics.sort_order[c];
            let prefix_sums_s11 = &mut state.precomputed_statistics.prefix_sums_s11[c];
            let prefix_sums_s22 = &mut state.precomputed_statistics.prefix_sums_s22[c];
            let prefix_sums_s12 = &mut state.precomputed_statistics.prefix_sums_s12[c];
            let prefix_sums_t1 = &mut state.precomputed_statistics.prefix_sums_t1[c];
            let prefix_sums_t2 = &mut state.precomputed_statistics.prefix_sums_t2[c];

            for (j, &pt) in updated_points.iter().enumerate() {
                let pos_full = sort_order[pt];
                let off = pos_full - min_pos;

                bucket_ds11[off] += d_s11[j];
                bucket_ds22[off] += d_s22[j];
                bucket_ds12[off] += d_s12[j];
                bucket_dt1[off] += d_t1[j];
                bucket_dt2[off] += d_t2[j];
            }

            let mut acc_ds11 = 0.0;
            let mut acc_ds22 = 0.0;
            let mut acc_ds12 = 0.0;
            let mut acc_dt1 = 0.0;
            let mut acc_dt2 = 0.0;
            let s11_range = &mut prefix_sums_s11[min_pos..=max_pos];
            let s22_range = &mut prefix_sums_s22[min_pos..=max_pos];
            let s12_range = &mut prefix_sums_s12[min_pos..=max_pos];
            let t1_range = &mut prefix_sums_t1[min_pos..=max_pos];
            let t2_range = &mut prefix_sums_t2[min_pos..=max_pos];
            for off in 0..span {
                acc_ds11 += bucket_ds11[off];
                acc_ds22 += bucket_ds22[off];
                acc_ds12 += bucket_ds12[off];
                acc_dt1 += bucket_dt1[off];
                acc_dt2 += bucket_dt2[off];
                s11_range[off] += acc_ds11;
                s22_range[off] += acc_ds22;
                s12_range[off] += acc_ds12;
                t1_range[off] += acc_dt1;
                t2_range[off] += acc_dt2;
            }

            if let Some(tail) = prefix_sums_s11.get_mut((max_pos + 1)..) {
                for v in tail {
                    *v += acc_ds11;
                }
            }
            if let Some(tail) = prefix_sums_s22.get_mut((max_pos + 1)..) {
                for v in tail {
                    *v += acc_ds22;
                }
            }
            if let Some(tail) = prefix_sums_s12.get_mut((max_pos + 1)..) {
                for v in tail {
                    *v += acc_ds12;
                }
            }
            if let Some(tail) = prefix_sums_t1.get_mut((max_pos + 1)..) {
                for v in tail {
                    *v += acc_dt1;
                }
            }
            if let Some(tail) = prefix_sums_t2.get_mut((max_pos + 1)..) {
                for v in tail {
                    *v += acc_dt2;
                }
            }
        }

        let prefix_sums_s11 = &mut state.precomputed_statistics.prefix_sums_s11[col];
        let prefix_sums_s22 = &mut state.precomputed_statistics.prefix_sums_s22[col];
        let prefix_sums_s12 = &mut state.precomputed_statistics.prefix_sums_s12[col];
        let prefix_sums_t1 = &mut state.precomputed_statistics.prefix_sums_t1[col];
        let prefix_sums_t2 = &mut state.precomputed_statistics.prefix_sums_t2[col];

        // 3) For the split column, rebuild the interval [start..end) and shift the tail
        let min_pos = start;
        let max_pos = end - 1;
        let prev_s11 = if min_pos > 0 {
            prefix_sums_s11[min_pos - 1]
        } else {
            0.0
        };
        let prev_s22 = if min_pos > 0 {
            prefix_sums_s22[min_pos - 1]
        } else {
            0.0
        };
        let prev_s12 = if min_pos > 0 {
            prefix_sums_s12[min_pos - 1]
        } else {
            0.0
        };
        let prev_t1 = if min_pos > 0 {
            prefix_sums_t1[min_pos - 1]
        } else {
            0.0
        };
        let prev_t2 = if min_pos > 0 {
            prefix_sums_t2[min_pos - 1]
        } else {
            0.0
        };
        let old_last_s11 = prefix_sums_s11[max_pos];
        let old_last_s22 = prefix_sums_s22[max_pos];
        let old_last_s12 = prefix_sums_s12[max_pos];
        let old_last_t1 = prefix_sums_t1[max_pos];
        let old_last_t2 = prefix_sums_t2[max_pos];

        let mut acc_s11 = prev_s11;
        for (k, v) in c_s11_new.iter().enumerate() {
            acc_s11 += *v;
            prefix_sums_s11[min_pos + k] = acc_s11;
        }
        let mut acc_s22 = prev_s22;
        for (k, v) in c_s22_new.iter().enumerate() {
            acc_s22 += *v;
            prefix_sums_s22[min_pos + k] = acc_s22;
        }
        let mut acc_s12 = prev_s12;
        for (k, v) in c_s12_new.iter().enumerate() {
            acc_s12 += *v;
            prefix_sums_s12[min_pos + k] = acc_s12;
        }
        let mut acc_t1 = prev_t1;
        for (k, v) in c_t1_new.iter().enumerate() {
            acc_t1 += *v;
            prefix_sums_t1[min_pos + k] = acc_t1;
        }
        let mut acc_t2 = prev_t2;
        for (k, v) in c_t2_new.iter().enumerate() {
            acc_t2 += *v;
            prefix_sums_t2[min_pos + k] = acc_t2;
        }

        let diff_s11 = prefix_sums_s11[max_pos] - old_last_s11;
        let diff_s22 = prefix_sums_s22[max_pos] - old_last_s22;
        let diff_s12 = prefix_sums_s12[max_pos] - old_last_s12;
        let diff_t1 = prefix_sums_t1[max_pos] - old_last_t1;
        let diff_t2 = prefix_sums_t2[max_pos] - old_last_t2;
        if let Some(tail) = prefix_sums_s11.get_mut((max_pos + 1)..) {
            for v in tail {
                *v += diff_s11;
            }
        }
        if let Some(tail) = prefix_sums_s22.get_mut((max_pos + 1)..) {
            for v in tail {
                *v += diff_s22;
            }
        }
        if let Some(tail) = prefix_sums_s12.get_mut((max_pos + 1)..) {
            for v in tail {
                *v += diff_s12;
            }
        }
        if let Some(tail) = prefix_sums_t1.get_mut((max_pos + 1)..) {
            for v in tail {
                *v += diff_t1;
            }
        }
        if let Some(tail) = prefix_sums_t2.get_mut((max_pos + 1)..) {
            for v in tail {
                *v += diff_t2;
            }
        }

        // 4) Update stored per-point c_* values
        {
            let c_s11 = &mut state.precomputed_statistics.c_s11;
            let c_s22 = &mut state.precomputed_statistics.c_s22;
            let c_s12 = &mut state.precomputed_statistics.c_s12;
            let c_t1 = &mut state.precomputed_statistics.c_t1;
            let c_t2 = &mut state.precomputed_statistics.c_t2;
            for (j, &pt) in updated_points.iter().enumerate() {
                c_s11[pt] = c_s11_new[j];
                c_s22[pt] = c_s22_new[j];
                c_s12[pt] = c_s12_new[j];
                c_t1[pt] = c_t1_new[j];
                c_t2[pt] = c_t2_new[j];
            }
        }

        state
    }

    pub fn refresh_error_reduction_caches_for_affected<'a>(
        &self,
        mut state: FittingState<'a>,
        affected_ranges: &[AffectedRange],
    ) -> FittingState<'a> {
        for affected_range in affected_ranges {
            let c = affected_range.col;
            let (lo, hi) = affected_range.point_range;

            // Update split error reductions using point range
            self.update_error_reductions_split_for_col_range(&mut state, c, lo, hi);

            // Update interval stats and boundary caches using boundary interval range
            // Note: interval_range in AffectedRange is for allowed_intervals,
            // but we need boundary indices for cache refreshing, so compute from point range
            let n_splits = state.boundaries[c].len();
            if n_splits > 0 {
                // Compute boundary interval indices from point range
                let (blo, bhi) = state.compute_boundary_index_range(c, lo, hi);

                // Update interval stats for affected intervals (O(1) lookup for resplit/merge)
                Self::update_interval_stats_for_col_range(&mut state, c, blo, bhi);

                // Convert interval range to boundary range
                // Intervals [blo, bhi] affect boundaries [blo-1, bhi] (clamped to valid boundaries)
                let n_boundaries = state.boundaries[c].len();
                if n_boundaries > 0 {
                    let boundary_lo = blo.saturating_sub(1);
                    let boundary_hi = bhi.min(n_boundaries - 1);

                    self.update_error_reductions_resplit_for_col_range(
                        &mut state,
                        c,
                        boundary_lo,
                        boundary_hi,
                    );
                    // Only update merge error reductions if merge is enabled
                    if state.split_strategy_state.merge_enabled {
                        self.update_error_reductions_merge_for_col_range(
                            &mut state,
                            c,
                            boundary_lo,
                            boundary_hi,
                        );
                    }
                }
            }
        }

        state
    }

    /// Recompute interval stats for affected intervals using prefix sums
    fn update_interval_stats_for_col_range(
        state: &mut FittingState,
        col: usize,
        blo: usize,
        bhi: usize,
    ) {
        let boundaries = &state.boundaries[col];
        let n = state.x.nrows();
        let s11 = &state.precomputed_statistics.prefix_sums_s11[col];
        let s22 = &state.precomputed_statistics.prefix_sums_s22[col];
        let s12 = &state.precomputed_statistics.prefix_sums_s12[col];
        let t1 = &state.precomputed_statistics.prefix_sums_t1[col];
        let t2 = &state.precomputed_statistics.prefix_sums_t2[col];

        // Update stats for intervals [blo..=bhi+1] since a boundary affects two adjacent intervals
        let start_interval = blo;
        let end_interval = (bhi + 2).min(state.precomputed_statistics.interval_stats[col].len());

        for interval_idx in start_interval..end_interval {
            let start = if interval_idx == 0 {
                0
            } else {
                boundaries[interval_idx - 1]
            };
            let end = boundaries.get(interval_idx).copied().unwrap_or(n);
            let interval_n = end - start;

            let sum_s11 = prefix_range(s11, start, end);
            let sum_s22 = prefix_range(s22, start, end);
            let sum_s12 = prefix_range(s12, start, end);
            let sum_t1 = prefix_range(t1, start, end);
            let sum_t2 = prefix_range(t2, start, end);

            state.precomputed_statistics.interval_stats[col][interval_idx] =
                crate::grid::state::IntervalStats {
                    sum_s11,
                    sum_s22,
                    sum_s12,
                    sum_t1,
                    sum_t2,
                    n: interval_n,
                };
        }
    }

    fn update_error_reductions_split_for_all_cols<'a>(
        &self,
        mut state: FittingState<'a>,
    ) -> FittingState<'a> {
        let n = state.x.nrows();
        for col in 0..state.boundaries.len() {
            self.update_error_reductions_split_for_col_range(&mut state, col, 0, n);
        }
        state
    }

    fn update_error_reductions_split_for_col_range(
        &self,
        state: &mut FittingState,
        col: usize,
        lo: usize,
        hi: usize,
    ) {
        let n = state.x.nrows();
        let is_stage1 = state.is_stage1_positive_only();
        let mut start = 0usize;
        for &b in state.boundaries[col].iter().chain(once(&n)) {
            // Computes the error reductions for any interval that intersects with the updated points range

            let end = b;
            if end > lo {
                self.update_error_reductions_split_single_interval(
                    &state.precomputed_statistics.prefix_sums_s11[col],
                    &state.precomputed_statistics.prefix_sums_s22[col],
                    &state.precomputed_statistics.prefix_sums_s12[col],
                    &state.precomputed_statistics.prefix_sums_t1[col],
                    &state.precomputed_statistics.prefix_sums_t2[col],
                    &mut state.precomputed_statistics.update_pairs_split_left[col],
                    &mut state.precomputed_statistics.update_pairs_split_right[col],
                    &mut state.precomputed_statistics.error_reductions_split[col],
                    &mut state.precomputed_statistics.error_reductions_split_pairs[col],
                    (start, end),
                    is_stage1,
                );
            }
            start = end;
            if start > hi {
                break;
            }
        }
    }

    fn update_error_reductions_resplit_for_col_range(
        &self,
        state: &mut FittingState,
        col: usize,
        lo_boundary_idx: usize,
        hi_boundary_idx: usize,
    ) {
        let boundary_pos = &state.boundaries[col];
        if boundary_pos.is_empty() {
            return;
        }

        // Ensure all resplit caches are sized to the current number of boundaries
        let target_len = boundary_pos.len();
        if state.precomputed_statistics.error_reductions_resplit[col].len() < target_len {
            state.precomputed_statistics.error_reductions_resplit[col].resize(target_len, f64::NAN);
        }
        if state.precomputed_statistics.update_pairs_resplit_left[col].len() < target_len {
            state.precomputed_statistics.update_pairs_resplit_left[col]
                .resize(target_len, (f64::NAN, f64::NAN));
        }
        if state.precomputed_statistics.update_pairs_resplit_right[col].len() < target_len {
            state.precomputed_statistics.update_pairs_resplit_right[col]
                .resize(target_len, (f64::NAN, f64::NAN));
        }
        if state.precomputed_statistics.error_reductions_resplit_pairs[col].len() < target_len {
            state.precomputed_statistics.error_reductions_resplit_pairs[col]
                .resize(target_len, (f64::NAN, f64::NAN));
        }

        let is_stage1 = state.is_stage1_positive_only();
        let alpha = self.alpha();
        let v_min = DEFAULT_V_MIN;
        let v_max = DEFAULT_V_MAX;

        for i in lo_boundary_idx..=hi_boundary_idx {
            // Use interval stats directly (O(1) lookup instead of prefix_range computation)
            let left_stats = &state.precomputed_statistics.interval_stats[col][i];
            let right_stats = &state.precomputed_statistics.interval_stats[col][i + 1];

            if is_stage1 {
                // Stage 1 positive-only: use 1D ridge solver
                // H^L = S_{11}^L, g^L = t_1^L
                let h_l = left_stats.sum_s11;
                let g_l = left_stats.sum_t1;
                let (u_l, _denom_l, _num_l) = l2_update_unanchored(alpha, g_l, h_l);
                let v_b_l = (1.0 + u_l).clamp(v_min, v_max);
                let u_l_clamped = v_b_l - 1.0;
                let gain_l = l2_gain_raw(u_l_clamped, g_l, h_l);

                // H^R = S_{11}^R, g^R = t_1^R
                let h_r = right_stats.sum_s11;
                let g_r = right_stats.sum_t1;
                let (u_r, _denom_r, _num_r) = l2_update_unanchored(alpha, g_r, h_r);
                let v_b_r = (1.0 + u_r).clamp(v_min, v_max);
                let u_r_clamped = v_b_r - 1.0;
                let gain_r = l2_gain_raw(u_r_clamped, g_r, h_r);

                // Store updates: (u_plus, u_minus) = (u, 0) for Stage 1
                state.precomputed_statistics.error_reductions_resplit[col][i] = gain_l + gain_r;
                state.precomputed_statistics.update_pairs_resplit_left[col][i] = (u_l_clamped, 0.0);
                state.precomputed_statistics.update_pairs_resplit_right[col][i] =
                    (u_r_clamped, 0.0);
                state.precomputed_statistics.error_reductions_resplit_pairs[col][i] =
                    (gain_l, gain_r);
            } else {
                // Full two-tensor: use 2×2 solver
                let tau = self.tilt_tau();
                let rho = self.tilt_rho();

                // Solve for left side
                let (u_plus_l, u_minus_l, gain_l) = solve_two_tensor(
                    left_stats.sum_s11,
                    left_stats.sum_s22,
                    -left_stats.sum_s12,
                    left_stats.sum_t1,
                    -left_stats.sum_t2,
                    alpha,
                    tau,
                    rho,
                    v_min,
                    v_max,
                );

                // Solve for right side
                let (u_plus_r, u_minus_r, gain_r) = solve_two_tensor(
                    right_stats.sum_s11,
                    right_stats.sum_s22,
                    -right_stats.sum_s12,
                    right_stats.sum_t1,
                    -right_stats.sum_t2,
                    alpha,
                    tau,
                    rho,
                    v_min,
                    v_max,
                );

                // Store updates as (u_plus, u_minus) pairs for each side
                // Note: We'll convert to (v_b, delta_d) when applying the split
                state.precomputed_statistics.error_reductions_resplit[col][i] = gain_l + gain_r;
                state.precomputed_statistics.update_pairs_resplit_left[col][i] =
                    (u_plus_l, u_minus_l);
                state.precomputed_statistics.update_pairs_resplit_right[col][i] =
                    (u_plus_r, u_minus_r);
                state.precomputed_statistics.error_reductions_resplit_pairs[col][i] =
                    (gain_l, gain_r);
            }
        }
    }

    fn update_error_reductions_merge_for_col_range(
        &self,
        state: &mut FittingState,
        col: usize,
        lo_boundary_idx: usize,
        hi_boundary_idx: usize,
    ) {
        let boundary_pos = &state.boundaries[col];
        if boundary_pos.is_empty() {
            return;
        }

        // Compute partial products for this axis (efficient: O(n) using divide-out approach)
        let (g_plus, g_minus) = crate::grid::reducer::compute_partial_products_for_axis(col, state);

        let is_stage1 = state.is_stage1_positive_only();
        let alpha = self.alpha();
        let v_min = DEFAULT_V_MIN;
        let v_max = DEFAULT_V_MAX;

        let sorted_indices = &state.precomputed_statistics.sorted_indices[col];
        for i in lo_boundary_idx..=hi_boundary_idx {
            // Get point ranges for left, right, and union intervals
            let (start, index, end) = state.interval_range_left_and_right(col, i);

            // Get sorted indices for each region (avoid per-boundary allocations)
            let left_region = &sorted_indices[start..index];
            let right_region = &sorted_indices[index..end];

            // Compute stats using partial products (g_{\pm}^{(-j)} regressors)
            let left_stats = crate::grid::reducer::compute_stats_using_partial_products(
                col,
                left_region,
                &g_plus,
                &g_minus,
                state,
            );
            let right_stats = crate::grid::reducer::compute_stats_using_partial_products(
                col,
                right_region,
                &g_plus,
                &g_minus,
                state,
            );
            // Union stats can be computed by additivity to avoid a third pass.
            let union_stats = crate::grid::state::IntervalStats::union(&left_stats, &right_stats);

            // Verify union additivity (I18 invariant) - debug builds only
            #[cfg(debug_assertions)]
            {
                const EPS: f64 = 1e-10;
                let union_region = &sorted_indices[start..end];
                let union_stats_direct = crate::grid::reducer::compute_stats_using_partial_products(
                    col,
                    union_region,
                    &g_plus,
                    &g_minus,
                    state,
                );
                assert!(
                    (union_stats.sum_s11 - union_stats_direct.sum_s11).abs() < EPS,
                    "I18 violation: S11 union additivity failed for col={}, boundary={}",
                    col,
                    i
                );
                assert!(
                    (union_stats.sum_s22 - union_stats_direct.sum_s22).abs() < EPS,
                    "I18 violation: S22 union additivity failed for col={}, boundary={}",
                    col,
                    i
                );
                assert!(
                    (union_stats.sum_s12 - union_stats_direct.sum_s12).abs() < EPS,
                    "I18 violation: S12 union additivity failed for col={}, boundary={}",
                    col,
                    i
                );
                assert!(
                    (union_stats.sum_t1 - union_stats_direct.sum_t1).abs() < EPS,
                    "I18 violation: t1 union additivity failed for col={}, boundary={}",
                    col,
                    i
                );
                assert!(
                    (union_stats.sum_t2 - union_stats_direct.sum_t2).abs() < EPS,
                    "I18 violation: t2 union additivity failed for col={}, boundary={}",
                    col,
                    i
                );
            }

            // Solve for optimal parameters for left, right, and union
            // Note: We compute gain_l and gain_r here even though error_reductions_resplit_pairs
            // stores (gain_l, gain_r) for this boundary. However, those are computed using
            // interval_stats (f_{\pm} regressors), while we need gains computed using partial
            // products (g_{\pm}^{(-j)} regressors). These are different, so we must recompute.
            // TODO: Consider caching these partial-product-based gains if we need them elsewhere.
            let (gain_l, gain_r, u_plus_merged, u_minus_merged, gain_merged) = if is_stage1 {
                // Stage 1 positive-only: use 1D ridge solver
                // Left side: H^L = S_{11}^L, g^L = t_1^L
                let h_l = left_stats.sum_s11;
                let g_l = left_stats.sum_t1;
                let (u_l, _denom_l, _num_l) = l2_update_unanchored(alpha, g_l, h_l);
                let v_b_l = (1.0 + u_l).clamp(v_min, v_max);
                let u_l_clamped = v_b_l - 1.0;
                let gain_l = l2_gain_raw(u_l_clamped, g_l, h_l);

                // Right side: H^R = S_{11}^R, g^R = t_1^R
                let h_r = right_stats.sum_s11;
                let g_r = right_stats.sum_t1;
                let (u_r, _denom_r, _num_r) = l2_update_unanchored(alpha, g_r, h_r);
                let v_b_r = (1.0 + u_r).clamp(v_min, v_max);
                let u_r_clamped = v_b_r - 1.0;
                let gain_r = l2_gain_raw(u_r_clamped, g_r, h_r);

                // Union: H^U = S_{11}^U, g^U = t_1^U
                let h_u = union_stats.sum_s11;
                let g_u = union_stats.sum_t1;
                let (u_u, _denom_u, _num_u) = l2_update_unanchored(alpha, g_u, h_u);
                let v_b_u = (1.0 + u_u).clamp(v_min, v_max);
                let u_u_clamped = v_b_u - 1.0;
                let gain_merged = l2_gain_raw(u_u_clamped, g_u, h_u);

                (gain_l, gain_r, u_u_clamped, 0.0, gain_merged)
            } else {
                // Full two-tensor: use 2×2 solver
                let tau = self.tilt_tau();
                let rho = self.tilt_rho();

                let (_u_plus_l, _u_minus_l, gain_l) = solve_two_tensor(
                    left_stats.sum_s11,
                    left_stats.sum_s22,
                    -left_stats.sum_s12,
                    left_stats.sum_t1,
                    -left_stats.sum_t2,
                    alpha,
                    tau,
                    rho,
                    v_min,
                    v_max,
                );

                let (_u_plus_r, _u_minus_r, gain_r) = solve_two_tensor(
                    right_stats.sum_s11,
                    right_stats.sum_s22,
                    -right_stats.sum_s12,
                    right_stats.sum_t1,
                    -right_stats.sum_t2,
                    alpha,
                    tau,
                    rho,
                    v_min,
                    v_max,
                );

                let (u_plus_merged, u_minus_merged, gain_merged) = solve_two_tensor(
                    union_stats.sum_s11,
                    union_stats.sum_s22,
                    -union_stats.sum_s12,
                    union_stats.sum_t1,
                    -union_stats.sum_t2,
                    alpha,
                    tau,
                    rho,
                    v_min,
                    v_max,
                );

                (gain_l, gain_r, u_plus_merged, u_minus_merged, gain_merged)
            };

            // Verify score dominance (I19 invariant) - debug builds only
            #[cfg(debug_assertions)]
            {
                const EPS: f64 = 1e-8;
                let gain_children = gain_l + gain_r;
                assert!(
                    gain_children >= gain_merged - EPS,
                    "I19 violation: Score dominance failed for col={}, boundary={}: g_A+g_B={}, g_U={}",
                    col, i, gain_children, gain_merged
                );
            }

            // Store boundary benefit as merge gain: Δ_boundary = (g_A + g_B) - g_U
            // Negative values mean merge improves objective (boundary not worth it)
            // Positive values mean keeping boundary improves objective
            let boundary_benefit = (gain_l + gain_r) - gain_merged;

            // Merge gain is the negative of boundary benefit (merge improves when boundary_benefit < 0)
            let merge_gain = -boundary_benefit;

            // Store (u_plus, u_minus) for merged interval (used when applying merge)
            state.precomputed_statistics.error_reductions_merge[col][i] = merge_gain;
            state.precomputed_statistics.update_pairs_merge[col][i] =
                (u_plus_merged, u_minus_merged);
        }
    }

    fn update_error_reductions_split_single_interval(
        &self,
        ps_s11: &[f64],
        ps_s22: &[f64],
        ps_s12: &[f64],
        ps_t1: &[f64],
        ps_t2: &[f64],
        update_left: &mut [(f64, f64)],
        update_right: &mut [(f64, f64)],
        error_reductions: &mut [f64],
        error_reductions_pairs: &mut [(f64, f64)],
        spanned_interval: (usize, usize),
        is_stage1: bool,
    ) {
        let (start, end) = spanned_interval;
        let len = end.saturating_sub(start);
        if len < 2 {
            return;
        }

        // Disallow splitting at the first position.
        update_left[start] = (f64::NAN, f64::NAN);
        update_right[start] = (f64::NAN, f64::NAN);
        error_reductions[start] = f64::NAN;
        error_reductions_pairs[start] = (f64::NAN, f64::NAN);

        let alpha = self.alpha();
        let v_min = DEFAULT_V_MIN;
        let v_max = DEFAULT_V_MAX;

        if is_stage1 {
            // Stage 1 positive-only: use 1D ridge solver
            // Spec: AI_CONTEXT/30_algorithms.md §3.2.2
            // H^S = S_{11}^S, g^S = t_1^S
            // u = g/(H+alpha), v_b = 1+u, gain = 2ug - u²h
            let base_s11 = if start == 0 { 0.0 } else { ps_s11[start - 1] };
            let base_t1 = if start == 0 { 0.0 } else { ps_t1[start - 1] };
            let total_s11 = ps_s11[end - 1] - base_s11;
            let total_t1 = ps_t1[end - 1] - base_t1;

            // Candidate boundary positions pos in [start+1, end-1).
            for pos in (start + 1)..end.saturating_sub(1) {
                if error_reductions[pos].is_nan() {
                    // forbidden (e.g. duplicate x)
                    update_left[pos] = (f64::NAN, f64::NAN);
                    update_right[pos] = (f64::NAN, f64::NAN);
                    error_reductions_pairs[pos] = (f64::NAN, f64::NAN);
                    continue;
                }

                let pos_end = pos - 1;

                // Left sums over [start, pos), right sums over [pos, end).
                // Compute right by subtraction to halve prefix sum reads.
                let h_l = ps_s11[pos_end] - base_s11;
                let g_l = ps_t1[pos_end] - base_t1;
                let h_r = total_s11 - h_l;
                let g_r = total_t1 - g_l;

                // Solve 1D ridge regression for left side
                let (u_l, _denom_l, _num_l) = l2_update_unanchored(alpha, g_l, h_l);
                // Clamp: v_b = 1 + u in [v_min, v_max]
                let v_b_l = (1.0 + u_l).clamp(v_min, v_max);
                let u_l_clamped = v_b_l - 1.0;
                let gain_l = l2_gain_raw(u_l_clamped, g_l, h_l);

                // Solve 1D ridge regression for right side
                let (u_r, _denom_r, _num_r) = l2_update_unanchored(alpha, g_r, h_r);
                // Clamp: v_b = 1 + u in [v_min, v_max]
                let v_b_r = (1.0 + u_r).clamp(v_min, v_max);
                let u_r_clamped = v_b_r - 1.0;
                let gain_r = l2_gain_raw(u_r_clamped, g_r, h_r);

                // Store updates: (u_plus, u_minus) = (u, 0) for Stage 1
                update_left[pos] = (u_l_clamped, 0.0);
                update_right[pos] = (u_r_clamped, 0.0);
                error_reductions[pos] = gain_l + gain_r;
                error_reductions_pairs[pos] = (gain_l, gain_r);
            }
        } else {
            // Full two-tensor: use 2×2 solver
            let tau = self.tilt_tau();
            let rho = self.tilt_rho();

            let base_s11 = if start == 0 { 0.0 } else { ps_s11[start - 1] };
            let base_s22 = if start == 0 { 0.0 } else { ps_s22[start - 1] };
            let base_s12 = if start == 0 { 0.0 } else { ps_s12[start - 1] };
            let base_t1 = if start == 0 { 0.0 } else { ps_t1[start - 1] };
            let base_t2 = if start == 0 { 0.0 } else { ps_t2[start - 1] };

            let total_s11 = ps_s11[end - 1] - base_s11;
            let total_s22 = ps_s22[end - 1] - base_s22;
            let total_s12 = -(ps_s12[end - 1] - base_s12);
            let total_t1 = ps_t1[end - 1] - base_t1;
            let total_t2 = -(ps_t2[end - 1] - base_t2);

            // Candidate boundary positions pos in [start+1, end-1).
            for pos in (start + 1)..end.saturating_sub(1) {
                if error_reductions[pos].is_nan() {
                    // forbidden (e.g. duplicate x)
                    update_left[pos] = (f64::NAN, f64::NAN);
                    update_right[pos] = (f64::NAN, f64::NAN);
                    error_reductions_pairs[pos] = (f64::NAN, f64::NAN);
                    continue;
                }

                let pos_end = pos - 1;

                // Left sums over [start, pos), right sums over [pos, end).
                // Compute right by subtraction to halve prefix sum reads.
                let s11_l = ps_s11[pos_end] - base_s11;
                let s22_l = ps_s22[pos_end] - base_s22;
                let s12_l = -(ps_s12[pos_end] - base_s12);
                let t1_l = ps_t1[pos_end] - base_t1;
                let t2_l = -(ps_t2[pos_end] - base_t2);

                let s11_r = total_s11 - s11_l;
                let s22_r = total_s22 - s22_l;
                let s12_r = total_s12 - s12_l;
                let t1_r = total_t1 - t1_l;
                let t2_r = total_t2 - t2_l;

                let (u_plus_l, u_minus_l, gain_l) = solve_two_tensor(
                    s11_l, s22_l, s12_l, t1_l, t2_l, alpha, tau, rho, v_min, v_max,
                );
                let (u_plus_r, u_minus_r, gain_r) = solve_two_tensor(
                    s11_r, s22_r, s12_r, t1_r, t2_r, alpha, tau, rho, v_min, v_max,
                );

                update_left[pos] = (u_plus_l, u_minus_l);
                update_right[pos] = (u_plus_r, u_minus_r);
                error_reductions[pos] = gain_l + gain_r;
                error_reductions_pairs[pos] = (gain_l, gain_r);
            }
        }

        // Disallow splitting at the last position.
        let last = end - 1;
        update_left[last] = (f64::NAN, f64::NAN);
        update_right[last] = (f64::NAN, f64::NAN);
        error_reductions[last] = f64::NAN;
        error_reductions_pairs[last] = (f64::NAN, f64::NAN);
    }

    fn get_error_reductions_both_intervals_resplit(
        &self,
        srb: &[f64],
        sbb: &[f64],
        spanned_intervals: (usize, usize, usize),
    ) -> (f64, f64, f64) {
        let (start, split, end) = spanned_intervals;

        #[allow(non_snake_case)]
        let rbL = prefix_range(srb, start, split);
        #[allow(non_snake_case)]
        let bbL = prefix_range(sbb, start, split);
        #[allow(non_snake_case)]
        let rbR = prefix_range(srb, split, end);
        #[allow(non_snake_case)]
        let bbR = prefix_range(sbb, split, end);

        let n_l = (split - start) as f64;
        let n_r = (end - split) as f64;
        let (update_a, update_b, err_reduction_a, err_reduction_b) =
            self.refinement_update_from_precomputed_statistics(rbL, bbL, n_l, rbR, bbR, n_r);

        (update_a, update_b, err_reduction_a + err_reduction_b)
    }

    fn get_error_reductions_both_intervals_merge(
        &self,
        srb: &[f64],
        sbb: &[f64],
        spanned_intervals: (usize, usize, usize),
        grid_value_left: f64,
        grid_value_right: f64,
    ) -> (f64, f64) {
        let (start, split, end) = spanned_intervals;

        #[allow(non_snake_case)]
        let rbL = prefix_range(srb, start, split);
        #[allow(non_snake_case)]
        let bbL = prefix_range(sbb, start, split);
        #[allow(non_snake_case)]
        let rbR = prefix_range(srb, split, end);
        #[allow(non_snake_case)]
        let bbR = prefix_range(sbb, split, end);
        let n_l = (split - start) as f64;
        let n_r = (end - split) as f64;
        let (update_a, err_reduction_a) = self.merge_update_from_precomputed_statistics(
            rbL,
            bbL,
            rbR,
            bbR,
            grid_value_left,
            grid_value_right,
            n_l,
            n_r,
        );
        (update_a, err_reduction_a)
    }

    pub fn alpha(&self) -> f64 {
        match self {
            RefinementStrategy::L2Refinement { alpha, .. } => *alpha,
            RefinementStrategy::HuberRefinement { alpha, .. } => *alpha,
        }
    }

    pub fn weight(&self, res: f64) -> f64 {
        match self {
            RefinementStrategy::L2Refinement { .. } => 1.0,
            RefinementStrategy::HuberRefinement { c, .. } => huber_weighting(res, *c),
        }
    }

    /// Returns the prior sample size (tau_0) for parent anchoring.
    ///
    /// This represents "how many samples worth of confidence we have that
    /// children should equal their parent". With tau_0 = 30, a child interval
    /// with 10 samples will be heavily shrunk toward the parent, while a child
    /// with 100 samples will mostly trust its own data.
    pub fn prior_sample_size(&self) -> f64 {
        match self {
            RefinementStrategy::L2Refinement {
                prior_sample_size, ..
            } => *prior_sample_size,
            RefinementStrategy::HuberRefinement {
                prior_sample_size, ..
            } => *prior_sample_size,
        }
    }

    pub fn update_clamp(&self) -> f64 {
        match self {
            RefinementStrategy::L2Refinement { update_clamp, .. } => *update_clamp,
            RefinementStrategy::HuberRefinement { update_clamp, .. } => *update_clamp,
        }
    }

    /// Compute optimal multiplicative updates and error reduction gains for a split.
    ///
    /// Uses the new "prior sample size" (tau_0) approach for parent anchoring:
    /// - tau_0 represents "how many samples worth of confidence in the parent"
    /// - With tau_0 = 30, a child with 10 samples is heavily shrunk toward parent
    /// - With tau_0 = 30, a child with 100 samples mostly trusts its own data
    ///
    /// The update formula is:
    ///   m* = 1 + (tau * u_p + S_rb) / (S_bb + tau + alpha * n)
    ///
    /// where tau = tau_0 * avg_sbb (scale-invariant prior sample size).
    fn refinement_update_from_precomputed_statistics(
        &self,
        s_rb_left: f64,
        s_bb_left: f64,
        n_left: f64,
        s_rb_right: f64,
        s_bb_right: f64,
        n_right: f64,
    ) -> (f64, f64, f64, f64) {
        let alpha_left = self.alpha() * n_left;
        let alpha_right = self.alpha() * n_right;
        let tau_0 = self.prior_sample_size();

        let (left_stats, right_stats) = if tau_0 <= 0.0 {
            // No anchoring: use standard updates
            let left_stats = l2_update_unanchored(alpha_left, s_rb_left, s_bb_left);
            let right_stats = l2_update_unanchored(alpha_right, s_rb_right, s_bb_right);
            (left_stats, right_stats)
        } else {
            // NEW: Prior sample size approach
            //
            // tau_0 is the "prior sample size" - how many samples worth of confidence
            // we have that children should equal their parent.
            //
            // For scale invariance, we multiply by avg_sbb (average weight per sample):
            //   tau = tau_0 * (S_bb_total / n_total)
            //
            // This means tau has units of "weighted samples" like S_bb, making the
            // formula scale-invariant.

            let h_l = s_bb_left + alpha_left;
            let h_r = s_bb_right + alpha_right;
            let h_u = h_l + h_r;
            let n_total = n_left + n_right;

            // Compute parent (unified) update
            let r_u = s_rb_left + s_rb_right;
            let u_p = if h_u > 0.0 { r_u / h_u } else { 0.0 };

            // Scale-invariant tau: prior sample size * average weight per sample
            let avg_sbb = if n_total > 0.0 { h_u / n_total } else { 1.0 };
            let tau = tau_0 * avg_sbb;

            // The Bayesian update formula automatically gives more shrinkage to
            // smaller intervals: m* = (tau * u_p + S_rb) / (S_bb + tau + alpha)
            // When S_bb is small (few samples), tau dominates -> strong shrinkage
            // When S_bb is large (many samples), S_bb dominates -> trust data
            let left_stats = l2_update(alpha_left, tau, u_p, s_rb_left, s_bb_left);
            let right_stats = l2_update(alpha_right, tau, u_p, s_rb_right, s_bb_right);
            (left_stats, right_stats)
        };

        let (u_l, denom_l, numerator_l) = left_stats;
        let (u_r, denom_r, numerator_r) = right_stats;
        let m_l = u_l + 1.0;
        let m_r = u_r + 1.0;
        // Apply clamping to multipliers
        let update_clamp = self.update_clamp();
        if update_clamp.is_finite() {
            let m_l = clamp_multiplier_signed(m_l, update_clamp);
            let m_r = clamp_multiplier_signed(m_r, update_clamp);

            // Derive clamped additive steps and compute gains with clamped values
            let u_clamped_l = m_l - 1.0;
            let u_clamped_r = m_r - 1.0;
            let g_l = self.clamped_gain(u_clamped_l, s_rb_left, s_bb_left, numerator_l, denom_l);
            let g_r = self.clamped_gain(u_clamped_r, s_rb_right, s_bb_right, numerator_r, denom_r);
            (m_l, m_r, g_l, g_r)
        } else {
            let g_l = self.gain(u_l, s_rb_left, s_bb_left, numerator_l, denom_l);
            let g_r = self.gain(u_r, s_rb_right, s_bb_right, numerator_r, denom_r);
            (m_l, m_r, g_l, g_r)
        }
    }

    fn gain(&self, u: f64, s_rb: f64, s_bb: f64, _numerator: f64, _denominator: f64) -> f64 {
        match self {
            RefinementStrategy::L2Refinement { .. } => l2_gain_raw(u, s_rb, s_bb),
            RefinementStrategy::HuberRefinement { .. } => l2_gain_raw(u, s_rb, s_bb),
        }
    }

    fn clamped_gain(
        &self,
        u: f64,
        s_rb: f64,
        s_bb: f64,
        _numerator: f64,
        _denominator: f64,
    ) -> f64 {
        match self {
            RefinementStrategy::L2Refinement { .. } => l2_gain_raw(u, s_rb, s_bb),
            RefinementStrategy::HuberRefinement { .. } => l2_gain_raw(u, s_rb, s_bb),
        }
    }

    /// Compute the gain for setting the multiplier to zero (u = -1).
    pub fn gain_for_zero_multiplier(&self, s_rb: f64, s_bb: f64, _n: f64) -> f64 {
        // When forcing multiplier to 0, the additive update is u = 0 - 1 = -1.
        // gain = 2u S_rb - u^2 S_bb = 2(-1)S_rb - (-1)^2 S_bb = -2 S_rb - S_bb.
        let u = -1.0;
        l2_gain_raw(u, s_rb, s_bb)
    }

    fn merge_update_from_precomputed_statistics(
        &self,
        s_rb_left: f64,
        s_bb_left: f64,
        s_rb_right: f64,
        s_bb_right: f64,
        v_left: f64,
        v_right: f64,
        n_left: f64,
        n_right: f64,
    ) -> (f64, f64) {
        match self {
            RefinementStrategy::L2Refinement { .. } => unpenalized_merge(
                self.alpha(),
                s_rb_left,
                s_bb_left,
                s_rb_right,
                s_bb_right,
                v_left,
                v_right,
                n_left,
                n_right,
            ),
            RefinementStrategy::HuberRefinement { .. } => unpenalized_merge(
                self.alpha(),
                s_rb_left,
                s_bb_left,
                s_rb_right,
                s_bb_right,
                v_left,
                v_right,
                n_left,
                n_right,
            ),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_no_clamp_when_infinite() {
        // Test that infinite update_clamp means no clamping
        let strategy = RefinementStrategy::L2Refinement {
            alpha: 0.0,
            tilt_tau: crate::grid::two_tensor_solver::DEFAULT_TAU,
            tilt_rho: crate::grid::two_tensor_solver::DEFAULT_RHO,
            prior_sample_size: 0.0,
            update_clamp: f64::INFINITY,
        };

        // Construct stats so u_raw_l = 5.0, u_raw_r = -0.2
        // s_rb_left = 5.0, s_bb_left = 1.0, n_left = 1.0
        // s_rb_right = -0.2, s_bb_right = 1.0, n_right = 1.0
        let (m_l, m_r, _g_l, _g_r) = strategy.refinement_update_from_precomputed_statistics(
            5.0, 1.0, 1.0, // left: s_rb, s_bb, n
            -0.2, 1.0, 1.0, // right: s_rb, s_bb, n
        );

        // Should equal 1 + u_raw_l exactly (within fp tolerance)
        // u_raw_l = 5.0/1.0 = 5.0, so m_l should be 6.0
        assert!((m_l - 6.0).abs() < 1e-12);
        // u_raw_r = -0.2/1.0 = -0.2, so m_r should be 0.8
        assert!((m_r - 0.8).abs() < 1e-12);
    }

    #[test]
    fn test_clamps_multiplier_to_exp_range() {
        // Test that multipliers are clamped to [exp(-C), exp(C)] range
        let strategy = RefinementStrategy::L2Refinement {
            alpha: 0.0,
            tilt_tau: crate::grid::two_tensor_solver::DEFAULT_TAU,
            tilt_rho: crate::grid::two_tensor_solver::DEFAULT_RHO,
            prior_sample_size: 0.0,
            update_clamp: 1.0,
        };

        // Choose stats giving u_raw=5.0 on left, -0.2 on right
        let (m_l, m_r, _, _) = strategy.refinement_update_from_precomputed_statistics(
            5.0, 1.0, 1.0, // left: s_rb, s_bb, n
            -0.2, 1.0, 1.0, // right: s_rb, s_bb, n
        );

        let _lo = (-1.0f64).exp(); // 0.367879...
        let hi = (1.0f64).exp(); // 2.718281...

        // Left side exceeded upper bound → clamped at hi
        assert!((m_l - hi).abs() < 1e-12);
        // Right side within bounds → unchanged: m_r ≈ 0.8
        assert!((m_r - 0.8).abs() < 1e-12);
    }

    #[test]
    fn test_signed_clamp_preserves_sign_and_bounds_magnitude() {
        // Test that signed clamping preserves sign and bounds magnitude
        let c = 2.0f64;
        let lo = (-c).exp(); // 0.1353
        let hi = (c).exp(); // 7.389

        let neg_big = clamp_multiplier_signed(-100.0, c);
        assert!(neg_big.is_sign_negative());
        assert!((neg_big.abs() - hi).abs() < 1e-12);

        let pos_tiny = clamp_multiplier_signed(1e-9, c);
        assert!(pos_tiny.is_sign_positive());
        assert!((pos_tiny - lo).abs() < 1e-12);

        // Test zero case
        let zero_clamped = clamp_multiplier_signed(0.0, c);
        assert!((zero_clamped - lo).abs() < 1e-12);

        // Test values within bounds (should be unchanged)
        let within_bounds = clamp_multiplier_signed(1.5, c);
        assert!((within_bounds - 1.5).abs() < 1e-12);

        let within_bounds_neg = clamp_multiplier_signed(-1.5, c);
        assert!((within_bounds_neg - (-1.5)).abs() < 1e-12);
    }

    #[test]
    fn test_gains_computed_with_clamped_step() {
        // Test that gains are computed using the clamped additive step, not the raw step
        let strategy = RefinementStrategy::L2Refinement {
            alpha: 0.0,
            tilt_tau: crate::grid::two_tensor_solver::DEFAULT_TAU,
            tilt_rho: crate::grid::two_tensor_solver::DEFAULT_RHO,
            prior_sample_size: 0.0,
            update_clamp: 1.0,
        };

        // Create a case where u_raw would give larger gain than the clamped u
        let (m_l, m_r, g_l, g_r) = strategy.refinement_update_from_precomputed_statistics(
            5.0, 1.0, 1.0, // left: s_rb, s_bb, n
            -0.2, 1.0, 1.0, // right: s_rb, s_bb, n
        );

        // Gains should be computed with the actual clamped step that will be applied
        assert!(g_l.is_finite());
        assert!(g_r.is_finite());

        // The gain should reflect the actual improvement from the clamped step
        // For left side: u_clamped = exp(1.0) - 1 ≈ 1.718, s_rb = 5.0, s_bb = 1.0
        // g_l = 2 * 1.718 * 5.0 - 1.0 * 1.718^2 = 17.18 - 2.95 ≈ 14.23
        let expected_g_l = 2.0 * (m_l - 1.0) * 5.0 - 1.0 * (m_l - 1.0).powi(2);
        assert!((g_l - expected_g_l).abs() < 1e-10);

        // For right side: u_clamped = -0.2 (unchanged), s_rb = -0.2, s_bb = 1.0
        // g_r = 2 * (-0.2) * (-0.2) - 1.0 * (-0.2)^2 = 0.08 - 0.04 = 0.04
        let expected_g_r = 2.0 * (m_r - 1.0) * (-0.2) - 1.0 * (m_r - 1.0).powi(2);
        assert!((g_r - expected_g_r).abs() < 1e-10);
    }

    #[test]
    fn test_clamp_edge_cases() {
        // Test edge cases for clamping
        let strategy = RefinementStrategy::L2Refinement {
            alpha: 0.0,
            tilt_tau: crate::grid::two_tensor_solver::DEFAULT_TAU,
            tilt_rho: crate::grid::two_tensor_solver::DEFAULT_RHO,
            prior_sample_size: 0.0,
            update_clamp: 0.5,
        };

        let _lo = (-0.5f64).exp(); // 0.6065
        let hi = (0.5f64).exp(); // 1.6487

        // Test with very large raw values
        let (m_l, m_r, _, _) = strategy.refinement_update_from_precomputed_statistics(
            1.0, 0.00001, 1.0, // left: very large u_raw = 1/0.00001
            -1.0, 0.00001, 1.0, // right: very large negative u_raw = -1/0.00001
        );

        // Both should be clamped to the bounds
        assert!((m_l - hi).abs() < 1e-12);
        assert!((m_r.abs() - hi).abs() < 1e-12);
        assert!(m_r.is_sign_negative());
    }
}
