use super::grid::FittedTreeGrid;
pub use fitter::fit_ensemble;
use ndarray::{Array1, ArrayView2};
use serde::{Deserialize, Serialize};
pub mod params;

mod aggregate_bagged;
mod combine_grids;
mod fitter;
mod projection;
mod reconstruction;
mod reference_grid;
mod similarity;
mod tensor_power;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TreeGridFamily {
    pub tree_grids: Vec<FittedTreeGrid>,
    pub primary_tree_grid: FittedTreeGrid,
    pub candidate_indices: Option<Vec<usize>>,
    pub aggregation_method: Aggregation,
    pub scaling_plus: Option<f64>,
    pub scaling_minus: Option<f64>,
    pub energy: Option<f64>,
}

#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
pub enum Aggregation {
    Mean,
    GeometricMean,
    Combined,
}

impl Aggregation {
    fn predict(&self, x: ArrayView2<f64>, tgf: &TreeGridFamily) -> Array1<f64> {
        // Extract UNSCALED f+ and f- predictions from primary_tree_grid (two-tensor fields)
        // IMPORTANT: Extract directly from two-tensor fields, do NOT use grid.predict()
        // which might apply scaling. We want raw f+ and f- values.
        let (f_plus, f_minus) = extract_two_tensor_predictions_unscaled(&tgf.primary_tree_grid, x);

        // Apply scalings from OLS solution: scaling_plus * f_+ + scaling_minus * (-f_-)
        // Note: scaling_minus is the coefficient for the -f_- column in the design matrix
        // So we multiply by -f_minus (the column value) and scaling_minus (the coefficient)
        let scaling_plus = tgf.scaling_plus.unwrap_or(1.0);
        let scaling_minus = tgf.scaling_minus.unwrap_or(0.0);
        &f_plus * scaling_plus + &(-f_minus) * scaling_minus
    }

    fn predict_unscaled(&self, x: ArrayView2<f64>, tgf: &TreeGridFamily) -> Array1<f64> {
        let n = x.nrows();
        let n_grids = tgf.tree_grids.len();
        match self {
            Self::Mean => {
                let mut preds = Array1::zeros(n);
                for tree_grid in tgf.tree_grids.iter() {
                    preds += &tree_grid.predict(x);
                }
                preds / n_grids as f64
            }
            Self::GeometricMean => {
                let preds: Vec<Array1<f64>> =
                    tgf.tree_grids.iter().map(|tg| tg.predict(x)).collect();
                let mut stacked_preds = Vec::with_capacity(n * n_grids);
                for i in 0..n {
                    for pred in preds.iter() {
                        stacked_preds.push(pred[i]);
                    }
                }

                Array1::from_shape_fn(n, |i| {
                    let slice = &stacked_preds[i * n_grids..(i + 1) * n_grids];
                    combine_grids::geometric_mean_combiner(slice)
                })
            }
            Self::Combined => {
                // For Combined aggregation, extract f+ and f- separately
                let (f_plus, f_minus) =
                    extract_two_tensor_predictions_unscaled(&tgf.primary_tree_grid, x);
                f_plus - f_minus
            }
        }
    }
}

/// Extract UNSCALED f+ and f- from two-tensor grid.
/// CRITICAL: This extracts raw values from two-tensor fields, NO scaling applied.
/// The grid.scaling field is IGNORED - we only use lambda_plus and lambda_minus.
pub fn extract_two_tensor_predictions_unscaled(
    grid: &FittedTreeGrid,
    x: ArrayView2<f64>,
) -> (Array1<f64>, Array1<f64>) {
    let n = x.nrows();
    let mut f_plus = Array1::zeros(n);
    let mut f_minus = Array1::zeros(n);

    // For each point, compute directly from two-tensor fields:
    // f_+ = lambda_+ * prod_j (backbone[j][k] * exp(tilt[j][k]))
    // f_- = lambda_- * prod_j (backbone[j][k] * exp(-tilt[j][k]))
    //
    // CRITICAL: Do NOT multiply by grid.scaling!
    // - grid.scaling is ONLY used in legacy mode (when backbone_values is None)
    // - For two-tensor mode, FittedTreeGrid::predict_single_unscaled() does NOT apply scaling
    // - We extract f+ and f- using ONLY: lambda_+, lambda_-, backbone_values, tilt_values
    // - Scaling will be applied LATER via scaling_plus and scaling_minus from OLS solution
    // - This ensures scaling is applied exactly ONCE (in TreeGridFamily::predict())

    for i in 0..n {
        let mut fp = grid.lambda_plus;
        let mut fm = grid.lambda_minus;
        for j in 0..x.ncols() {
            let val = x[[i, j]];
            let col_idx = grid.splits[j].partition_point(|&split| split <= val);
            let col_idx = col_idx.min(grid.backbone_values[j].len() - 1);
            let b = grid.backbone_values[j][col_idx];
            let d = grid.tilt_values[j][col_idx];
            fp *= b * d.exp();
            fm *= b * (-d).exp();
        }
        f_plus[i] = fp;
        f_minus[i] = fm;
    }

    (f_plus, f_minus)
}

impl TreeGridFamily {
    pub fn get_tree_grids(&self) -> &Vec<FittedTreeGrid> {
        &self.tree_grids
    }

    pub fn get_primary_tree_grid(&self) -> &FittedTreeGrid {
        &self.primary_tree_grid
    }

    pub fn get_candidate_indices(&self) -> Option<&[usize]> {
        self.candidate_indices.as_deref()
    }

    pub fn new_exact(tree_grids: Vec<FittedTreeGrid>, aggregation_method: Aggregation) -> Self {
        let primary_tree_grid = tree_grids[0].clone();
        Self {
            tree_grids,
            primary_tree_grid,
            candidate_indices: None,
            aggregation_method,
            scaling_plus: None,
            scaling_minus: None,
            energy: None,
        }
    }

    pub fn new_ensemble(
        tree_grids: Vec<FittedTreeGrid>,
        primary_tree_grid: FittedTreeGrid,
        candidate_indices: Vec<usize>,
        aggregation_method: Aggregation,
    ) -> Self {
        let scaling_plus = primary_tree_grid.lambda_plus;
        let scaling_minus = primary_tree_grid.lambda_minus;
        Self {
            tree_grids,
            primary_tree_grid,
            candidate_indices: Some(candidate_indices),
            aggregation_method,
            scaling_plus: Some(scaling_plus),
            scaling_minus: Some(scaling_minus),
            energy: None,
        }
    }

    pub fn predict_unscaled(&self, x: ArrayView2<f64>) -> Array1<f64> {
        self.aggregation_method.predict_unscaled(x, self)
    }

    /// Combine candidate grids using geometric mean on two-tensor a_± factors.
    /// This updates the primary_tree_grid with the combined result.
    ///
    /// # Arguments
    /// * `x` - Training points (used for computing observation counts)
    /// * `weights` - Optional weights for each candidate grid (defaults to uniform)
    ///
    /// # Returns
    /// A new `TreeGridFamily` with the combined primary_tree_grid.
    pub fn combine_candidates_two_tensor_geometric_mean(
        &self,
        x: ArrayView2<f64>,
        weights: Option<&[f64]>,
    ) -> Self {
        use crate::family::combine_grids::combine_two_tensor_grids_geometric_mean;

        // Get candidate grids (if candidate_indices is set, use those; otherwise use all)
        let candidate_grids: Vec<&FittedTreeGrid> =
            if let Some(ref indices) = self.candidate_indices {
                indices.iter().map(|&idx| &self.tree_grids[idx]).collect()
            } else {
                self.tree_grids.iter().collect()
            };

        if candidate_grids.is_empty() {
            return self.clone();
        }

        // Convert to owned grids for combination
        let candidate_grids_owned: Vec<FittedTreeGrid> =
            candidate_grids.into_iter().cloned().collect();

        // Combine using geometric mean
        let combined_grid =
            combine_two_tensor_grids_geometric_mean(&candidate_grids_owned, weights, x);

        // Create new TreeGridFamily with combined primary grid
        Self {
            tree_grids: self.tree_grids.clone(),
            primary_tree_grid: combined_grid,
            candidate_indices: self.candidate_indices.clone(),
            aggregation_method: self.aggregation_method.clone(),
            scaling_plus: self.scaling_plus,
            scaling_minus: self.scaling_minus,
            energy: self.energy,
        }
    }
}

impl TreeGridFamily {
    pub fn predict(&self, x: ArrayView2<f64>) -> Array1<f64> {
        self.aggregation_method.predict(x, self)
    }
}
