mod data_utils;

#[cfg(test)]
mod data_tests {
    use crate::data_utils::{setup_data_csv, split_data};

    use mpf::{
        family::{
            params::CombinationStrategyParams,
            Aggregation,
        },
        forest::{fit_boosted, params::MPFBoostedParamsBuilder},
        grid::params::{RefinementStrategyParamsBuilder, SplitStrategyParamsBuilder},
    };

    #[test]
    fn test_mpf_housing_with_logging() {
        let (x_train, y_train) = setup_data_csv("./target/x_train_residuals.csv");

        // Set up logging
        let params = MPFBoostedParamsBuilder::new()
            .epochs(2)
            .n_iter(60)
            .n_trees(200)
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(0.01)
                    .build(),
            )
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .similarity_threshold(0.097)
            .decay(0.9)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(10)
                    .colsample_bytree(1.0)
                    .min_interval_samples(20)
                    .min_split_loss(0.9)
                    .build(),
            )
            .log_level("debug")
            .build();

        println!("🏠 Starting MPF Housing Dataset Training with Logging");

        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        println!("Fit result: {:?}", fit_result);
        let preds = model.predict(x_train.view());
        let mean = y_train.mean().unwrap();
        let base_err = y_train.view().map(|v| v - mean).powi(2).mean().unwrap();
        let test_err = (y_train - preds).powi(2).mean().unwrap();
        println!(
            "Base error: {:?}, Training Error: {:?}, Test Error: {:?}",
            base_err, fit_result.err, test_err
        );
    }

    #[test]
    fn test_mpf_cps_panic() {
        let (x, y) = setup_data_csv("./data/cps88wages_zero_alignment.csv");
        let params = MPFBoostedParamsBuilder::new()
            .epochs(8)
            .n_trees(105)
            .n_iter(87)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(2)
                    .colsample_bytree(0.9393363304182665)
                    .min_interval_samples(2)
                    .min_split_loss(0.04332548652534609)
                    .build(),
            )
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .similarity_threshold(0.9976482633753242)
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(9.927_027_908_649_666e-5)
                    .prior_sample_size(0.0)
                    .build(),
            )
            .bagged(false)
            .log_level("debug")
            .decay(0.6497822267834119)
            .seed(42)
            .build();
        let (fit_result, model) = fit_boosted(x.view(), y.view(), &params);
        println!("Base error: {:?}", fit_result.err);
    }

    #[test]
    fn test_mpf_boosted_fit_friedman1_default() {
        let (x, y) = setup_data_csv("./data/friedman1.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);
        let params = MPFBoostedParamsBuilder::new()
            .epochs(3)
            .n_trees(100)
            .n_iter(10)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(14)
                    .colsample_bytree(0.749816047538945)
                    .min_interval_samples(3)
                    .build(),
            )
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .aggregation_method(Aggregation::Combined)
            .seed(42)
            .build();
        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        let mean = y_test.mean().unwrap();
        let base_err = y_test.view().map(|v| v - mean).powi(2).mean().unwrap();
        let preds = model.predict(x_test.view());
        let test_err = (y_test - preds).powi(2).mean().unwrap();
        println!(
            "Base error: {:?}, Training Error: {:?}, Test Error: {:?}",
            base_err, fit_result.err, test_err
        );
        assert!(test_err < 0.1);
    }

    #[test]
    fn test_mpf_boosted_fit_2_covariates() {
        let (x, y) = setup_data_csv("./data/2covars.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);

        let params = MPFBoostedParamsBuilder::new()
            .epochs(5)
            .n_trees(5)
            .n_iter(25)
            .seed(42)
            .build();
        let (fit_result, mpf) = fit_boosted(x_train.view(), y_train.view(), &params);
        let mean = y_test.mean().unwrap();
        let base_err = y_test.view().map(|v| v - mean).powi(2).mean().unwrap();
        let preds = mpf.predict(x_test.view());
        let test_err = (y_test - preds).powi(2).mean().unwrap();
        println! {"Base error: {:?}, Training Error: {:?}, Test Error: {:?}", base_err, fit_result.err, test_err};

        assert!(test_err < 0.7, "Error is not less than mean error");
    }

    #[test]
    fn test_mpf_housing() {
        let (x, y) = setup_data_csv("./data/housing_full.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);
        // Use builder pattern for cleaner parameter construction
        let params = MPFBoostedParamsBuilder::new()
            .epochs(8)
            .n_iter(100) // Using default, but explicitly stated for clarity
            .n_trees(10)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .similarity_threshold(0.1)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(2)
                    .colsample_bytree(1.0)
                    .min_interval_samples(3)
                    .min_split_loss(3.0)
                    .build(),
            )
            .build();

        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        let preds = model.predict(x_test.view());
        let mean = y_test.mean().unwrap();
        let base_err = y_test.view().map(|v| v - mean).powi(2).mean().unwrap();

        let err = (y_test - preds).powi(2).mean().unwrap();

        println!(
            "Base error: {:?}, Train Error: {:?}, Test Error: {:?}",
            base_err, fit_result.err, err
        );
        assert!(err < 0.32);
    }
}

#[cfg(test)]
mod reproducibility_tests {
    use crate::data_utils::setup_data_csv;

    use mpf::forest::{fit_boosted, params::MPFBoostedParamsBuilder};

    use std::ops::Div;

    #[test]
    fn test_mpf_boosted_reproducibility() {
        let (x, y) = setup_data_csv("./data/2covars.csv");

        // Use builder pattern for cleaner parameter construction
        let params = MPFBoostedParamsBuilder::new()
            .epochs(10)
            .n_trees(5)
            .n_iter(25) // Using default, but explicitly stated for clarity
            .seed(42)
            .build();

        // Train two models with the same seed
        let (_, model1) = fit_boosted(x.view(), y.view(), &params);
        let (_, model2) = fit_boosted(x.view(), y.view(), &params);

        // Generate predictions
        let pred1 = model1.predict(x.view());
        let pred2 = model2.predict(x.view());

        // Check predictions are identical
        let diff = &pred1 - &pred2;
        assert!(
            diff.iter().all(|&x| x.abs() < 1e-10),
            "Models with same seed produced different predictions"
        );
    }

    #[test]
    fn test_mpf_boosted_different_seeds() {
        let (x, y) = setup_data_csv("./data/2covars.csv");

        // Use builder pattern for cleaner parameter construction
        let params1 = MPFBoostedParamsBuilder::new()
            .epochs(2)
            .n_trees(5)
            .n_iter(25) // Using default, but explicitly stated for clarity
            .seed(42)
            .build();

        // Train models with different seeds
        let (_, model1) = fit_boosted(x.view(), y.view(), &params1);

        let params2 = MPFBoostedParamsBuilder::new()
            .epochs(2)
            .n_trees(5)
            .n_iter(25)
            .seed(43) // Different seed
            .build();

        let (_, model2) = fit_boosted(x.view(), y.view(), &params2);

        // Generate predictions
        let pred1 = model1.predict(x.view());
        let pred2 = model2.predict(x.view());

        // Check predictions are different
        let diff = &pred1 - &pred2;
        assert!(
            diff.iter().any(|&x| x.abs() > 1e-10),
            "Models with different seeds produced identical predictions"
        );
    }

    #[test]

    fn test_fit_result_error_is_y_minus_sum_preds() {
        let (x, y) = setup_data_csv("./data/2covars.csv");
        let params = MPFBoostedParamsBuilder::new()
            .epochs(10)
            .n_trees(10)
            .n_iter(10)
            .seed(42)
            .build();
        let (fit_result, model) = fit_boosted(x.view(), y.view(), &params);
        let preds = model.predict(x.view());
        let err = y
            .view()
            .iter()
            .zip(preds.iter())
            .map(|(y, p)| (y - p).powi(2))
            .sum::<f64>()
            .div(y.len() as f64);

        assert!((fit_result.err - err).abs() < 1e-15);
    }
}
