use super::reference_grid::ReferenceGrid;
use crate::grid::FittedTreeGrid;

/// Utility function: compute L2 norm of a vector
fn l2_norm(vec: &[f64]) -> f64 {
    vec.iter().map(|x| x * x).sum::<f64>().sqrt()
}

/// Utility function: scale a vector in-place
fn scale_vector(vec: &mut [f64], scale: f64) {
    for x in vec.iter_mut() {
        *x *= scale;
    }
}

/// Converts a tree grid into a set of vectors (one per axis) using a reference grid.
/// The vectors are normalized to unit L2 norm, and the magnitude is returned as the new 'lambda'.
///
/// # Arguments
/// * `grid` - The FittedTreeGrid to project
/// * `ref_grid` - The ReferenceGrid providing the fixed coordinate system
///
/// # Returns
/// A tuple of:
/// - `Vec<Vec<f64>>`: Projected vectors, one per axis, each normalized to unit L2 norm
/// - `f64`: The new lambda (scaling factor) that preserves the original magnitude
pub fn project_grid_to_vectors(
    grid: &FittedTreeGrid,
    ref_grid: &ReferenceGrid,
) -> (Vec<Vec<f64>>, f64) {
    let mut projected_vectors = Vec::new();

    // 1. Project Shapes
    for (axis_idx, axis_centers) in ref_grid.axes_centers.iter().enumerate() {
        let mut vec = Vec::with_capacity(axis_centers.len());

        // Iterate through reference bin centers
        for &center_val in axis_centers {
            // Query the FittedTreeGrid: "What is your value at `center_val`?"
            // Since FittedTreeGrid intervals are sorted, use binary search
            // (slice::partition_point) to find the interval index.
            let col_idx = grid.splits[axis_idx].partition_point(|&split| split <= center_val);
            let col_idx = col_idx.min(grid.backbone_values[axis_idx].len().saturating_sub(1));
            let b = grid.backbone_values[axis_idx][col_idx];
            let d = grid.tilt_values[axis_idx][col_idx];
            let val = b * d.cosh();
            vec.push(val);
        }
        projected_vectors.push(vec);
    }

    // 2. Handle Scaling and Normalization
    // We want the vectors to have L2 norm = 1, so the math works.
    // All magnitude moves into `new_lambda`.
    let mut new_lambda = grid.scaling;

    for vec in &mut projected_vectors {
        let norm = l2_norm(vec);
        if norm > 1e-12 {
            // Divide vector by norm
            scale_vector(vec, 1.0 / norm);
            // Multiply lambda by norm (conservation of energy)
            new_lambda *= norm;
        } else {
            // Handle zero-vector case (rare/error)
            new_lambda = 0.0;
        }
    }

    (projected_vectors, new_lambda)
}
