use ndarray::ArrayView2;

/// A fixed coordinate system (quantiles) that provides a common reference
/// for converting tree grids with different split points into fixed-length vectors.
#[derive(Debug, Clone)]
pub struct ReferenceGrid {
    /// The split points defining the bins for each axis (dimension).
    /// `axes_splits[j]` is a sorted vector of thresholds for feature j.
    /// Example: 256 quantiles.
    pub axes_splits: Vec<Vec<f64>>,

    /// The center point of each bin, used for evaluating the component functions.
    /// If splits are [0.0, 1.0], center is 0.5.
    pub axes_centers: Vec<Vec<f64>>,
}

impl ReferenceGrid {
    /// Creates a ReferenceGrid from training data by computing quantiles for each feature.
    ///
    /// # Arguments
    /// * `data` - Training data matrix (n_samples x n_features)
    /// * `n_bins` - Number of quantile bins to create per feature
    ///
    /// # Returns
    /// A ReferenceGrid with fixed quantile splits and bin centers for each axis.
    pub fn from_data(data: ArrayView2<f64>, n_bins: usize) -> Self {
        let n_features = data.ncols();
        let mut axes_splits = Vec::with_capacity(n_features);
        let mut axes_centers = Vec::with_capacity(n_features);

        // Iterate over each column (feature) j in `data`
        for j in 0..n_features {
            // Collect all values for this feature
            let mut values: Vec<f64> = data.column(j).iter().copied().collect();
            values.sort_by(|a, b| a.partial_cmp(b).unwrap());

            // Compute quantiles
            let n_values = values.len();
            let mut splits = Vec::new();

            if n_values == 0 {
                // Empty feature - create a single bin
                splits.push(0.0);
            } else if n_values == 1 {
                // Single value - use it as the only split
                splits.push(values[0]);
            } else {
                // Compute n_bins quantiles
                // If unique values < n_bins, just use the unique values
                let unique_values: Vec<f64> = {
                    let mut unique: Vec<f64> = Vec::new();
                    for &v in &values {
                        if unique.is_empty() || (v - unique[unique.len() - 1]).abs() > 1e-12 {
                            unique.push(v);
                        }
                    }
                    unique
                };

                if unique_values.len() <= n_bins {
                    // Use all unique values as splits
                    splits = unique_values;
                } else {
                    // Compute quantiles
                    for i in 0..=n_bins {
                        let quantile = i as f64 / n_bins as f64;
                        let idx = (quantile * (n_values - 1) as f64).round() as usize;
                        let idx = idx.min(n_values - 1);
                        splits.push(values[idx]);
                    }
                }

                // Deduplicate splits (crucial for categorical or sparse features)
                splits.dedup_by(|a, b| (*a - *b).abs() < 1e-12);
            }

            // Compute centers: center[i] = (split[i] + split[i+1]) / 2.0
            // We need n+1 centers for n splits (one per interval)
            // Intervals are: (-∞, split[0]), (split[0], split[1]), ..., (split[n-1], +∞)
            let mut centers = Vec::new();

            if splits.is_empty() {
                // No splits means one interval covering everything
                centers.push(0.0);
            } else {
                // First interval: (-∞, splits[0])
                // Use a center slightly below the first split, or the first split itself
                if splits.len() > 1 {
                    let first_interval_width = splits[1] - splits[0];
                    centers.push(splits[0] - first_interval_width / 2.0);
                } else {
                    centers.push(splits[0] - 1.0); // Fallback for single split
                }

                // Middle intervals: (splits[i], splits[i+1])
                for i in 0..splits.len() - 1 {
                    let center = (splits[i] + splits[i + 1]) / 2.0;
                    centers.push(center);
                }

                // Last interval: (splits[n-1], +∞)
                // Use a center slightly above the last split
                if splits.len() > 1 {
                    let last_interval_width = splits[splits.len() - 1] - splits[splits.len() - 2];
                    centers.push(splits[splits.len() - 1] + last_interval_width / 2.0);
                } else {
                    centers.push(splits[0] + 1.0); // Fallback for single split
                }
            }

            axes_splits.push(splits);
            axes_centers.push(centers);
        }

        Self {
            axes_splits,
            axes_centers,
        }
    }

    /// Returns the number of axes (features) in this reference grid.
    pub fn n_axes(&self) -> usize {
        self.axes_splits.len()
    }

    /// Returns the number of bins for a given axis.
    pub fn n_bins(&self, axis: usize) -> usize {
        self.axes_centers.get(axis).map(|c| c.len()).unwrap_or(0)
    }
}
