mod data_utils;

#[cfg(test)]
mod tests {
    use crate::data_utils::setup_data_csv;
    use mpf::{
        family::{
            params::CombinationStrategyParams,
            Aggregation,
        },
        forest::{fit_boosted, fit_boosted_with_test_error, params::MPFBoostedParamsBuilder},
        grid::params::{RefinementStrategyParamsBuilder, SplitStrategyParamsBuilder},
    };
    use ndarray::{Array1, Array2};
    use rand::rngs::StdRng;
    use rand::seq::SliceRandom;
    use rand::{Rng, SeedableRng};

    #[allow(dead_code)]
    pub fn split_data(
        x: &Array2<f64>,
        y: &Array1<f64>,
        seed: u64,
    ) -> (Array2<f64>, Array1<f64>, Array2<f64>, Array1<f64>) {
        let mut rng = StdRng::seed_from_u64(seed);
        split_data_rng(x, y, &mut rng)
    }

    #[allow(dead_code)]
    pub fn split_data_rng<R: Rng + ?Sized>(
        x: &Array2<f64>,
        y: &Array1<f64>,
        rng: &mut R,
    ) -> (Array2<f64>, Array1<f64>, Array2<f64>, Array1<f64>) {
        let n = y.len();
        let mut indices: Vec<usize> = (0..n).collect();
        indices.as_mut_slice().shuffle(rng);
        let split = n / 2;
        let train_idx = &indices[..split];
        let test_idx = &indices[split..];
        (
            x.select(ndarray::Axis(0), train_idx),
            y.select(ndarray::Axis(0), train_idx),
            x.select(ndarray::Axis(0), test_idx),
            y.select(ndarray::Axis(0), test_idx),
        )
    }

    #[test]
    fn test_mpf_brazilian_housing() {
        //MPF Test MSE: 11655274.692121824
        //XGBoost Test MSE: 43249928.20197494

        let (x, y) = setup_data_csv("data/brazilian_housing.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y, 1);

        let params = MPFBoostedParamsBuilder::new()
            .epochs(7)
            .n_trees(267)
            .n_iter(27)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(19)
                    .colsample_bytree(0.7598028003624545)
                    .min_interval_samples(2)
                    .min_split_loss(0.7170702984732098)
                    .build(),
            )
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(0.10295766550155155)
                    .prior_sample_size(0.0)
                    .build(),
            )
            .similarity_threshold(0.021212651555301613)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .bagged(false)
            .seed(42)
            .build();

        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        let preds = model.predict(x_test.view());

        let mean = y_test.mean().unwrap();
        let base_err = y_test.mapv(|v| (v - mean).powi(2)).mean().unwrap();
        let test_err = (y_test - preds).powi(2).mean().unwrap();

        let preds_train = model.predict(x_train.view());
        let train_err = (y_train - preds_train).powi(2).mean().unwrap();

        println!(
            "Brazilian Housing | Base error: {:.4}, Training Error: {:.4}, and {:.4}, Test Error: {:.4}",
            base_err, fit_result.err, train_err, test_err
        );
        assert!(test_err < 34425200.0);
    }

    #[test]
    fn test_mpf_cps88wages() {
        let (x, y) = setup_data_csv("data/cps88wages.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y, 1);
        let params = MPFBoostedParamsBuilder::new()
            .epochs(3)
            .n_trees(213)
            .n_iter(16)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(2)
                    .colsample_bytree(0.6237241962915381)
                    .min_interval_samples(24)
                    .min_split_loss(0.4385901864308926)
                    .build(),
            )
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .huber()
                    .alpha(0.0054461959546609874)
                    .prior_sample_size(0.0)
                    .build(),
            )
            .similarity_threshold(0.5116064236426068)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .bagged(true)
            .seed(42)
            .build();

        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        let preds = model.predict(x_test.view());

        let mean = y_test.mean().unwrap();
        let base_err = y_test.mapv(|v| (v - mean).powi(2)).mean().unwrap();
        let test_err = (y_test - preds).powi(2).mean().unwrap();

        let preds_train = model.predict(x_train.view());
        let train_err = (y_train - preds_train).powi(2).mean().unwrap();
        println!(
            "CPS88 Wages | Base error: {:.4}, Training Error: {:.4} and {:.4}, Test Error: {:.4}",
            base_err, fit_result.err, train_err, test_err
        );
        assert!(test_err < 175440.0);
    }

    #[test]
    fn test_mpf_red_wine() {
        let (x, y) = setup_data_csv("data/red_wine_processed.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y, 1);

        let params = MPFBoostedParamsBuilder::new()
            .epochs(1)
            .n_trees(201)
            .n_iter(25)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(3)
                    .colsample_bytree(0.652015893646319)
                    .min_interval_samples(43)
                    .min_split_loss(0.060839473717718556)
                    .build(),
            )
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(2.196_807_874_090_684_2e-9)
                    .prior_sample_size(0.0)
                    .build(),
            )
            .similarity_threshold(0.2713998798759759)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .bagged(false)
            .seed(42)
            .build();

        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        let preds = model.predict(x_test.view());

        let mean = y_test.mean().unwrap();
        let base_err = y_test.mapv(|v| (v - mean).powi(2)).mean().unwrap();
        let test_err = (y_test - preds).powi(2).mean().unwrap();

        println!(
            "Red Wine | Base error: {:.4}, Training Error: {:.4}, Test Error: {:.4}",
            base_err, fit_result.err, test_err
        );
        assert!(test_err < 0.42);
    }

    #[test]
    fn test_mpf_fish_toxicity() {
        let (x, y) = setup_data_csv("data/fish_toxicity.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y, 1);

        let params = MPFBoostedParamsBuilder::new()
            .epochs(2)
            .n_trees(221)
            .n_iter(15)
            .decay(0.9301908910375473)
            .similarity_threshold(0.7027072768106223)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(16)
                    .colsample_bytree(0.7030641276642211)
                    .min_interval_samples(5)
                    .min_split_loss(0.3530290549534714)
                    .build(),
            )
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .bagged(true)
            .seed(42)
            .build();

        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        let preds = model.predict(x_test.view());

        let mean = y_test.mean().unwrap();
        let base_err = y_test.mapv(|v| (v - mean).powi(2)).mean().unwrap();
        let test_err = (y_test - preds).powi(2).mean().unwrap();

        println!(
            "Fish Toxicity | Base error: {:.4}, Training Error: {:.4}, Test Error: {:.4}",
            base_err, fit_result.err, test_err
        );

        //
        assert!(test_err < 0.9);
    }

    #[test]
    fn test_mpf_fish_toxicity_combined_aggregation() {
        let (x, y) = setup_data_csv("data/fish_toxicity.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y, 1);

        let params = MPFBoostedParamsBuilder::new()
            .epochs(9)
            .n_trees(221)
            .n_iter(15)
            .decay(0.9301908910375473)
            .similarity_threshold(0.7027072768106223)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(16)
                    .colsample_bytree(0.7030641276642211)
                    .min_interval_samples(5)
                    .min_split_loss(0.3530290549534714)
                    .build(),
            )
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .aggregation_method(Aggregation::Combined)
            .bagged(true)
            .seed(42)
            .build();

        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        let preds = model.predict(x_test.view());

        let mean = y_test.mean().unwrap();
        let base_err = y_test.mapv(|v| (v - mean).powi(2)).mean().unwrap();
        let test_err = (y_test - preds).powi(2).mean().unwrap();

        println!(
            "Fish Toxicity | Base error: {:.4}, Training Error: {:.4}, Test Error: {:.4}",
            base_err, fit_result.err, test_err
        );

        //
        assert!(test_err < 0.9);
    }

    #[test]
    fn test_mpf_auction() {
        let (x, y) = setup_data_csv("data/openml/44958_auction_verification.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y, 1);

        let params = MPFBoostedParamsBuilder::new()
            .epochs(2)
            .n_trees(2)
            .n_iter(57)
            .bagged(true)
            .aggregation_method(Aggregation::Combined)
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(0.0008761191731033688)
                    .prior_sample_size(0.0)
                    .update_clamp(3.1388455603935297)
                    .build(),
            )
            .log_level("trace")
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .similarity_threshold(0.3573991614563707)
            .decay(0.8418796974151387)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(20)
                    .colsample_bytree(0.631161799847988)
                    .min_interval_samples(6)
                    .min_split_loss(0.0)
                    .build(),
            )
            .build();

        println!("🔨 Starting MPF Auction Dataset Training");
        println!(
            "Training samples: {}, Features: {}",
            x_train.nrows(),
            x_train.ncols()
        );
        println!("Test samples: {}", x_test.nrows());

        let (fit_result, model, _test_errors) = fit_boosted_with_test_error(
            x_train.view(),
            y_train.view(),
            x_test.view(),
            y_test.view(),
            &params,
        );
        let preds = model.predict(x_test.view());

        let mean = y_test.to_owned().mean().unwrap();
        let base_err = y_test
            .to_owned()
            .mapv(|v| (v - mean).powi(2))
            .mean()
            .unwrap();
        let test_err = (y_test.to_owned() - preds).powi(2).mean().unwrap();

        println!(
            "Auction | Base error: {:.4}, Training Error: {:.4}, Test Error: {:.4}",
            base_err, fit_result.err, test_err
        );

        // Optionally: Check that test error is not obviously terrible
        assert!(test_err < base_err * 0.9, "MPF did not beat base error");
    }
}
