use ndarray::{Array1, ArrayView2};
use serde::{Deserialize, Serialize};

use crate::family::TreeGridFamily;

mod fitter;
pub mod params;
pub use fitter::{fit_boosted, fit_boosted_with_test_error};

#[derive(Debug, Serialize, Deserialize)]
pub struct MPF {
    tree_grid_families: Vec<TreeGridFamily>,
}

impl MPF {
    pub fn get_tree_grid_families(&self) -> &Vec<TreeGridFamily> {
        &self.tree_grid_families
    }
}

impl MPF {
    pub const fn new(tree_grid_families: Vec<TreeGridFamily>) -> Self {
        Self { tree_grid_families }
    }
}

impl MPF {
    pub fn predict(&self, x: ArrayView2<f64>) -> Array1<f64> {
        let mut result = Array1::zeros(x.shape()[0]);
        for tree_grid_family in &self.tree_grid_families {
            result += &tree_grid_family.predict(x);
        }

        result
    }
}
