use crate::{
    family::Aggregation,
    grid::params::{
        RefinementStrategyParams, SplitStrategyParams, TreeGridParams, TreeGridParamsBuilder,
    },
};
use serde::Serialize;

#[derive(Debug, Clone, PartialEq, Serialize)]
pub enum CombinationStrategyParams {
    GeometricMeanTwoTensor, // Geometric mean on two-tensor a_± factors
    BaggedTwoTensor, // Bagged aggregation using component-shape distance (uses similarity_threshold)
    TensorPower {
        /// Number of quantile bins per axis (default: 256)
        n_bins: usize,
        /// Number of ALS iterations (default: 20)
        n_iterations: usize,
    },
}

#[derive(Debug, Clone, PartialEq, Serialize)]
pub enum ScalingStrategy {
    Off,
    OrthogonalGreedy,
}

#[derive(Debug, Clone, Serialize)]
pub struct TreeGridFamilyParams {
    pub n_trees: usize,
    pub bagged: bool,
    pub tg_params: TreeGridParams,
    pub combination_strategy: CombinationStrategyParams,
    pub similarity_threshold: f64,
    pub aggregation_method: Aggregation,
}

// Builder for TreeGridFamilyBoostedParams
#[derive(Debug)]
pub struct TreeGridFamilyParamsBuilder {
    n_trees: usize,
    bagged: bool,
    tg_params_builder: TreeGridParamsBuilder,
    combination_strategy: CombinationStrategyParams,
    similarity_threshold: f64,
    aggregation_method: Aggregation,
}

impl TreeGridFamilyParamsBuilder {
    pub fn new() -> Self {
        Self {
            n_trees: 100,
            bagged: false,
            tg_params_builder: TreeGridParamsBuilder::new(),
            combination_strategy: CombinationStrategyParams::GeometricMeanTwoTensor,
            similarity_threshold: 0.0,
            aggregation_method: Aggregation::Combined,
        }
    }

    pub fn aggregation_method(mut self, aggregation_method: Aggregation) -> Self {
        self.aggregation_method = aggregation_method;
        self
    }

    pub fn n_trees(mut self, n_trees: usize) -> Self {
        self.n_trees = n_trees;
        self
    }

    pub fn bagged(mut self, bagged: bool) -> Self {
        self.bagged = bagged;
        self
    }

    // Convenience methods for TreeGridParams configuration
    pub fn n_iter(mut self, n_iter: usize) -> Self {
        self.tg_params_builder = self.tg_params_builder.n_iter(n_iter);
        self
    }

    pub fn refinement_strategy(mut self, strategy: RefinementStrategyParams) -> Self {
        self.tg_params_builder = self.tg_params_builder.refinement_strategy(strategy);
        self
    }

    pub fn split_strategy(mut self, strategy: SplitStrategyParams) -> Self {
        self.tg_params_builder = self.tg_params_builder.split_strategy(strategy);
        self
    }

    pub fn combination_strategy(mut self, combination_strategy: CombinationStrategyParams) -> Self {
        self.combination_strategy = combination_strategy;
        self
    }

    pub fn similarity_threshold(mut self, similarity_threshold: f64) -> Self {
        self.similarity_threshold = similarity_threshold;
        self
    }

    /// Sets the combination strategy to BaggedTwoTensor.
    /// Uses similarity_threshold to determine fraction of bags to keep (trim_percentage = 1.0 - similarity_threshold).
    pub fn bagged_two_tensor(mut self) -> Self {
        self.combination_strategy = CombinationStrategyParams::BaggedTwoTensor;
        self
    }

    /// Sets the combination strategy to TensorPower with specified parameters.
    pub fn tensor_power(mut self, n_bins: usize, n_iterations: usize) -> Self {
        self.combination_strategy = CombinationStrategyParams::TensorPower {
            n_bins,
            n_iterations,
        };
        self
    }

    pub fn build(self) -> TreeGridFamilyParams {
        TreeGridFamilyParams {
            n_trees: self.n_trees,
            bagged: self.bagged,
            tg_params: self.tg_params_builder.build(),
            combination_strategy: self.combination_strategy,
            similarity_threshold: self.similarity_threshold,
            aggregation_method: self.aggregation_method,
        }
    }
}

impl Default for TreeGridFamilyParamsBuilder {
    fn default() -> Self {
        Self::new()
    }
}

impl Default for TreeGridFamilyParams {
    fn default() -> Self {
        TreeGridFamilyParamsBuilder::new().build()
    }
}
