//! Bagged Two-Tensor MPF Aggregation
//!
//! Implements `AI_CONTEXT/30_algorithms.md` §14: Robust aggregation procedure for bagged
//! two-tensor MPF stage-1 models that aggregates based on component shapes (backbone + tilt)
//! rather than just overall predictions.
//!
//! The procedure includes:
//! 1. Common grid alignment per axis
//! 2. Gauge fixing / canonicalization per bag
//! 3. Component-shape distance computation
//! 4. Median tensor selection (medoid)
//! 5. Robust averaging of kept components
//! 6. Post-aggregation normalization

use crate::family::combine_grids::refine_grids_to_union_two_tensor;
use crate::grid::identification::l2_identify;
use crate::grid::FittedTreeGrid;
use crate::logging::log_combination_choice;
use ndarray::{ArrayView1, ArrayView2};

#[cfg(feature = "use-rayon")]
use rayon::prelude::*;

/// Numerical stability constants
const EPSILON: f64 = 1e-10;
const LOG_EPSILON: f64 = -23.025850929940457; // ln(1e-10)
const LOG_MAX: f64 = 23.025850929940457; // ln(1e10)
const EPSILON_N: f64 = 1e-12; // For empty bin weights

/// Compute a_± factors from backbone and tilt values.
///
/// For each axis $j$ and bin $k$:
/// - $a_{+,j}^k = b_j^k e^{d_j^k}$
/// - $a_{-,j}^k = b_j^k e^{-d_j^k}$
///
/// Clamps values to epsilon > 0 for numerical stability.
fn compute_a_factors_from_bd(
    backbone: &[Vec<f64>],
    tilt: &[Vec<f64>],
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
    let num_axes = backbone.len();
    let mut a_plus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
    let mut a_minus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);

    for axis in 0..num_axes {
        let n_bins = backbone[axis].len();
        let mut a_plus_axis = Vec::with_capacity(n_bins);
        let mut a_minus_axis = Vec::with_capacity(n_bins);

        for bin in 0..n_bins {
            let b = backbone[axis][bin].max(EPSILON);
            let d = tilt[axis][bin];
            // Clamp exp(d) and exp(-d) to prevent overflow
            let exp_d = d.min(50.0).exp();
            let exp_neg_d = (-d).min(50.0).exp();
            a_plus_axis.push(b * exp_d);
            a_minus_axis.push(b * exp_neg_d);
        }
        a_plus.push(a_plus_axis);
        a_minus.push(a_minus_axis);
    }

    (a_plus, a_minus)
}

/// Geometric mean combination of a_± factors in log-space.
///
/// Returns (combined_a_plus, combined_a_minus) where:
/// $$\log(\text{combined}_a) = \frac{\sum_b w_b \log(a_b)}{\sum_b w_b}$$
fn geometric_mean_combine_a_factors(
    a_plus_candidates: &[Vec<Vec<f64>>],
    a_minus_candidates: &[Vec<Vec<f64>>],
    weights: &[f64],
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
    if a_plus_candidates.is_empty() {
        panic!("Cannot combine empty candidate grids");
    }

    let num_axes = a_plus_candidates[0].len();
    let total_weight: f64 = weights.iter().sum();

    if total_weight <= 0.0 {
        panic!("Total weight must be positive");
    }

    let mut combined_a_plus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
    let mut combined_a_minus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);

    for axis in 0..num_axes {
        let n_bins = a_plus_candidates[0][axis].len();
        let mut a_plus_axis = Vec::with_capacity(n_bins);
        let mut a_minus_axis = Vec::with_capacity(n_bins);

        for bin in 0..n_bins {
            // Compute weighted geometric mean in log-space for a_+
            let mut log_sum_plus = 0.0;
            for (a_plus_cand, &weight) in a_plus_candidates.iter().zip(weights) {
                let a_val = a_plus_cand[axis][bin].max(EPSILON);
                let log_val = a_val.ln().max(LOG_EPSILON).min(LOG_MAX);
                log_sum_plus += weight * log_val;
            }
            let combined_a_plus_val = (log_sum_plus / total_weight).exp();

            // Compute weighted geometric mean in log-space for a_-
            let mut log_sum_minus = 0.0;
            for (a_minus_cand, &weight) in a_minus_candidates.iter().zip(weights) {
                let a_val = a_minus_cand[axis][bin].max(EPSILON);
                let log_val = a_val.ln().max(LOG_EPSILON).min(LOG_MAX);
                log_sum_minus += weight * log_val;
            }
            let combined_a_minus_val = (log_sum_minus / total_weight).exp();

            a_plus_axis.push(combined_a_plus_val);
            a_minus_axis.push(combined_a_minus_val);
        }

        combined_a_plus.push(a_plus_axis);
        combined_a_minus.push(a_minus_axis);
    }

    (combined_a_plus, combined_a_minus)
}

/// Convert a_± factors back to (backbone, tilt) representation.
///
/// Given $a_+$ and $a_-$, we solve:
/// - $b = \sqrt{a_+ \cdot a_-}$
/// - $d = \frac{1}{2} \ln\left(\frac{a_+}{a_-}\right)$
fn convert_a_factors_to_bd(
    a_plus: &[Vec<f64>],
    a_minus: &[Vec<f64>],
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
    let num_axes = a_plus.len();
    let mut backbone: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
    let mut tilt: Vec<Vec<f64>> = Vec::with_capacity(num_axes);

    for axis in 0..num_axes {
        let n_bins = a_plus[axis].len();
        let mut backbone_axis = Vec::with_capacity(n_bins);
        let mut tilt_axis = Vec::with_capacity(n_bins);

        for bin in 0..n_bins {
            let a_p = a_plus[axis][bin].max(EPSILON);
            let a_m = a_minus[axis][bin].max(EPSILON);

            // b = sqrt(a_+ * a_-)
            let b = (a_p * a_m).sqrt();

            // d = 0.5 * ln(a_+ / a_-)
            // Clamp ratio to prevent extreme values
            let ratio = (a_p / a_m).max(EPSILON).min(1.0 / EPSILON);
            let d = 0.5 * ratio.ln();

            backbone_axis.push(b);
            tilt_axis.push(d);
        }

        backbone.push(backbone_axis);
        tilt.push(tilt_axis);
    }

    (backbone, tilt)
}

/// Center tilt per axis (extracted from two_tensor_identify_l2, but without modifying lambdas).
///
/// For each axis $j$, compute weighted mean tilt and center:
/// $$c_j = \frac{\sum_k N_j^k d_j^k}{\sum_k N_j^k}, \quad d_j^k \leftarrow d_j^k - c_j$$
///
/// **Note**: This does NOT modify lambdas (unlike two_tensor_identify_l2).
/// Used for per-bag canonicalization before distance computation.
fn center_tilt_per_axis(tilt_values: &mut [Vec<f64>], observation_counts: &[Vec<usize>]) {
    const EPS: f64 = 1e-12;

    for dim in 0..tilt_values.len() {
        let counts = &observation_counts[dim];
        let d = &mut tilt_values[dim];

        if counts.is_empty() || d.is_empty() {
            continue;
        }

        let weights_sum: f64 = counts.iter().map(|&c| c as f64).sum();
        if weights_sum <= EPS {
            continue;
        }

        let mean = d
            .iter()
            .zip(counts.iter())
            .map(|(&x, &c)| x * (c as f64))
            .sum::<f64>()
            / weights_sum;

        for x in d.iter_mut() {
            *x -= mean;
        }
    }
}

/// Center backbone in log-space per axis.
///
/// For each axis $j$, compute weighted mean of log backbone and center:
/// $$\bar{\log b}_j = \frac{\sum_k N_j^k \log b_j^k}{\sum_k N_j^k}, \quad \log b_j^k \leftarrow \log b_j^k - \bar{\log b}_j$$
///
/// This is different from L2 normalization used in two_tensor_identify_l2.
fn canonicalize_backbone_log_space(
    backbone_values: &mut [Vec<f64>],
    observation_counts: &[Vec<usize>],
) {
    const EPS: f64 = 1e-12;

    for dim in 0..backbone_values.len() {
        let counts = &observation_counts[dim];
        let b = &mut backbone_values[dim];

        if counts.is_empty() || b.is_empty() {
            continue;
        }

        let weights_sum: f64 = counts.iter().map(|&c| c as f64).sum();
        if weights_sum <= EPS {
            continue;
        }

        // Compute weighted mean of log backbone
        let mean_log_b = b
            .iter()
            .zip(counts.iter())
            .map(|(&x, &c)| {
                let log_x = x.max(EPSILON).ln().max(LOG_EPSILON).min(LOG_MAX);
                log_x * (c as f64)
            })
            .sum::<f64>()
            / weights_sum;

        // Center in log-space, then exponentiate
        for x in b.iter_mut() {
            let log_x = x.max(EPSILON).ln().max(LOG_EPSILON).min(LOG_MAX);
            let centered_log_x = log_x - mean_log_b;
            *x = centered_log_x.exp();
        }
    }
}

/// Canonicalize a bag per axis (log-space backbone centering + tilt centering).
///
/// This removes per-axis constant drift to focus on component shapes.
/// **Note**: Does NOT modify lambdas (we're comparing shapes, not preserving predictions).
fn canonicalize_bag_per_axis(grid: &mut FittedTreeGrid, observation_counts: &[Vec<usize>]) {
    canonicalize_backbone_log_space(&mut grid.backbone_values, observation_counts);
    center_tilt_per_axis(&mut grid.tilt_values, observation_counts);
}

#[derive(Clone)]
struct PrecomputedGrid {
    log_f_plus: Vec<Vec<f64>>,
    log_f_minus: Vec<Vec<f64>>,
}

fn precompute_bin_weights(observation_counts: &[Vec<usize>]) -> Vec<Vec<f64>> {
    observation_counts
        .iter()
        .map(|axis_counts| {
            axis_counts
                .iter()
                .map(|&count| if count == 0 { EPSILON_N } else { count as f64 })
                .collect()
        })
        .collect()
}

fn precompute_log_components(grid: &FittedTreeGrid) -> PrecomputedGrid {
    let mut log_f_plus: Vec<Vec<f64>> = Vec::with_capacity(grid.backbone_values.len());
    let mut log_f_minus: Vec<Vec<f64>> = Vec::with_capacity(grid.backbone_values.len());

    for axis in 0..grid.backbone_values.len() {
        let backbone = &grid.backbone_values[axis];
        let tilt = &grid.tilt_values[axis];
        let mut plus_axis = Vec::with_capacity(backbone.len());
        let mut minus_axis = Vec::with_capacity(backbone.len());

        for bin in 0..backbone.len() {
            let b = backbone[bin].max(EPSILON);
            let d = tilt[bin];
            let log_b = b.ln();

            let p = (log_b + d.min(50.0)).max(LOG_EPSILON).min(LOG_MAX);
            let m = (log_b + (-d).min(50.0)).max(LOG_EPSILON).min(LOG_MAX);

            plus_axis.push(p);
            minus_axis.push(m);
        }

        log_f_plus.push(plus_axis);
        log_f_minus.push(minus_axis);
    }

    PrecomputedGrid {
        log_f_plus,
        log_f_minus,
    }
}

/// Compute component-shape distance between two aligned grids.
///
/// Implements `AI_CONTEXT/30_algorithms.md` §14.3:
/// - Compute $f_{\pm,j}^{(b),k} = b_j^{(b),k} e^{\pm d_j^{(b),k}}$
/// - Compute log-space coordinates $p_j^{(b),k} = \log f_{+,j}^{(b),k}$, $m_j^{(b),k} = \log f_{-,j}^{(b),k}$
/// - Compute per-axis weighted squared distance: $d_j^2(a,b) = \sum_k N_j^k[(p_j^{(a),k} - p_j^{(b),k})^2 + (m_j^{(a),k} - m_j^{(b),k})^2]$
/// - Compute total distance: $D(a,b) = \sqrt{\sum_j d_j^2(a,b)}$
///
/// # Arguments
/// * `grid_a`, `grid_b` - Two aligned grids (must have same structure)
/// * `bin_weights` - Precomputed bin weights per axis and bin
///
/// # Returns
/// Component-shape distance $D(a,b)$
fn compute_component_shape_distance_precomputed(
    grid_a: &PrecomputedGrid,
    grid_b: &PrecomputedGrid,
    bin_weights: &[Vec<f64>],
) -> f64 {
    debug_assert_eq!(grid_a.log_f_plus.len(), grid_b.log_f_plus.len());
    debug_assert_eq!(grid_a.log_f_plus.len(), bin_weights.len());

    let mut total_distance_sq = 0.0;

    for axis in 0..grid_a.log_f_plus.len() {
        let p_a = &grid_a.log_f_plus[axis];
        let m_a = &grid_a.log_f_minus[axis];
        let p_b = &grid_b.log_f_plus[axis];
        let m_b = &grid_b.log_f_minus[axis];
        let weights = &bin_weights[axis];

        debug_assert_eq!(p_a.len(), p_b.len());
        debug_assert_eq!(p_a.len(), weights.len());

        let mut axis_distance_sq = 0.0;

        for bin in 0..p_a.len() {
            let w = weights[bin];

            // Per-axis weighted squared distance
            let delta_p = p_a[bin] - p_b[bin];
            let delta_m = m_a[bin] - m_b[bin];
            axis_distance_sq += w * (delta_p * delta_p + delta_m * delta_m);
        }

        total_distance_sq += axis_distance_sq;
    }

    total_distance_sq.sqrt()
}

/// Select median tensor (medoid) using component-shape distance.
///
/// Implements `AI_CONTEXT/30_algorithms.md` §14.4.2:
/// $$b^\star = \arg\min_{b \in \{1, \dots, B\}} \sum_{b'=1}^B D(b,b')$$
///
/// # Arguments
/// * `grids` - Aligned grids (must have same structure)
/// * `observation_counts` - Observation counts $N_j^k$ per axis and bin
///
/// # Returns
/// (Index of median bag (medoid), distance matrix for reuse)
///
/// # Edge Cases
/// * Single bag: returns index 0
/// * Ties: returns smallest index (deterministic)
#[cfg(not(feature = "use-rayon"))]
fn select_median_tensor_component_shape(
    grids: &[PrecomputedGrid],
    bin_weights: &[Vec<f64>],
) -> (usize, Vec<Vec<f64>>) {
    if grids.is_empty() {
        panic!("Cannot select median from empty grid list");
    }
    if grids.len() == 1 {
        // Single bag: return it as median with empty distance matrix
        return (0, vec![vec![0.0; 1]; 1]);
    }

    let n = grids.len();
    let mut min_total_distance = f64::INFINITY;
    let mut median_index = 0;

    // Pre-compute all pairwise distances (only compute once per pair, using symmetry)
    // Store in a matrix: distances[i][j] = D(grids[i], grids[j])
    // Only compute upper triangle (i < j), then use symmetry
    let mut distance_matrix: Vec<Vec<f64>> = vec![vec![0.0; n]; n];

    log::info!(
        "Computing pairwise component-shape distances for {} bags",
        n
    );
    for i in 0..n {
        for j in (i + 1)..n {
            let dist =
                compute_component_shape_distance_precomputed(&grids[i], &grids[j], bin_weights);
            distance_matrix[i][j] = dist;
            distance_matrix[j][i] = dist; // Use symmetry
        }
    }

    // Find median: bag with minimum sum of distances to all others
    log::info!("Finding median bag (medoid) from distance matrix");
    for i in 0..n {
        let total_distance: f64 = distance_matrix[i].iter().sum();
        if total_distance < min_total_distance {
            min_total_distance = total_distance;
            median_index = i;
        }
    }

    log::info!(
        "Selected grid {} as median (medoid) with total distance {:.6}",
        median_index,
        min_total_distance
    );
    (median_index, distance_matrix)
}

/// Select median tensor (medoid) using component-shape distance (parallel version with rayon).
///
/// Same as the sequential version but uses rayon to parallelize the outer loop for pairwise distance computation.
#[cfg(feature = "use-rayon")]
fn select_median_tensor_component_shape(
    grids: &[PrecomputedGrid],
    bin_weights: &[Vec<f64>],
) -> (usize, Vec<Vec<f64>>) {
    if grids.is_empty() {
        panic!("Cannot select median from empty grid list");
    }
    if grids.len() == 1 {
        // Single bag: return it as median with empty distance matrix
        return (0, vec![vec![0.0; 1]; 1]);
    }

    let n = grids.len();

    // Pre-compute all pairwise distances (only compute once per pair, using symmetry)
    // Store in a matrix: distances[i][j] = D(grids[i], grids[j])
    // Only compute upper triangle (i < j), then use symmetry
    let mut distance_matrix: Vec<Vec<f64>> = vec![vec![0.0; n]; n];

    log::info!(
        "Computing pairwise component-shape distances for {} bags (parallel)",
        n
    );

    // Parallelize outer loop over i
    let upper_triangle: Vec<(usize, usize, f64)> = (0..n)
        .into_par_iter()
        .flat_map(|i| {
            ((i + 1)..n).into_par_iter().map(move |j| {
                let dist =
                    compute_component_shape_distance_precomputed(&grids[i], &grids[j], bin_weights);
                (i, j, dist)
            })
        })
        .collect();

    // Fill distance matrix using symmetry
    for (i, j, dist) in upper_triangle {
        distance_matrix[i][j] = dist;
        distance_matrix[j][i] = dist; // Use symmetry
    }

    // Find median: bag with minimum sum of distances to all others
    log::info!("Finding median bag (medoid) from distance matrix");
    let (median_index, min_total_distance) = (0..n)
        .map(|i| {
            let total_distance: f64 = distance_matrix[i].iter().sum();
            (i, total_distance)
        })
        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
        .unwrap();

    log::info!(
        "Selected grid {} as median (medoid) with total distance {:.6}",
        median_index,
        min_total_distance
    );
    (median_index, distance_matrix)
}

/// Trim outliers by keeping closest bags to median.
///
/// Implements `AI_CONTEXT/30_algorithms.md` §14.4.3:
/// - Sort bags by distance to median
/// - Keep closest $\lceil \text{trim_percentage} \times B \rceil$ bags
///
/// # Arguments
/// * `grids` - Aligned grids
/// * `median_index` - Index of median bag
/// * `distance_matrix` - Pre-computed pairwise distance matrix (reuse from median selection)
/// * `trim_percentage` - Fraction to keep (default 0.9 = 90%)
///
/// # Returns
/// Indices of kept bags (sorted by distance to median, closest first)
///
/// # Edge Cases
/// * $B=1$: keep all bags
/// * $\lceil \text{trim_percentage} \times B \rceil = B$: keep all bags
fn trim_outliers(
    grids: &[FittedTreeGrid],
    median_index: usize,
    distance_matrix: &[Vec<f64>],
    trim_percentage: f64,
) -> Vec<usize> {
    if grids.is_empty() {
        return Vec::new();
    }
    if grids.len() == 1 {
        return vec![0];
    }

    let n = grids.len();
    let keep_count = (trim_percentage * n as f64).ceil() as usize;
    let keep_count = keep_count.min(n);

    // Reuse distances from pre-computed distance matrix (no need to recompute)
    log::info!(
        "Using pre-computed distances to median bag {}",
        median_index
    );
    let mut distances: Vec<(usize, f64)> = Vec::with_capacity(n);
    for i in 0..n {
        let dist = distance_matrix[i][median_index];
        distances.push((i, dist));
    }

    // Sort by distance (closest first)
    distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());

    // Return indices of kept bags
    let kept: Vec<usize> = distances[..keep_count]
        .iter()
        .map(|(idx, _)| *idx)
        .collect();

    log::info!(
        "Selected {} out of {} grids (trim_percentage={:.2}, keep_count={})",
        kept.len(),
        n,
        trim_percentage,
        keep_count
    );

    kept
}

/// Combine lambdas using geometric mean.
///
/// Implements `AI_CONTEXT/30_algorithms.md` §14.6.2:
/// $$\log \bar{\lambda}_\pm = \frac{1}{|K|} \sum_{b \in K} \log \lambda_\pm^{(b)}$$
/// $$\bar{\lambda}_\pm = \exp(\log \bar{\lambda}_\pm)$$
///
/// # Arguments
/// * `lambda_plus_candidates` - $\lambda_+^{(b)}$ for each bag
/// * `lambda_minus_candidates` - $\lambda_-^{(b)}$ for each bag
/// * `weights` - Optional weights for each bag (default: uniform)
///
/// # Returns
/// Combined lambdas $(\bar{\lambda}_+, \bar{\lambda}_-)$
fn combine_lambdas_geometric_mean(
    lambda_plus_candidates: &[f64],
    lambda_minus_candidates: &[f64],
    weights: Option<&[f64]>,
) -> (f64, f64) {
    if lambda_plus_candidates.is_empty() {
        panic!("Cannot combine empty lambda list");
    }

    let n = lambda_plus_candidates.len();
    let uniform_weights: Vec<f64> = vec![1.0; n];
    let weights = weights.unwrap_or(&uniform_weights);

    if weights.len() != n {
        panic!("Weights length must match lambdas length");
    }

    let total_weight: f64 = weights.iter().sum();
    if total_weight <= 0.0 {
        panic!("Total weight must be positive");
    }

    // Geometric mean of lambda_+
    let log_sum_plus: f64 = lambda_plus_candidates
        .iter()
        .zip(weights.iter())
        .map(|(&l, &w)| {
            let log_l = l.max(EPSILON).ln().max(LOG_EPSILON).min(LOG_MAX);
            log_l * w
        })
        .sum();
    let combined_lambda_plus = (log_sum_plus / total_weight).exp();

    // Geometric mean of lambda_-
    let log_sum_minus: f64 = lambda_minus_candidates
        .iter()
        .zip(weights.iter())
        .map(|(&l, &w)| {
            let log_l = l.max(EPSILON).ln().max(LOG_EPSILON).min(LOG_MAX);
            log_l * w
        })
        .sum();
    let combined_lambda_minus = (log_sum_minus / total_weight).exp();

    (combined_lambda_plus, combined_lambda_minus)
}

/// Main function: Aggregate bagged two-tensor MPF models using component-shape distance.
///
/// Implements `AI_CONTEXT/30_algorithms.md` §14: Robust aggregation procedure for bagged
/// two-tensor MPF stage-1 models that aggregates based on component shapes (backbone + tilt)
/// rather than just overall predictions.
///
/// Algorithm:
/// 1. Grid alignment: Refine all grids to union grid (reuse `refine_grids_to_union_two_tensor`)
/// 2. Compute bin weights from observation_counts
/// 3. Per-bag canonicalization: Center backbone (log-space) and tilt per axis
/// 4. Component-shape distance computation
/// 5. Median selection (medoid)
/// 6. Trimming: Keep closest 90% to median
/// 7. Robust averaging: Geometric mean of $a_{\pm}$ factors on kept set
/// 8. Reconstruct (b,d) from averaged $a_{\pm}$
/// 9. Post-aggregation normalization (reuse `two_tensor_identify_l2`)
/// 10. Geometric mean of lambdas
///
/// # Arguments
/// * `grids` - Input bagged grids (one per bag)
/// * `points` - Training data points (n × p matrix)
/// * `weights` - Optional weights for each point (default: uniform)
/// * `trim_percentage` - Fraction of bags to keep (default: 0.9 = 90%)
///
/// # Returns
/// Aggregated `FittedTreeGrid` model
pub fn aggregate_bagged_two_tensor(
    grids: &[FittedTreeGrid],
    points: ArrayView2<f64>,
    _weights: Option<ArrayView1<f64>>, // TODO: Use for weighted bin weights if provided
    trim_percentage: f64,
) -> FittedTreeGrid {
    if grids.is_empty() {
        panic!("Cannot aggregate empty grid list");
    }
    if grids.len() == 1 {
        // Single bag: return it directly
        // Clone is necessary because function signature requires owned FittedTreeGrid
        return grids[0].clone();
    }

    // Step 1: Grid alignment - refine all grids to union grid
    let mut aligned_grids = refine_grids_to_union_two_tensor(grids);

    // Step 2: Compute bin weights from observation_counts of first aligned grid
    // (All aligned grids have same structure, so we can use any grid's observation_counts)
    // But we need to recompute observation_counts for the union grid from data
    // Actually, refine_grids_to_union_two_tensor sets observation_counts to zeros,
    // so we need to recompute them. Let's use the same logic as combine_grids.rs
    let num_axes = aligned_grids[0].intervals.len();
    let mut union_observation_counts: Vec<Vec<usize>> = Vec::with_capacity(num_axes);
    for axis in 0..num_axes {
        let n_bins = aligned_grids[0].intervals[axis].len();
        let mut counts: Vec<usize> = vec![0; n_bins];
        let splits = &aligned_grids[0].splits[axis];

        // Count observations in each bin (same logic as combine_grids.rs)
        let mut vals: Vec<f64> = points.column(axis).iter().copied().collect();
        vals.sort_by(|a, b| a.partial_cmp(b).unwrap());

        let mut s_idx: usize = 0;
        let b_len = splits.len();
        for v in vals {
            while s_idx < b_len && v >= splits[s_idx] {
                s_idx += 1;
            }
            counts[s_idx] += 1;
        }

        union_observation_counts.push(counts);
    }

    // Note: We don't need to set observation_counts on aligned_grids because:
    // - canonicalize_bag_per_axis only uses backbone_values and tilt_values
    // - compute_component_shape_distance only uses backbone_values, tilt_values, and observation_counts
    // - We only need observation_counts for the final aggregated grid

    // Step 3: Per-bag canonicalization (remove per-axis constant drift)
    for grid in &mut aligned_grids {
        canonicalize_bag_per_axis(grid, &union_observation_counts);
    }

    // Optimization: if trim_percentage is 1.0, combine all grids without computing pairwise distances
    let (_kept_indices, kept_grids) = if trim_percentage >= 1.0 {
        log::info!(
            "trim_percentage is {:.2}, combining all {} grids without distance computation",
            trim_percentage,
            aligned_grids.len()
        );
        let all_indices: Vec<usize> = (0..aligned_grids.len()).collect();
        let candidate_indices: Vec<(usize, f64)> =
            all_indices.iter().map(|&idx| (idx, 0.0)).collect();
        log_combination_choice("BaggedTwoTensor", None, &candidate_indices);
        let kept: Vec<&FittedTreeGrid> = aligned_grids.iter().collect();
        (all_indices, kept)
    } else {
        let bin_weights = precompute_bin_weights(&union_observation_counts);
        let precomputed_grids: Vec<PrecomputedGrid> = aligned_grids
            .iter()
            .map(precompute_log_components)
            .collect();

        // Step 4 & 5: Compute distances and select median (returns distance matrix for reuse)
        let (median_index, distance_matrix) =
            select_median_tensor_component_shape(&precomputed_grids, &bin_weights);

        // Step 6: Trim outliers (keep closest trim_percentage) - reuse distance matrix
        let kept_indices = trim_outliers(
            &aligned_grids,
            median_index,
            &distance_matrix,
            trim_percentage,
        );

        // Log combination choice: kept bags with their distances to median
        let candidate_indices: Vec<(usize, f64)> = kept_indices
            .iter()
            .map(|&idx| (idx, distance_matrix[idx][median_index]))
            .collect();
        log_combination_choice("BaggedTwoTensor", Some(median_index), &candidate_indices);

        let kept_grids: Vec<&FittedTreeGrid> =
            kept_indices.iter().map(|&i| &aligned_grids[i]).collect();
        (kept_indices, kept_grids)
    };

    // Step 7: Robust averaging - geometric mean of a_± factors on kept set
    // Compute a_± factors for kept grids
    let mut a_plus_candidates: Vec<Vec<Vec<f64>>> = Vec::with_capacity(kept_grids.len());
    let mut a_minus_candidates: Vec<Vec<Vec<f64>>> = Vec::with_capacity(kept_grids.len());
    let mut lambda_plus_candidates: Vec<f64> = Vec::with_capacity(kept_grids.len());
    let mut lambda_minus_candidates: Vec<f64> = Vec::with_capacity(kept_grids.len());

    for grid in &kept_grids {
        let (a_plus, a_minus) = compute_a_factors_from_bd(&grid.backbone_values, &grid.tilt_values);
        a_plus_candidates.push(a_plus);
        a_minus_candidates.push(a_minus);
        lambda_plus_candidates.push(grid.lambda_plus);
        lambda_minus_candidates.push(grid.lambda_minus);
    }

    // Uniform weights for kept bags
    let kept_weights: Vec<f64> = vec![1.0; kept_grids.len()];

    // Geometric mean combination
    let (combined_a_plus, combined_a_minus) =
        geometric_mean_combine_a_factors(&a_plus_candidates, &a_minus_candidates, &kept_weights);

    // Step 8: Reconstruct (b,d) from averaged a_±
    let (combined_backbone, combined_tilt) =
        convert_a_factors_to_bd(&combined_a_plus, &combined_a_minus);

    // Step 9: Post-aggregation normalization (reuse two_tensor_identify_l2)
    // But we need mutable references, so create a temporary grid
    let mut temp_backbone = combined_backbone.clone();
    let mut temp_tilt = combined_tilt.clone();
    let mut temp_lambda_plus = 1.0; // Will be set in step 10
    let mut temp_lambda_minus = 1.0; // Will be set in step 10

    // Apply normalization (this will modify lambdas, but we'll recompute them in step 10)
    l2_identify(
        &mut temp_backbone,
        &mut temp_tilt,
        &union_observation_counts,
        &mut temp_lambda_plus,
        &mut temp_lambda_minus,
    );

    // Step 10: Geometric mean of lambdas (on kept set, before normalization adjustment)
    let (combined_lambda_plus, combined_lambda_minus) =
        combine_lambdas_geometric_mean(&lambda_plus_candidates, &lambda_minus_candidates, None);

    // Note: The normalization in step 9 modified temp_lambda_plus/minus, but we want to use
    // the geometric mean from step 10. The normalization is applied to backbone/tilt, which
    // is what we want. The lambdas from step 10 are the aggregated lambdas.

    // scaling field is kept for backward compatibility but not used in two-tensor model

    // Create aggregated FittedTreeGrid
    FittedTreeGrid::new_two_tensor(
        aligned_grids[0].splits.clone(),
        union_observation_counts,
        aligned_grids[0].intervals.clone(),
        temp_backbone,
        temp_tilt,
        combined_lambda_plus,
        combined_lambda_minus,
    )
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array2;

    fn create_simple_grid(
        backbone: Vec<Vec<f64>>,
        tilt: Vec<Vec<f64>>,
        lambda_plus: f64,
        lambda_minus: f64,
        splits: Vec<Vec<f64>>,
        intervals: Vec<Vec<(f64, f64)>>,
    ) -> FittedTreeGrid {
        let observation_counts: Vec<Vec<usize>> =
            backbone.iter().map(|axis| vec![10; axis.len()]).collect();
        FittedTreeGrid::new_two_tensor(
            splits,
            observation_counts,
            intervals,
            backbone,
            tilt,
            lambda_plus,
            lambda_minus,
        )
    }

    #[test]
    fn test_component_shape_distance_identical() {
        let intervals = vec![vec![(0.0, 1.0)], vec![(0.0, 1.0)]];
        let splits = vec![vec![], vec![]];
        let grid1 = create_simple_grid(
            vec![vec![1.0], vec![1.0]],
            vec![vec![0.0], vec![0.0]],
            1.0,
            0.5,
            splits.clone(),
            intervals.clone(),
        );
        let grid2 = create_simple_grid(
            vec![vec![1.0], vec![1.0]],
            vec![vec![0.0], vec![0.0]],
            1.0,
            0.5,
            splits,
            intervals.clone(),
        );
        let bin_weights = precompute_bin_weights(&grid1.observation_counts);
        let grid1_pre = precompute_log_components(&grid1);
        let grid2_pre = precompute_log_components(&grid2);
        let distance =
            compute_component_shape_distance_precomputed(&grid1_pre, &grid2_pre, &bin_weights);
        assert!(
            distance < 1e-10,
            "Identical grids should have distance ≈ 0, got {}",
            distance
        );
    }

    #[test]
    fn test_select_median_single_bag() {
        let intervals = vec![vec![(0.0, 1.0)]];
        let splits = vec![vec![]];
        let grid = create_simple_grid(
            vec![vec![1.0]],
            vec![vec![0.0]],
            1.0,
            0.5,
            splits,
            intervals.clone(),
        );
        let grids = [grid];
        let observation_counts = vec![vec![10]];

        let bin_weights = precompute_bin_weights(&observation_counts);
        let precomputed_grids: Vec<PrecomputedGrid> =
            grids.iter().map(precompute_log_components).collect();
        let (median_index, _distance_matrix) =
            select_median_tensor_component_shape(&precomputed_grids, &bin_weights);
        assert_eq!(median_index, 0);
    }

    #[test]
    fn test_trim_outliers() {
        let intervals = vec![vec![(0.0, 1.0)]];
        let splits = vec![vec![]];
        let grid1 = create_simple_grid(
            vec![vec![1.0]],
            vec![vec![0.0]],
            1.0,
            0.5,
            splits.clone(),
            intervals.clone(),
        );
        let grid2 = create_simple_grid(
            vec![vec![1.0]],
            vec![vec![0.0]],
            1.0,
            0.5,
            splits.clone(),
            intervals.clone(),
        );
        let grid3 = create_simple_grid(
            vec![vec![2.0]],
            vec![vec![0.5]],
            2.0,
            1.0,
            splits,
            intervals,
        );
        let grids = vec![grid1, grid2, grid3];
        let observation_counts = vec![vec![10]];

        let bin_weights = precompute_bin_weights(&observation_counts);
        let precomputed_grids: Vec<PrecomputedGrid> =
            grids.iter().map(precompute_log_components).collect();
        let (median_index, distance_matrix) =
            select_median_tensor_component_shape(&precomputed_grids, &bin_weights);
        let kept = trim_outliers(&grids, median_index, &distance_matrix, 0.9);

        // Should keep 90% of 3 = 2.7, rounded up to 3, but since we sort by distance,
        // we keep the 2 closest (grids 0 and 1, both identical to median)
        assert!(kept.len() >= 2);
        assert!(kept.contains(&0) || kept.contains(&1));
    }

    #[test]
    fn test_combine_lambdas_geometric_mean() {
        let lambda_plus = vec![1.0, 2.0, 4.0];
        let lambda_minus = vec![0.5, 1.0, 2.0];

        let (combined_plus, combined_minus) =
            combine_lambdas_geometric_mean(&lambda_plus, &lambda_minus, None);

        // Geometric mean of [1, 2, 4] = (1 * 2 * 4)^(1/3) = 8^(1/3) ≈ 2.0
        let product: f64 = 1.0 * 2.0 * 4.0;
        let expected_plus = product.powf(1.0 / 3.0);
        assert!((combined_plus - expected_plus).abs() < 1e-10);

        // Geometric mean of [0.5, 1.0, 2.0] = (0.5 * 1.0 * 2.0)^(1/3) = 1.0^(1/3) = 1.0
        assert!((combined_minus - 1.0).abs() < 1e-10);
    }

    #[test]
    fn test_aggregate_bagged_two_tensor_single_bag() {
        let intervals = vec![vec![(0.0, 1.0)]];
        let splits = vec![vec![]];
        let grid = create_simple_grid(
            vec![vec![1.0]],
            vec![vec![0.0]],
            1.0,
            0.5,
            splits,
            intervals,
        );
        let grids = vec![grid.clone()];
        let points = Array2::from_shape_vec(
            (10, 1),
            vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95],
        )
        .unwrap();

        let aggregated = aggregate_bagged_two_tensor(&grids, points.view(), None, 0.9);

        // Single bag should return (approximately) the same grid
        assert_eq!(aggregated.backbone_values.len(), grid.backbone_values.len());
    }
}
