use super::reference_grid::ReferenceGrid;
use crate::grid::FittedTreeGrid;
use ndarray::ArrayView2;

/// Derives interval boundaries from split points.
/// Intervals are of the form [start, end) where:
/// - First interval: (-∞, splits[0])
/// - Middle intervals: (splits[i], splits[i+1])
/// - Last interval: (splits[n-1], +∞)
fn derive_intervals(splits: &[f64]) -> Vec<(f64, f64)> {
    let mut intervals = Vec::new();
    if splits.is_empty() {
        intervals.push((f64::NEG_INFINITY, f64::INFINITY));
    } else {
        intervals.push((f64::NEG_INFINITY, splits[0]));
        for i in 0..splits.len() - 1 {
            intervals.push((splits[i], splits[i + 1]));
        }
        intervals.push((splits[splits.len() - 1], f64::INFINITY));
    }
    intervals
}

/// Computes histogram counts for each axis by binning data points into intervals.
fn compute_histogram(data: ArrayView2<f64>, axes_splits: &[Vec<f64>]) -> Vec<Vec<usize>> {
    let n_axes = data.ncols();
    let mut counts = Vec::with_capacity(n_axes);

    for axis in 0..n_axes {
        let splits = &axes_splits[axis];
        let n_bins = if splits.is_empty() {
            1
        } else {
            splits.len() + 1 // splits define n+1 intervals
        };
        let mut axis_counts = vec![0; n_bins];

        // Collect and sort the column's values once
        let mut vals: Vec<f64> = data.column(axis).iter().copied().collect();
        vals.sort_by(|a, b| a.partial_cmp(b).unwrap());

        // Walk through sorted values and "advance" the split pointer
        // This produces the same bin index as `partition_point(|&s| s <= v)` used in predict_single
        let mut s_idx: usize = 0;
        let b_len = splits.len();
        for v in vals {
            while s_idx < b_len && v >= splits[s_idx] {
                s_idx += 1;
            }
            axis_counts[s_idx] += 1;
        }

        counts.push(axis_counts);
    }

    counts
}

/// Converts the raw consensus vectors back into a FittedTreeGrid so the rest of the
/// pipeline (counting observations, computing residuals) works without modification.
///
/// # Arguments
/// * `ref_grid` - The ReferenceGrid used for projection
/// * `consensus_components` - The consensus vectors, one per axis
/// * `consensus_scale` - The consensus scaling factor
/// * `original_data` - Original training data, needed to populate observation_counts
///
/// # Returns
/// A FittedTreeGrid representing the consensus model
pub fn reconstruct_grid(
    ref_grid: &ReferenceGrid,
    consensus_components: Vec<Vec<f64>>,
    consensus_scale: f64,
    original_data: ArrayView2<f64>,
) -> FittedTreeGrid {
    // 1. Setup Structure
    // The splits are exactly the RefGrid splits.
    // The values are the consensus_components.

    // 2. Compute Observation Counts
    // Since we are using fixed quantiles, we just need to histogram
    // the data points into these bins.
    let observation_counts = compute_histogram(original_data, &ref_grid.axes_splits);

    // 3. Derive intervals from splits
    let intervals: Vec<Vec<(f64, f64)>> = ref_grid
        .axes_splits
        .iter()
        .map(|splits| derive_intervals(splits))
        .collect();

    // Validate that dimensions match
    let n_axes = ref_grid.axes_splits.len();
    assert_eq!(
        consensus_components.len(),
        n_axes,
        "Consensus components must match number of axes"
    );

    // Validate that each axis has correct number of values
    for (axis, (components, splits)) in consensus_components
        .iter()
        .zip(ref_grid.axes_splits.iter())
        .enumerate()
    {
        let expected_len = if splits.is_empty() {
            1
        } else {
            splits.len() + 1 // n splits define n+1 intervals
        };
        assert_eq!(
            components.len(),
            expected_len,
            "Axis {}: consensus component length {} doesn't match expected {}",
            axis,
            components.len(),
            expected_len
        );
    }

    // 4. Convert consensus components to two-tensor representation
    // Consensus components are treated as backbone values with zero tilt
    let backbone_values = consensus_components;
    let tilt_values: Vec<Vec<f64>> = backbone_values
        .iter()
        .map(|axis_values| vec![0.0; axis_values.len()])
        .collect();

    FittedTreeGrid::new_two_tensor(
        ref_grid.axes_splits.clone(),
        observation_counts,
        intervals,
        backbone_values,
        tilt_values,
        consensus_scale,
        0.0,
    )
}
