mod data_utils;

#[cfg(feature = "evo-logging")]
mod evo_logging_tests {

    use crate::data_utils::{save_x_y, setup_data_csv, split_data};
    use mpf::{
        family::{params::CombinationStrategyParams, Aggregation},
        forest::{fit_boosted, fit_boosted_with_test_error, params::MPFBoostedParamsBuilder},
        grid::params::{RefinementStrategyParamsBuilder, SplitStrategyParamsBuilder},
    };

    #[test]
    fn test_socmob() {
        let (x, y) = setup_data_csv("./data/openml/44987_socmob.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);

        let db_path = "target/socmob_splits_all_positive.sqlite";

        let params = MPFBoostedParamsBuilder::new()
            .epochs(1)
            .n_trees(200)
            .n_iter(92)
            .decay(0.9882439500155061)
            .visualdb_path(Some(db_path.to_string()))
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(3)
                    .colsample_bytree(0.8828809775150922)
                    .min_interval_samples(1)
                    .min_split_loss(0.0)
                    .build(),
            )
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(0.0)
                    .update_clamp(3.4919359535659664)
                    .build(),
            )
            .similarity_threshold(0.36343748787856406)
            .combination_strategy(CombinationStrategyParams::BaggedTwoTensor)
            .bagged(true)
            .seed(42)
            .build();

        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        let preds = model.predict(x_test.view());

        let mean = y_test.mean().unwrap();
        let base_err = y_test.mapv(|v| (v - mean).powi(2)).mean().unwrap();
        let test_err = (y_test - preds).powi(2).mean().unwrap();

        let preds_train = model.predict(x_train.view());
        let train_err = (y_train - preds_train).powi(2).mean().unwrap();
        println!(
            "SocMob | Base error: {:.4}, Training Error: {:.4} and {:.4}, Test Error: {:.4}",
            base_err, fit_result.err, train_err, test_err
        );
        assert!(test_err < 102.0);
    }

    #[test]
    fn test_mpf_auction_with_logging() {
        let (x, y) = setup_data_csv("data/openml/44958_auction_verification.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);

        // Set up logging
        let db_path = "target/auction_splits_all_positive.sqlite";
        let params = MPFBoostedParamsBuilder::new()
            .epochs(2)
            .n_trees(200)
            .n_iter(57)
            .bagged(true)
            .aggregation_method(Aggregation::Combined)
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(0.0008761191731033688)
                    .prior_sample_size(0.0)
                    .update_clamp(3.1388455603935297)
                    .build(),
            )
            .log_level("info")
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .similarity_threshold(0.3573991614563707)
            .decay(0.8418796974151387)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(20)
                    .colsample_bytree(0.631161799847988)
                    .min_interval_samples(6)
                    .min_split_loss(0.0)
                    .build(),
            )
            .visualdb_path(Some(db_path.to_string()))
            .build();

        println!("🔨 Starting MPF Auction Dataset Training with Logging");
        println!(
            "Training samples: {}, Features: {}",
            x_train.nrows(),
            x_train.ncols()
        );
        println!("Test samples: {}", x_test.nrows());
        println!("Database: {}", db_path);

        let (fit_result, model, _test_errors) = fit_boosted_with_test_error(
            x_train.view(),
            y_train.view(),
            x_test.view(),
            y_test.view(),
            &params,
        );
        let preds = model.predict(x_test.view());

        let mean = y_test.to_owned().mean().unwrap();
        let base_err = y_test
            .to_owned()
            .mapv(|v| (v - mean).powi(2))
            .mean()
            .unwrap();
        let test_err = (y_test.to_owned() - preds).powi(2).mean().unwrap();

        println!(
            "Auction | Base error: {:.4}, Training Error: {:.4}, Test Error: {:.4}",
            base_err, fit_result.err, test_err
        );

        // Optionally: Check that test error is not obviously terrible
        assert!(test_err < base_err * 0.9, "MPF did not beat base error");
    }

    #[test]
    fn test_mpf_synthetic_data_with_logging() {
        let (x, y) = setup_data_csv("./data/data_gen3.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);

        // # Best Hyperparameters: {'epochs': 4, 'n_trees': 149, 'n_iter': 23, 'alpha': 0.41267364726662886, 'tilt_tau': 5.558658765668141e-05, 'tilt_rho': 0.00021010963348998642, 'min_split_loss': 0.7540167160231375, 'min_interval_samples': 38, 'split_strategy': 'random', 'split_try': 2, 'colsample_bytree': 0.39588592353906327}

        let params = MPFBoostedParamsBuilder::new()
            .visualdb_path(Some("target/synthetic_data_splits.sqlite".to_string()))
            .epochs(3)
            .n_iter(20) // Using default, but explicitly stated for clarity
            .n_trees(389)
            .combination_strategy(CombinationStrategyParams::BaggedTwoTensor)
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(0.41267364726662886)
                    .tilt_tau(5.558658765668141e-05)
                    .tilt_rho(0.00021010963348998642)
                    .build(),
            )
            .similarity_threshold(0.9)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(2)
                    .colsample_bytree(0.39588592353906327)
                    .min_interval_samples(38)
                    .min_split_loss(0.7540167160231375)
                    .build(),
            )
            .log_level("info")
            .build();

        let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
        save_x_y(&x_train, &fit_result.residuals, "data/data_gen3epoch3.csv")
            .expect("Failed to save x_train and residuals");
        let preds = model.predict(x_test.view());
        let mean = y_test.mean().unwrap();
        let base_err = y_test.view().map(|v| v - mean).powi(2).mean().unwrap();

        let err = (y_test - preds).powi(2).mean().unwrap();

        println!(
            "Base error: {:?}, Train Error: {:?}, Test Error: {:?}",
            base_err, fit_result.err, err
        );
        assert!(err < 0.32);
    }

    #[test]
    fn test_mpf_housing_with_logging() {
        let (x, y) = setup_data_csv("./data/housing_full.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);

        // Set up logging
        let db_path = "target/housing_splits_new.sqlite";
        // Use reduced parameters for faster testing
        let params = MPFBoostedParamsBuilder::new()
            .epochs(7)
            .n_iter(150)
            .n_trees(200)
            .aggregation_method(Aggregation::Combined)
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(0.01)
                    .build(),
            )
            .combination_strategy(CombinationStrategyParams::BaggedTwoTensor)
            .similarity_threshold(0.1)
            .decay(0.9)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(10)
                    .colsample_bytree(1.0)
                    .min_interval_samples(20)
                    .min_split_loss(0.9)
                    .build(),
            )
            .visualdb_path(Some(db_path.to_string()))
            .build();

        println!("🏠 Starting MPF Housing Dataset Training with Logging");
        println!(
            "Training samples: {}, Features: {}",
            x_train.nrows(),
            x_train.ncols()
        );
        println!("Test samples: {}", x_test.nrows());
        println!("Database: {}", db_path);

        let (fit_result, model, _test_errors) = fit_boosted_with_test_error(
            x_train.view(),
            y_train.view(),
            x_test.view(),
            y_test.view(),
            &params,
        );
        // Evaluate model performance
        let preds = model.predict(x_test.view());
        let mean = y_test.mean().unwrap();
        let base_err = y_test.view().map(|v| v - mean).powi(2).mean().unwrap();
        let test_err = (y_test - preds).powi(2).mean().unwrap();

        println!("📊 Results:");
        println!("  Base error: {:.6}", base_err);
        println!("  Training error: {:.6}", fit_result.err);
        println!("  Test error: {:.6}", test_err);

        // Verify performance (relaxed assertion for reduced epochs)
        assert!(
            test_err < 0.3,
            "Test error {:.6} should be less than 0.5 (relaxed for fewer epochs)",
            test_err
        );

        println!("✅ Housing dataset test with logging completed successfully!");
        println!("📁 Split events database saved to: {}", db_path);
        println!("🎯 Use visualization tools to analyze the splits:");
        println!("   cd visualize_splits && python dashboard.py ../target/housing_splits.sqlite");
    }

    #[test]
    fn test_mpf_red_wine() {
        let (x, y) = setup_data_csv("data/red_wine_processed.csv");
        println!("ncols: {}", x.ncols());

        let db_path = "target/red_wine_splits.sqlite";
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);

        let mut params = MPFBoostedParamsBuilder::new()
            .epochs(10)
            .n_trees(203)
            .n_iter(105)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(3)
                    .colsample_bytree(0.652015893646319)
                    .min_interval_samples(43)
                    .min_split_loss(0.060839473717718556)
                    .build(),
            )
            .refinement_strategy(
                RefinementStrategyParamsBuilder::new()
                    .l2()
                    .alpha(2.196_807_874_090_684_2e-9)
                    .prior_sample_size(0.0)
                    .build(),
            )
            .similarity_threshold(0.0)
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .bagged(false)
            .log_level("info")
            .seed(42)
            .visualdb_path(Some(db_path.to_string()))
            .build();

        for i in 1..3 {
            params.seed = i as u64;

            let (fit_result, model, _test_errors) = fit_boosted_with_test_error(
                x_train.view(),
                y_train.view(),
                x_test.view(),
                y_test.view(),
                &params,
            );
            let preds = model.predict(x_test.view());

            let mean = y_test.to_owned().mean().unwrap();
            let base_err = y_test
                .to_owned()
                .mapv(|v| (v - mean).powi(2))
                .mean()
                .unwrap();
            let test_err = (y_test.to_owned() - preds).powi(2).mean().unwrap();

            println!(
                "Red Wine | Base error: {:.4}, Training Error: {:.4}, Test Error: {:.4}",
                base_err, fit_result.err, test_err
            );
        }
    }

    #[test]
    fn test_mpf_fish_toxicity() {
        let (x, y) = setup_data_csv("data/fish_toxicity.csv");
        let (x_train, y_train, x_test, y_test) = split_data(&x, &y);

        let db_path = "target/fish_toxicity_splits.sqlite";
        // let _ = fs::remove_file(db_path);

        let param_builder = MPFBoostedParamsBuilder::new()
            .epochs(9)
            .n_trees(121)
            .n_iter(15)
            .decay(0.9301908910375473)
            .similarity_threshold(0.7027072768106223)
            .split_strategy(
                SplitStrategyParamsBuilder::new()
                    .random_split()
                    .split_try(16)
                    .colsample_bytree(0.7030641276642211)
                    .min_interval_samples(5)
                    .min_split_loss(0.3530290549534714)
                    .build(),
            )
            .combination_strategy(CombinationStrategyParams::GeometricMeanTwoTensor)
            .bagged(true)
            .seed(42)
            .visualdb_path(Some(db_path.to_string()));
        let mut params = param_builder.build();

        for i in 100..200 {
            params.seed = i as u64;

            let (fit_result, model) = fit_boosted(x_train.view(), y_train.view(), &params);
            let preds = model.predict(x_test.view());

            // Evaluate model performance
            let mean = y_test.to_owned().mean().unwrap();
            let base_err = y_test
                .to_owned()
                .mapv(|v| (v - mean).powi(2))
                .mean()
                .unwrap();
            let test_err = (y_test.to_owned() - preds).powi(2).mean().unwrap();

            println!(
                "Fish Toxicity | Base error: {:.4}, Training Error: {:.4}, Test Error: {:.4}",
                base_err, fit_result.err, test_err
            );
        }
    }
}
