use ndarray::ArrayView2;

use crate::grid::FittedTreeGrid;

// --- Two-Tensor Geometric Mean Combination ---

/// Compute a_± factors from backbone and tilt values.
/// a_+ = b * exp(d), a_- = b * exp(-d)
/// Clamps values to epsilon > 0 for numerical stability.
fn compute_a_factors_from_bd(
    backbone: &[Vec<f64>],
    tilt: &[Vec<f64>],
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
    const EPSILON: f64 = 1e-10;
    let num_axes = backbone.len();
    let mut a_plus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
    let mut a_minus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);

    for axis in 0..num_axes {
        let n_bins = backbone[axis].len();
        let mut a_plus_axis = Vec::with_capacity(n_bins);
        let mut a_minus_axis = Vec::with_capacity(n_bins);

        for bin in 0..n_bins {
            let b = backbone[axis][bin].max(EPSILON);
            let d = tilt[axis][bin];
            // Clamp exp(d) and exp(-d) to prevent overflow
            let exp_d = d.min(50.0).exp();
            let exp_neg_d = (-d).min(50.0).exp();
            a_plus_axis.push(b * exp_d);
            a_minus_axis.push(b * exp_neg_d);
        }
        a_plus.push(a_plus_axis);
        a_minus.push(a_minus_axis);
    }

    (a_plus, a_minus)
}

/// Geometric mean combination of a_± factors in log-space.
/// Returns (combined_a_plus, combined_a_minus) where:
/// log(combined_a) = sum(weights * log(a_candidates)) / sum(weights)
fn geometric_mean_combine_a_factors(
    a_plus_candidates: &[Vec<Vec<f64>>],
    a_minus_candidates: &[Vec<Vec<f64>>],
    weights: &[f64],
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
    const EPSILON: f64 = 1e-10;
    const LOG_EPSILON: f64 = -23.025850929940457; // ln(1e-10)

    if a_plus_candidates.is_empty() {
        panic!("Cannot combine empty candidate grids");
    }

    let num_axes = a_plus_candidates[0].len();
    let total_weight: f64 = weights.iter().sum();

    if total_weight <= 0.0 {
        panic!("Total weight must be positive");
    }

    let mut combined_a_plus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
    let mut combined_a_minus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);

    for axis in 0..num_axes {
        let n_bins = a_plus_candidates[0][axis].len();
        let mut a_plus_axis = Vec::with_capacity(n_bins);
        let mut a_minus_axis = Vec::with_capacity(n_bins);

        for bin in 0..n_bins {
            // Compute weighted geometric mean in log-space for a_+
            let mut log_sum_plus = 0.0;
            for (a_plus_cand, &weight) in a_plus_candidates.iter().zip(weights) {
                let a_val = a_plus_cand[axis][bin].max(EPSILON);
                let log_val = a_val.ln().max(LOG_EPSILON);
                log_sum_plus += weight * log_val;
            }
            let combined_a_plus_val = (log_sum_plus / total_weight).exp();

            // Compute weighted geometric mean in log-space for a_-
            let mut log_sum_minus = 0.0;
            for (a_minus_cand, &weight) in a_minus_candidates.iter().zip(weights) {
                let a_val = a_minus_cand[axis][bin].max(EPSILON);
                let log_val = a_val.ln().max(LOG_EPSILON);
                log_sum_minus += weight * log_val;
            }
            let combined_a_minus_val = (log_sum_minus / total_weight).exp();

            a_plus_axis.push(combined_a_plus_val);
            a_minus_axis.push(combined_a_minus_val);
        }

        combined_a_plus.push(a_plus_axis);
        combined_a_minus.push(a_minus_axis);
    }

    (combined_a_plus, combined_a_minus)
}

/// Convert a_± factors back to (backbone, tilt) representation.
/// Given a_+ and a_-, we solve:
/// a_+ = b * exp(d)  =>  b = sqrt(a_+ * a_-)
/// a_- = b * exp(-d)  =>  d = 0.5 * ln(a_+ / a_-)
fn convert_a_factors_to_bd(
    a_plus: &[Vec<f64>],
    a_minus: &[Vec<f64>],
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
    const EPSILON: f64 = 1e-10;
    let num_axes = a_plus.len();
    let mut backbone: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
    let mut tilt: Vec<Vec<f64>> = Vec::with_capacity(num_axes);

    for axis in 0..num_axes {
        let n_bins = a_plus[axis].len();
        let mut backbone_axis = Vec::with_capacity(n_bins);
        let mut tilt_axis = Vec::with_capacity(n_bins);

        for bin in 0..n_bins {
            let a_p = a_plus[axis][bin].max(EPSILON);
            let a_m = a_minus[axis][bin].max(EPSILON);

            // b = sqrt(a_+ * a_-)
            let b = (a_p * a_m).sqrt();

            // d = 0.5 * ln(a_+ / a_-)
            // Clamp ratio to prevent extreme values
            let ratio = (a_p / a_m).max(EPSILON).min(1.0 / EPSILON);
            let d = 0.5 * ratio.ln();

            backbone_axis.push(b);
            tilt_axis.push(d);
        }

        backbone.push(backbone_axis);
        tilt.push(tilt_axis);
    }

    (backbone, tilt)
}

/// Weighted arithmetic mean for lambda scalars.
fn combine_lambdas_weighted(
    lambda_plus_candidates: &[f64],
    lambda_minus_candidates: &[f64],
    weights: &[f64],
) -> (f64, f64) {
    let total_weight: f64 = weights.iter().sum();
    if total_weight <= 0.0 {
        panic!("Total weight must be positive");
    }

    let lambda_plus = lambda_plus_candidates
        .iter()
        .zip(weights)
        .map(|(l, &w)| l * w)
        .sum::<f64>()
        / total_weight;

    let lambda_minus = lambda_minus_candidates
        .iter()
        .zip(weights)
        .map(|(l, &w)| l * w)
        .sum::<f64>()
        / total_weight;

    (lambda_plus, lambda_minus)
}

/// Refine grids to a union grid that includes all unique split points.
/// Returns refined grids with values copied from parent intervals to sub-intervals.
pub(crate) fn refine_grids_to_union_two_tensor(grids: &[FittedTreeGrid]) -> Vec<FittedTreeGrid> {
    if grids.is_empty() {
        return Vec::new();
    }

    let num_axes = grids[0].intervals.len();
    let mut union_splits: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
    let mut union_intervals: Vec<Vec<(f64, f64)>> = Vec::with_capacity(num_axes);

    // Build union splits for each axis
    for axis in 0..num_axes {
        let mut splits: Vec<f64> = grids
            .iter()
            .flat_map(|grid| grid.splits[axis].iter().copied())
            .collect();

        splits.sort_by(|a, b| a.partial_cmp(b).unwrap());
        splits.dedup_by(|a, b| (*a - *b).abs() < 1e-12);

        // Create union intervals
        let mut intervals: Vec<(f64, f64)> = Vec::new();
        if splits.is_empty() {
            intervals.push((f64::NEG_INFINITY, f64::INFINITY));
        } else {
            intervals.push((f64::NEG_INFINITY, splits[0]));
            for i in 0..splits.len() - 1 {
                intervals.push((splits[i], splits[i + 1]));
            }
            intervals.push((splits[splits.len() - 1], f64::INFINITY));
        }

        union_splits.push(splits);
        union_intervals.push(intervals);
    }

    // Refine each grid to union structure
    let mut refined_grids = Vec::with_capacity(grids.len());
    for grid in grids {
        let mut refined_backbone: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
        let mut refined_tilt: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
        let mut refined_observation_counts: Vec<Vec<usize>> = Vec::with_capacity(num_axes);

        for axis in 0..num_axes {
            let n_union_bins = union_intervals[axis].len();
            let mut backbone_axis = Vec::with_capacity(n_union_bins);
            let mut tilt_axis = Vec::with_capacity(n_union_bins);
            let counts_axis = vec![0; n_union_bins];

            // For each union interval, find the corresponding value from the original grid
            for &(union_a, union_b) in &union_intervals[axis] {
                // Find the original interval that contains this union interval
                let mut found = false;
                for (orig_idx, &(orig_a, orig_b)) in grid.intervals[axis].iter().enumerate() {
                    if union_a >= orig_a && union_b <= orig_b {
                        // Copy values from parent interval
                        backbone_axis.push(grid.backbone_values[axis][orig_idx]);
                        tilt_axis.push(grid.tilt_values[axis][orig_idx]);
                        found = true;
                        break;
                    }
                }
                if !found {
                    // Should not happen if union is constructed correctly
                    log::warn!(
                        "Could not find parent interval for union interval [{}, {})",
                        union_a,
                        union_b
                    );
                    backbone_axis.push(1.0);
                    tilt_axis.push(0.0);
                }
            }

            refined_backbone.push(backbone_axis);
            refined_tilt.push(tilt_axis);
            refined_observation_counts.push(counts_axis);
        }

        // Create refined grid (observation counts will be recomputed later if needed)
        let refined_grid = FittedTreeGrid::new_two_tensor(
            union_splits.clone(),
            refined_observation_counts,
            union_intervals.clone(),
            refined_backbone,
            refined_tilt,
            grid.lambda_plus,
            grid.lambda_minus,
        );

        refined_grids.push(refined_grid);
    }

    refined_grids
}

/// Main function: Combine two-tensor grids using geometric mean on a_± factors.
///
/// Algorithm:
/// 1. Refine all grids to union grid (common split points)
/// 2. Compute a_± factors from (backbone, tilt) for each candidate
/// 3. Take weighted geometric mean of a_± factors in log-space
/// 4. Take weighted arithmetic mean of lambda_± scalars
/// 5. Convert combined a_± back to (backbone, tilt)
/// 6. Create new FittedTreeGrid
pub fn combine_two_tensor_grids_geometric_mean(
    grids: &[FittedTreeGrid],
    weights: Option<&[f64]>,
    points: ArrayView2<f64>,
) -> FittedTreeGrid {
    if grids.is_empty() {
        panic!("Cannot combine empty grid list");
    }

    if grids.len() == 1 {
        // Single grid: return a copy
        return grids[0].clone();
    }

    // Default to uniform weights if not provided
    let uniform_weights: Vec<f64> = vec![1.0; grids.len()];
    let weights = weights.unwrap_or(&uniform_weights);

    if weights.len() != grids.len() {
        panic!("Weights length must match grids length");
    }

    // Step 1: Refine grids to union grid
    // This ensures all grids have the same structure (same intervals per axis)
    let refined_grids = refine_grids_to_union_two_tensor(grids);

    // Step 2: Compute a_± factors for each candidate
    let mut a_plus_candidates: Vec<Vec<Vec<f64>>> = Vec::with_capacity(refined_grids.len());
    let mut a_minus_candidates: Vec<Vec<Vec<f64>>> = Vec::with_capacity(refined_grids.len());
    let mut lambda_plus_candidates: Vec<f64> = Vec::with_capacity(refined_grids.len());
    let mut lambda_minus_candidates: Vec<f64> = Vec::with_capacity(refined_grids.len());

    for grid in &refined_grids {
        let (a_plus, a_minus) = compute_a_factors_from_bd(&grid.backbone_values, &grid.tilt_values);
        a_plus_candidates.push(a_plus);
        a_minus_candidates.push(a_minus);
        lambda_plus_candidates.push(grid.lambda_plus);
        lambda_minus_candidates.push(grid.lambda_minus);
    }

    // Step 3: Geometric mean combination of a_± factors
    let (combined_a_plus, combined_a_minus) =
        geometric_mean_combine_a_factors(&a_plus_candidates, &a_minus_candidates, weights);

    // Step 4: Arithmetic mean combination of lambdas
    let (combined_lambda_plus, combined_lambda_minus) =
        combine_lambdas_weighted(&lambda_plus_candidates, &lambda_minus_candidates, weights);

    // Step 5: Convert combined a_± back to (backbone, tilt)
    let (combined_backbone, combined_tilt) =
        convert_a_factors_to_bd(&combined_a_plus, &combined_a_minus);

    // Compute observation counts for the union intervals
    let num_axes = refined_grids[0].intervals.len();
    let mut combined_observation_counts: Vec<Vec<usize>> = Vec::with_capacity(num_axes);
    for axis in 0..num_axes {
        let n_bins = refined_grids[0].intervals[axis].len();
        let mut counts: Vec<usize> = vec![0; n_bins];
        let splits = &refined_grids[0].splits[axis];

        // Count observations in each bin
        let mut vals: Vec<f64> = points.column(axis).iter().copied().collect();
        vals.sort_by(|a, b| a.partial_cmp(b).unwrap());

        let mut s_idx: usize = 0;
        let b_len = splits.len();
        for v in vals {
            while s_idx < b_len && v >= splits[s_idx] {
                s_idx += 1;
            }
            counts[s_idx] += 1;
        }

        combined_observation_counts.push(counts);
    }

    // scaling field is kept for backward compatibility but not used in two-tensor model

    // Step 7: Create new FittedTreeGrid
    FittedTreeGrid::new_two_tensor(
        refined_grids[0].splits.clone(),
        combined_observation_counts,
        refined_grids[0].intervals.clone(),
        combined_backbone,
        combined_tilt,
        combined_lambda_plus,
        combined_lambda_minus,
    )
}

/// Combine two-tensor grids by selecting a symmetric window around the median backbone
/// for each interval and computing the geometric mean of the corresponding a_± factors.
///
/// Algorithm:
/// 1. Refine all grids to the union grid (same intervals).
/// 2. For each axis and interval, collect backbone values across candidates and sort.
/// 3. Select `tresh` fraction of candidates symmetrically around the median (at least 1).
/// 4. For the selected subset compute geometric mean of a_± (in log-space).
/// 5. Combine lambda_± by arithmetic mean (over all grids, consistent with legacy behavior).
/// 6. Convert combined a_± back to (backbone, tilt) and build the resulting grid.
pub fn combine_median_two_tensor_grids_geometric_mean(
    grids: &[FittedTreeGrid],
    tresh: f64,
    points: ArrayView2<f64>,
) -> FittedTreeGrid {
    if grids.is_empty() {
        panic!("Cannot combine empty grid list");
    }

    if grids.len() == 1 {
        return grids[0].clone();
    }

    let tresh = if tresh <= 0.0 {
        0.0
    } else if tresh > 1.0 {
        1.0
    } else {
        tresh
    };

    // Use uniform weights = 1.0 for all grids (no weights parameter accepted)
    let uniform_weights: Vec<f64> = vec![1.0; grids.len()];
    let weights: &[f64] = &uniform_weights;

    // Step 1: Refine grids to union intervals
    let refined_grids = refine_grids_to_union_two_tensor(grids);

    // Step 2: Precompute a_± factors and collect backbone copies for selection
    let mut a_plus_candidates: Vec<Vec<Vec<f64>>> = Vec::with_capacity(refined_grids.len());
    let mut a_minus_candidates: Vec<Vec<Vec<f64>>> = Vec::with_capacity(refined_grids.len());
    let mut lambda_plus_candidates: Vec<f64> = Vec::with_capacity(refined_grids.len());
    let mut lambda_minus_candidates: Vec<f64> = Vec::with_capacity(refined_grids.len());
    let mut backbone_values_per_grid: Vec<Vec<Vec<f64>>> = Vec::with_capacity(refined_grids.len());

    for grid in &refined_grids {
        let (a_plus, a_minus) = compute_a_factors_from_bd(&grid.backbone_values, &grid.tilt_values);
        a_plus_candidates.push(a_plus);
        a_minus_candidates.push(a_minus);
        lambda_plus_candidates.push(grid.lambda_plus);
        lambda_minus_candidates.push(grid.lambda_minus);
        backbone_values_per_grid.push(grid.backbone_values.clone());
    }

    // Prepare containers for combined a_±
    let num_axes = refined_grids[0].intervals.len();
    let mut combined_a_plus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);
    let mut combined_a_minus: Vec<Vec<f64>> = Vec::with_capacity(num_axes);

    for axis in 0..num_axes {
        let n_bins = refined_grids[0].intervals[axis].len();
        let mut a_plus_axis: Vec<f64> = Vec::with_capacity(n_bins);
        let mut a_minus_axis: Vec<f64> = Vec::with_capacity(n_bins);

        for bin in 0..n_bins {
            // Build (backbone_value, grid_idx) pairs and sort by backbone
            let mut pairs: Vec<(f64, usize)> = backbone_values_per_grid
                .iter()
                .enumerate()
                .map(|(idx, b)| (b[axis][bin], idx))
                .collect();
            pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

            let n = pairs.len();

            // Determine symmetric window size around median (at least 1)
            let window = ((tresh * n as f64).ceil() as usize).max(1);
            let mid = n / 2;
            let half = window / 2;
            let mut start = mid.saturating_sub(half);
            let mut end = start + window;
            if end > n {
                end = n;
                start = n.saturating_sub(window);
            }

            // Collect selected candidate indices
            let selected_indices: Vec<usize> = pairs[start..end].iter().map(|p| p.1).collect();

            // Compute geometric mean of a_± over the selected candidates (log-space)
            const EPSILON: f64 = 1e-10;
            let mut log_sum_plus = 0.0;
            let mut log_sum_minus = 0.0;
            for &idx in &selected_indices {
                let a_p = a_plus_candidates[idx][axis][bin].max(EPSILON);
                let a_m = a_minus_candidates[idx][axis][bin].max(EPSILON);
                log_sum_plus += a_p.ln();
                log_sum_minus += a_m.ln();
            }
            let count = selected_indices.len() as f64;
            let combined_p = if count > 0.0 {
                (log_sum_plus / count).exp()
            } else {
                EPSILON
            };
            let combined_m = if count > 0.0 {
                (log_sum_minus / count).exp()
            } else {
                EPSILON
            };

            a_plus_axis.push(combined_p);
            a_minus_axis.push(combined_m);
        }

        combined_a_plus.push(a_plus_axis);
        combined_a_minus.push(a_minus_axis);
    }

    // Combine lambda scalars by weighted arithmetic mean (reuse existing helper)
    let (combined_lambda_plus, combined_lambda_minus) =
        combine_lambdas_weighted(&lambda_plus_candidates, &lambda_minus_candidates, weights);

    // Convert back to backbone and tilt
    let (combined_backbone, combined_tilt) =
        convert_a_factors_to_bd(&combined_a_plus, &combined_a_minus);

    // Compute observation counts for the union intervals (same logic as other combiner)
    let mut combined_observation_counts: Vec<Vec<usize>> = Vec::with_capacity(num_axes);
    for axis in 0..num_axes {
        let n_bins = refined_grids[0].intervals[axis].len();
        let mut counts: Vec<usize> = vec![0; n_bins];
        let splits = &refined_grids[0].splits[axis];

        let mut vals: Vec<f64> = points.column(axis).iter().copied().collect();
        vals.sort_by(|a, b| a.partial_cmp(b).unwrap());

        let mut s_idx: usize = 0;
        let b_len = splits.len();
        for v in vals {
            while s_idx < b_len && v >= splits[s_idx] {
                s_idx += 1;
            }
            counts[s_idx] += 1;
        }

        combined_observation_counts.push(counts);
    }

    // scaling field is kept for backward compatibility but not used in two-tensor model

    // Build resulting FittedTreeGrid
    FittedTreeGrid::new_two_tensor(
        refined_grids[0].splits.clone(),
        combined_observation_counts,
        refined_grids[0].intervals.clone(),
        combined_backbone,
        combined_tilt,
        combined_lambda_plus,
        combined_lambda_minus,
    )
}

// geometric_mean_combiner is kept for Aggregation::GeometricMean (prediction space aggregation)
pub(super) fn geometric_mean_combiner(values: &[f64]) -> f64 {
    if values.is_empty() {
        return 0.0;
    }
    let sign = values.iter().map(|v| v.signum()).sum::<f64>().signum();
    let log_sum = values.iter().map(|v| v.abs().ln()).sum::<f64>();
    let geom_mean = (log_sum / values.len() as f64).exp();
    sign * geom_mean
}
