mod data_utils;

#[cfg(test)]
mod tests {

    use crate::data_utils::split_data_rng;

    use super::data_utils::setup_data_csv;
    use mpf::{
        grid::{
            fit,
            params::{
                RefinementStrategyParamsBuilder, SplitStrategyParamsBuilder, TreeGridParams,
                TreeGridParamsBuilder,
            },
        },
    };
    use ndarray::{Array1, Array2};
    use rand::{rngs::StdRng, SeedableRng};

    fn simulate_exact_tree_grid(
        x: &Array2<f64>,
        y: &Array1<f64>,
        params: TreeGridParams,
        iterations: usize,
    ) {
        let mut rng = StdRng::seed_from_u64(42);
        let mut test_errors = Vec::new();
        let mut train_errors = Vec::new();

        for _ in 0..iterations {
            let (x_train, y_train, x_test, y_test) = split_data_rng(x, y, &mut rng);
            let (fit_result, tg) = fit(x_train.view(), y_train.view(), &params, &mut rng);
            let preds = tg.predict(x_test.view());
            let residuals = y_test.clone() - preds;
            let test_err = residuals.powi(2).mean().unwrap();
            let train_err = fit_result.err;
            println!("test err: {:#?}, train err: {:#?}", test_err, train_err);
            test_errors.push(test_err);
            train_errors.push(train_err);
        }

        let test_errors_mean = test_errors.iter().sum::<f64>() / iterations as f64;
        let train_errors_mean = train_errors.iter().sum::<f64>() / iterations as f64;
        let test_errors_std = test_errors
            .iter()
            .map(|x| (x - test_errors_mean).powi(2))
            .sum::<f64>()
            / iterations as f64;
        let train_errors_std = train_errors
            .iter()
            .map(|x| (x - train_errors_mean).powi(2))
            .sum::<f64>()
            / iterations as f64;
        println!(
            "test errors mean and std: {:#?}, {:#?}",
            test_errors_mean, test_errors_std
        );
        println!(
            "train errors mean and std: {:#?}, {:#?}",
            train_errors_mean, train_errors_std
        );
    }

    #[test]
    fn test_housing_exact_tree_grid() {
        let (x, y) = setup_data_csv("./data/debug_data.csv");
        let params = TreeGridParamsBuilder::new()
            .n_iter(100)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(2)
                    .colsample_bytree(1.0)
                    .min_interval_samples(3)
                    .min_split_loss(3.0)
                    .build(),
            )
            .build();
        let mut rng = StdRng::seed_from_u64(8256015244108642386);
        let (fit_result, tg) = fit(x.view(), y.view(), &params, &mut rng);
        let preds = tg.predict(x.view());

        let residuals = y - preds;
        let computed_err = residuals.powi(2).mean().unwrap();
        println!("err: {:#?}", computed_err);
        println!("Error: {:#?}", fit_result.err);

        assert!((computed_err - fit_result.err).abs() < 1e-10);
    }

    #[test]
    fn test_tree_grid_fit_result_equals_computed() {
        let (x, y) = setup_data_csv("data/cps88wages.csv");
        let params = TreeGridParamsBuilder::new()
            .n_iter(16)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(2)
                    .colsample_bytree(0.6237241962915381)
                    .min_interval_samples(24)
                    .min_split_loss(0.4385901864308926)
                    .build(),
            )
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .huber()
                    .alpha(0.0054461959546609874)
                    .prior_sample_size(0.0)
                    .build(),
            )
            .build();

        let mut rng = StdRng::seed_from_u64(42);
        let (fit_result, model) = fit(x.view(), y.view(), &params, &mut rng);

        let preds = model.predict(x.view());
        let computed_err = (y - preds.clone()).powi(2).mean().unwrap();
        println!(
            "Fit result: {:?}, Computed err: {:?}",
            fit_result.err, computed_err
        );
        assert!((fit_result.err - computed_err).abs() < 1e-10);
        assert!((preds - fit_result.y_hat).abs().iter().all(|&x| x < 1e-10));
    }

    #[test]
    fn test_housing_exact_tree_grid_top_k_mean_and_std() {
        let (x, y) = setup_data_csv("./data/housing_full.csv");
        let params = TreeGridParamsBuilder::new()
            .n_iter(50)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .top_k_splits()
                    .top_k(10)
                    .must_fill_all_k(false)
                    .min_interval_samples(1)
                    .build(),
            )
            .build();
        simulate_exact_tree_grid(&x, &y, params, 100);
    }

    #[test]
    fn test_housing_exact_tree_grid_best_split_mean_and_std() {
        let (x, y) = setup_data_csv("./data/housing_full.csv");
        let params = TreeGridParamsBuilder::new()
            .n_iter(50)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .best_split()
                    .min_interval_samples(1)
                    .build(),
            )
            .build();
        simulate_exact_tree_grid(&x, &y, params, 100);
    }

    #[test]
    fn test_tree_grid_fit_interval_split() {
        let (x, y) = setup_data_csv("./data/2covars.csv");
        let mut rng = StdRng::seed_from_u64(42);
        let hyperparameters = TreeGridParamsBuilder::new()
            .n_iter(24)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .best_split()
                    .min_interval_samples(1)
                    .build(),
            )
            .build();
        let (fit_result, _) = fit(x.view(), y.view(), &hyperparameters, &mut rng);
        let mean = y.mean().unwrap();
        let base_err = (y - mean).powi(2).mean().unwrap();
        println!("Base error: {:?}, Error: {:?}", base_err, fit_result.err);
        assert!(
            fit_result.err < base_err,
            "Error is not less than mean error"
        );
    }
}
