/// Utility function: compute dot product of two vectors
fn dot_product(a: &[f64], b: &[f64]) -> f64 {
    debug_assert_eq!(
        a.len(),
        b.len(),
        "Vectors must have same length for dot product"
    );
    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}

/// Utility function: compute L2 norm of a vector
fn l2_norm(vec: &[f64]) -> f64 {
    vec.iter().map(|x| x * x).sum::<f64>().sqrt()
}

/// Utility function: scale a vector in-place
fn scale_vector(vec: &mut [f64], scale: f64) {
    for x in vec.iter_mut() {
        *x *= scale;
    }
}

/// Utility function: axpy operation (y = a * x + y)
fn axpy(alpha: f64, x: &[f64], y: &mut [f64]) {
    debug_assert_eq!(x.len(), y.len(), "Vectors must have same length for axpy");
    for (y_i, x_i) in y.iter_mut().zip(x.iter()) {
        *y_i += alpha * x_i;
    }
}

/// Utility function: find index of maximum absolute value
fn argmax_abs(values: &[f64]) -> usize {
    values
        .iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.abs().partial_cmp(&b.abs()).unwrap())
        .map(|(idx, _)| idx)
        .unwrap_or(0)
}

/// Performs Alternating Least Squares (ALS) optimization to find the "Consensus Rank-1 Tensor"
/// from the bagged projections.
///
/// # Arguments
/// * `bagged_components` - B bags, P axes, K bins. dimensions: [bag][axis][bin]
/// * `bag_lambdas` - The projected lambdas from projection. dimensions: [bag]
/// * `n_iterations` - Number of ALS iterations to perform
///
/// # Returns
/// A tuple of:
/// - `Vec<Vec<f64>>`: Consensus components, one per axis, each normalized to unit L2 norm
/// - `f64`: The final consensus scale (magnitude)
pub fn compute_consensus(
    bagged_components: &[Vec<Vec<f64>>],
    bag_lambdas: &[f64],
    n_iterations: usize,
) -> (Vec<Vec<f64>>, f64) {
    let n_bags = bagged_components.len();
    if n_bags == 0 {
        return (Vec::new(), 0.0);
    }

    let n_axes = bagged_components[0].len();

    // Get the number of bins per axis (can vary between axes)
    let bins_per_axis: Vec<usize> = (0..n_axes)
        .map(|axis| bagged_components[0][axis].len())
        .collect();

    // Validate all bags have same structure
    for bag in bagged_components {
        assert_eq!(bag.len(), n_axes, "All bags must have same number of axes");
        for (axis, axis_vec) in bag.iter().enumerate() {
            assert_eq!(
                axis_vec.len(),
                bins_per_axis[axis],
                "All bags must have same number of bins per axis"
            );
        }
    }

    // --- Step 1: Initialize Consensus ---
    // Best practice: Initialize with the component from the bag
    // that has the largest absolute lambda (most explanatory power).
    let best_bag_idx = argmax_abs(bag_lambdas);
    let mut h = bagged_components[best_bag_idx].clone();

    // --- Step 2: ALS Loop ---
    for _ in 0..n_iterations {
        for k in 0..n_axes {
            // A. Calculate Voting Weights for each bag
            // w[b] = lambda[b] * Product_{j != k} ( <bag_vec_j, h_j> )
            let mut w = vec![0.0; n_bags];

            for b in 0..n_bags {
                let mut alignment_score = 1.0;
                for j in 0..n_axes {
                    if j != k {
                        let dot = dot_product(&bagged_components[b][j], &h[j]);
                        alignment_score *= dot;
                    }
                }
                // Automatic Sign Correction happens here:
                w[b] = bag_lambdas[b] * alignment_score;
            }

            // B. Update h[k] as weighted sum
            // h_k_new = Sum_{b} ( w[b] * bag_vec_k )
            let n_bins_k = bins_per_axis[k];
            let mut h_new_k = vec![0.0; n_bins_k];
            for b in 0..n_bags {
                // Vector addition: h_new_k += w[b] * bagged_components[b][k]
                axpy(w[b], &bagged_components[b][k], &mut h_new_k);
            }

            // C. Normalize
            let norm = l2_norm(&h_new_k);
            if norm > 1e-12 {
                scale_vector(&mut h_new_k, 1.0 / norm);
            } else {
                // If norm is zero, keep previous value
                h_new_k = h[k].clone();
            }
            h[k] = h_new_k;
        }
    }

    // --- Step 3: Compute Final Scale ---
    // The magnitude is the projection of the ensemble sum onto our unit rank-1 tensor.
    // Final Scale = (1/B) * Sum_{b} ( lambda[b] * Product_{j} <bag_vec_j, h_j> )
    let mut final_total_energy = 0.0;

    for b in 0..n_bags {
        let mut total_overlap = 1.0;
        for j in 0..n_axes {
            total_overlap *= dot_product(&bagged_components[b][j], &h[j]);
        }
        final_total_energy += bag_lambdas[b] * total_overlap;
    }

    let avg_scale = final_total_energy / (n_bags as f64);

    (h, avg_scale)
}
