mod data_utils;

#[cfg(test)]
mod tests {
    use crate::data_utils::{setup_data_csv, split_data};
    use mpf::{
        family::{
            fit_ensemble,
            params::{CombinationStrategyParams, TreeGridFamilyParamsBuilder},
        },
        grid::params::{RefinementStrategyParamsBuilder, SplitStrategyParamsBuilder},
        logging::init_logging,
    };
    use rand::{rngs::StdRng, SeedableRng};

    #[test]
    fn test_family_fit_result_equals_computed() {
        init_logging("debug");
        let (x, y) = setup_data_csv("data/cps88wages.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);
        let params = TreeGridFamilyParamsBuilder::new()
            .n_trees(30)
            .n_iter(16)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(2)
                    .colsample_bytree(0.6237241962915381)
                    .min_interval_samples(24)
                    .min_split_loss(0.4385901864308926)
                    .build(),
            )
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(0.0054461959546609874)
                    .prior_sample_size(0.0)
                    .build(),
            )
            .similarity_threshold(0.5116064236426068)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .bagged(true)
            .build();

        let mut rng = StdRng::seed_from_u64(42);
        let (fit_result, model) = fit_ensemble(x_train.view(), y_train.view(), &params, &mut rng);

        let preds = model.predict(x_train.view());
        let computed_err = (y_train - preds.clone()).powi(2).mean().unwrap();
        println!("Preds mean: {:?}", preds.mean());
        println!(
            "Fit result: {:?}, Computed err: {:?}",
            fit_result.err, computed_err
        );
        assert!((fit_result.err - computed_err).abs() < 1e-10);
        assert!((preds - fit_result.y_hat).abs().iter().all(|&x| x < 1e-10));
    }

    #[test]
    fn test_l2_median_combined_tree_grid_predicts_well() {
        let (x, y) = setup_data_csv("./data/2covars.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);
        let mut rng = StdRng::seed_from_u64(42);
        let hyperparameters = TreeGridFamilyParamsBuilder::new()
            .n_trees(20)
            .bagged(false)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .similarity_threshold(0.1)
            .build();
        let (fit_result, tgf) =
            fit_ensemble(x_train.view(), y_train.view(), &hyperparameters, &mut rng);
        let pred = tgf.predict(x_test.view());
        let err = (y_test - pred).powi(2).mean().unwrap();
        println!("Train Error: {:?}, Test Error: {:?}", fit_result.err, err);
        assert!(err < 0.1);
    }

    #[test]
    fn test_l2_arith_geom_mean_combined_tree_grid_predicts_well() {
        let (x, y) = setup_data_csv("./data/2covars.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);
        let mut rng = StdRng::seed_from_u64(42);
        let hyperparameters = TreeGridFamilyParamsBuilder::new()
            .n_trees(20)
            .bagged(false)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .similarity_threshold(0.1)
            .build();
        let (fit_result, tgf) =
            fit_ensemble(x_train.view(), y_train.view(), &hyperparameters, &mut rng);
        let pred = tgf.predict(x_test.view());
        let err = (y_test - pred).powi(2).mean().unwrap();
        println!("Train Error: {:?}, Test Error: {:?}", fit_result.err, err);
        assert!(err < 0.1);
    }

    #[test]
    fn test_l2_arith_mean_combined_tree_grid_predicts_well() {
        let (x, y) = setup_data_csv("./data/2covars.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);
        let mut rng = StdRng::seed_from_u64(42);
        let hyperparameters = TreeGridFamilyParamsBuilder::new()
            .n_trees(20)
            .bagged(false)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .similarity_threshold(0.1)
            .build();
        let (fit_result, tgf) =
            fit_ensemble(x_train.view(), y_train.view(), &hyperparameters, &mut rng);
        let pred = tgf.predict(x_test.view());
        let err = (y_test - pred).powi(2).mean().unwrap();
        println!("Train Error: {:?}, Test Error: {:?}", fit_result.err, err);
        assert!(err < 0.1);
    }

    #[test]
    fn test_l2_geom_mean_combined_tree_grid_predicts_well() {
        let (x, y) = setup_data_csv("./data/2covars.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);
        let mut rng = StdRng::seed_from_u64(42);
        let hyperparameters = TreeGridFamilyParamsBuilder::new()
            .n_trees(20)
            .bagged(false)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .similarity_threshold(0.1)
            .build();
        let (fit_result, tgf) =
            fit_ensemble(x_train.view(), y_train.view(), &hyperparameters, &mut rng);
        let pred = tgf.predict(x_test.view());
        let err = (y_test - pred).powi(2).mean().unwrap();
        println!("Train Error: {:?}, Test Error: {:?}", fit_result.err, err);
        assert!(err < 0.1);
    }
}
